├── CITATION.cff
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── balloon_learning_environment
├── __init__.py
├── acme_utils.py
├── agents
│ ├── __init__.py
│ ├── acme_eval_agent.py
│ ├── agent.py
│ ├── agent_registry.py
│ ├── agent_test.py
│ ├── configs
│ │ ├── __init__.py
│ │ ├── dqn.gin
│ │ ├── finetune_perciatelli.gin
│ │ ├── mlp.gin
│ │ └── quantile.gin
│ ├── dopamine_utils.py
│ ├── dopamine_utils_test.py
│ ├── dqn_agent.py
│ ├── dqn_agent_test.py
│ ├── exploration.py
│ ├── exploration_test.py
│ ├── marco_polo_exploration.py
│ ├── marco_polo_exploration_test.py
│ ├── mlp_agent.py
│ ├── mlp_agent_test.py
│ ├── networks.py
│ ├── networks_test.py
│ ├── perciatelli44.py
│ ├── perciatelli44_test.py
│ ├── quantile_agent.py
│ ├── quantile_agent_test.py
│ ├── random_walk_agent.py
│ ├── random_walk_agent_test.py
│ ├── station_seeker_agent.py
│ └── station_seeker_agent_test.py
├── colab
│ ├── BLE_Generative_Wind_Field.ipynb
│ ├── BLE_view_flight_paths.ipynb
│ └── summarize_eval.ipynb
├── distributed_train_acme_qrdqn.py
├── env
│ ├── __init__.py
│ ├── balloon
│ │ ├── __init__.py
│ │ ├── acs.py
│ │ ├── acs_test.py
│ │ ├── altitude_safety.py
│ │ ├── altitude_safety_test.py
│ │ ├── balloon.py
│ │ ├── balloon_test.py
│ │ ├── control.py
│ │ ├── envelope_safety.py
│ │ ├── envelope_safety_test.py
│ │ ├── power_safety.py
│ │ ├── power_safety_test.py
│ │ ├── power_table.py
│ │ ├── power_table_test.py
│ │ ├── pressure_range_builder.py
│ │ ├── pressure_range_builder_test.py
│ │ ├── solar.py
│ │ ├── solar_test.py
│ │ ├── stable_init.py
│ │ ├── stable_init_test.py
│ │ ├── standard_atmosphere.py
│ │ ├── standard_atmosphere_test.py
│ │ └── thermal.py
│ ├── balloon_arena.py
│ ├── balloon_arena_test.py
│ ├── balloon_env.py
│ ├── balloon_env_test.py
│ ├── features.py
│ ├── features_test.py
│ ├── generative_wind_field.py
│ ├── grid_based_wind_field.py
│ ├── grid_based_wind_field_test.py
│ ├── grid_wind_field_sampler.py
│ ├── gym.py
│ ├── rendering
│ │ ├── __init__.py
│ │ ├── matplotlib_renderer.py
│ │ └── renderer.py
│ ├── simplex_wind_noise.py
│ ├── simulator_data.py
│ ├── wind_field.py
│ ├── wind_field_test.py
│ ├── wind_gp.py
│ └── wind_gp_test.py
├── eval
│ ├── README.md
│ ├── __init__.py
│ ├── combine_eval_shards.py
│ ├── eval.py
│ ├── eval_lib.py
│ ├── eval_lib_test.py
│ ├── strata_seeds.py
│ ├── suites.py
│ └── suites_test.py
├── generated
│ ├── multi_balloon.mp4
│ └── wind_field.mp4
├── generative
│ ├── __init__.py
│ ├── dataset_wind_field_reservoir.py
│ ├── dataset_wind_field_reservoir_test.py
│ ├── learn_wind_field_generator.py
│ ├── vae.py
│ ├── vae_test.py
│ └── wind_field_reservoir.py
├── metrics
│ ├── __init__.py
│ ├── collector.py
│ ├── collector_dispatcher.py
│ ├── collector_dispatcher_test.py
│ ├── collector_test.py
│ ├── console_collector.py
│ ├── console_collector_test.py
│ ├── pickle_collector.py
│ ├── pickle_collector_test.py
│ ├── statistics_instance.py
│ ├── tensorboard_collector.py
│ └── tensorboard_collector_test.py
├── models
│ ├── __init__.py
│ ├── models.py
│ ├── models_test.py
│ ├── offlineskies22_decoder.msgpack
│ └── perciatelli44.pb
├── train.py
├── train_acme_qrdqn.py
├── train_lib.py
├── train_lib_test.py
└── utils
│ ├── __init__.py
│ ├── constants.py
│ ├── run_helpers.py
│ ├── run_helpers_test.py
│ ├── sampling.py
│ ├── sampling_test.py
│ ├── spherical_geometry.py
│ ├── spherical_geometry_test.py
│ ├── test_helpers.py
│ ├── transforms.py
│ ├── transforms_test.py
│ ├── units.py
│ ├── units_test.py
│ ├── wind.py
│ └── wind_test.py
├── docs
├── Makefile
├── README.md
├── about.rst
├── benchmarks.rst
├── changelist.rst
├── conf.py
├── environment.rst
├── getting_started.rst
├── imgs
│ ├── balloon_schematic.jpg
│ ├── ble_logo.png
│ ├── ble_logo_small.png
│ ├── reward_function.png
│ ├── station_keeping.gif
│ ├── training_curve.jpg
│ └── wind_field.gif
├── index.rst
├── make.bat
├── new_agent.rst
├── requirements.txt
└── src
│ ├── agents.rst
│ ├── agents
│ ├── agent.rst
│ ├── dqn_agent.rst
│ ├── perciatelli44.rst
│ ├── quantile_agent.rst
│ └── station_seeker_agent.rst
│ ├── balloon_env.rst
│ ├── env.rst
│ ├── eval_lib.rst
│ ├── features.rst
│ ├── metrics.rst
│ ├── metrics
│ ├── collector.rst
│ ├── collector_dispatcher.rst
│ ├── console_collector.rst
│ ├── pickle_collector.rst
│ ├── statistics_instance.rst
│ └── tensorboard_collector.rst
│ └── train_lib.rst
├── pyproject.toml
├── requirements.txt
├── setup.py
└── style_guidelines.md
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Greaves"
5 | given-names: "Joshua"
6 | - family-names: "Candido"
7 | given-names: "Salvatore"
8 | - family-names: "Dumoulin"
9 | given-names: "Vincent"
10 | - family-names: "Goroshin"
11 | given-names: "Ross"
12 | - family-names: "Ponda"
13 | given-names: "Sameera S."
14 | - family-names: "Bellemare"
15 | given-names: "Marc G."
16 | - family-names: "Castro"
17 | given-names: "Pablo Samuel"
18 | title: "Balloon Learning Environment"
19 | version: 1.0.0
20 | date-released: 2021-12-06
21 | url: "https://github.com/google/balloon-learning-environment"
22 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Balloon Learning Environment
2 | [Docs][docs]
3 |
4 |
5 |

7 |
8 |
9 |
10 |
11 | The Balloon Learning Environment (BLE) is a simulator for stratospheric
12 | balloons. It is designed as a benchmark environment for deep reinforcement
13 | learning algorithms, and is a followup to the Nature paper
14 | ["Autonomous navigation of stratospheric balloons using reinforcement learning"](https://www.nature.com/articles/s41586-020-2939-8).
15 |
16 | ## Getting Started
17 |
18 | Note: The BLE requires python >= 3.7
19 |
20 | The BLE can easily be installed with pip:
21 |
22 | ```
23 | $ pip install --upgrade pip
24 | $ pip install balloon_learning_environment
25 | ```
26 |
27 | To install with the `acme` package:
28 |
29 | ```
30 | $ pip install --upgrade pip
31 | $ pip install balloon_learning_environment[acme]
32 | ```
33 |
34 | Once the package has been installed, you can test it runs correctly by
35 | evaluating one of the benchmark agents:
36 |
37 | ```
38 | python -m balloon_learning_environment.eval.eval \
39 | --agent=station_seeker \
40 | --renderer=matplotlib \
41 | --suite=micro_eval \
42 | --output_dir=/tmp/ble/eval
43 | ```
44 |
45 | To install from GitHub directly, run the following commands from the root
46 | directory where you cloned the repository:
47 |
48 | ```
49 | $ pip install --upgrade pip
50 | $ pip install .[acme]
51 | ```
52 |
53 | ## Ensure the BLE is Using Your GPU/TPU
54 |
55 | The BLE contains a VAE for generating winds, which you will probably want
56 | to run on your accelerator. See the jax documentation for installing with
57 | [GPU](https://github.com/google/jax#pip-installation-gpu-cuda) or
58 | [TPU](https://github.com/google/jax#pip-installation-google-cloud-tpu).
59 |
60 | As a sanity check, you can open interactive python and run:
61 |
62 | ```
63 | from balloon_learning_environment.env import balloon_env
64 | env = balloon_env.BalloonEnv()
65 | ```
66 |
67 | If you are not running with GPU/TPU, you should see a log like:
68 |
69 | ```
70 | WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
71 | ```
72 |
73 | If you don't see this log, you should be good to go!
74 |
75 | ## Next Steps
76 |
77 | For more information, see the [docs][docs].
78 |
79 | ## Giving credit
80 |
81 | If you use the Balloon Learning Environment in your work, we ask that you use
82 | the following BibTeX entry:
83 |
84 | ```
85 | @software{Greaves_Balloon_Learning_Environment_2021,
86 | author = {Greaves, Joshua and Candido, Salvatore and Dumoulin, Vincent and Goroshin, Ross and Ponda, Sameera S. and Bellemare, Marc G. and Castro, Pablo Samuel},
87 | month = {12},
88 | title = {{Balloon Learning Environment}},
89 | url = {https://github.com/google/balloon-learning-environment},
90 | version = {1.0.0},
91 | year = {2021}
92 | }
93 | ```
94 |
95 | If you use the `ble_wind_field` dataset, you should also cite
96 |
97 | ```
98 | Hersbach, H., Bell, B., Berrisford, P., Hirahara, S., Horányi, A.,
99 | Muñoz‐Sabater, J., Nicolas, J., Peubey, C., Radu, R., Schepers, D., Simmons, A.,
100 | Soci, C., Abdalla, S., Abellan, X., Balsamo, G., Bechtold, P., Biavati, G.,
101 | Bidlot, J., Bonavita, M., De Chiara, G., Dahlgren, P., Dee, D., Diamantakis, M.,
102 | Dragani, R., Flemming, J., Forbes, R., Fuentes, M., Geer, A., Haimberger, L.,
103 | Healy, S., Hogan, R.J., Hólm, E., Janisková, M., Keeley, S., Laloyaux, P.,
104 | Lopez, P., Lupu, C., Radnoti, G., de Rosnay, P., Rozum, I., Vamborg, F.,
105 | Villaume, S., Thépaut, J-N. (2017): Complete ERA5: Fifth generation of ECMWF
106 | atmospheric reanalyses of the global climate. Copernicus Climate Change Service
107 | (C3S) Data Store (CDS). (Accessed on 01-04-2021)
108 | ```
109 |
110 |
111 | [docs]: https://balloon-learning-environment.readthedocs.io/en/latest/
112 |
--------------------------------------------------------------------------------
/balloon_learning_environment/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Balloon Learning Environment root package exports."""
17 | from balloon_learning_environment.env import gym as _ble_gym
18 |
19 | # Register Gym environment. Users can import the environment as:
20 | # gym.make('balloon_learning_environment:BalloonLearningEnvironment-v0')
21 | _ble_gym.register_env()
22 |
23 | # We don't export anything
24 | __all__ = []
25 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/agent_registry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """The registry of agents.
17 |
18 | This is where you add new agents; we provide some examples to get you started.
19 |
20 | When writing a new agent, follow the API specified by the base class
21 | `agent.Agent` and implement the abstract methods.
22 | The provided agents are:
23 | RandomAgent: Ignores all observations and picks actions randomly.
24 | MLPAgent: Uses a simple multi-layer perceptron (MLP) to learn the mapping of
25 | states to Q-values. The number of layers and hidden units in the MLP is
26 | configurable.
27 | """
28 |
29 | from typing import Callable, Optional
30 |
31 | from balloon_learning_environment.agents import agent
32 | from balloon_learning_environment.agents import dqn_agent
33 | from balloon_learning_environment.agents import mlp_agent
34 | from balloon_learning_environment.agents import perciatelli44
35 | from balloon_learning_environment.agents import quantile_agent
36 | from balloon_learning_environment.agents import random_walk_agent
37 | from balloon_learning_environment.agents import station_seeker_agent
38 |
39 | BASE_DIR = 'balloon_learning_environment/agents/configs'
40 | REGISTRY = {
41 | 'random': (agent.RandomAgent, None),
42 | 'mlp': (mlp_agent.MLPAgent, f'{BASE_DIR}/mlp.gin'),
43 | 'dqn': (dqn_agent.DQNAgent, f'{BASE_DIR}/dqn.gin'),
44 | 'perciatelli44': (perciatelli44.Perciatelli44, None),
45 | 'quantile': (quantile_agent.QuantileAgent, f'{BASE_DIR}/quantile.gin'),
46 | 'finetune_perciatelli': (quantile_agent.QuantileAgent,
47 | f'{BASE_DIR}/finetune_perciatelli.gin'),
48 | 'station_seeker': (station_seeker_agent.StationSeekerAgent, None),
49 | 'random_walk': (random_walk_agent.RandomWalkAgent, None),
50 | }
51 |
52 | try:
53 | from balloon_learning_environment.agents import acme_eval_agent # pylint: disable=g-import-not-at-top
54 | REGISTRY['acme_eval_agent'] = (acme_eval_agent.AcmeEvalAgent, None)
55 | except ModuleNotFoundError:
56 | # This is most likely because acme dependencies aren't installed.
57 | pass
58 | _ACME_AGENTS = frozenset({'acme_eval_agent'})
59 |
60 |
61 |
62 | def agent_constructor(name: str) -> Callable[..., agent.Agent]:
63 | if name not in REGISTRY:
64 | if name in _ACME_AGENTS:
65 | raise ValueError(f'Agent {name} not available. '
66 | 'Have you tried installing the acme dependencies?')
67 | else:
68 | raise ValueError(f'Agent {name} not recognized')
69 | return REGISTRY[name][0]
70 |
71 |
72 | def get_default_gin_config(name: str) -> Optional[str]:
73 | if name not in REGISTRY:
74 | raise ValueError(f'Agent {name} not recognized')
75 | return REGISTRY[name][1]
76 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/agent_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.agents.agent."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.agents import agent
20 | import numpy as np
21 |
22 |
23 | class AgentTest(absltest.TestCase):
24 |
25 | def setUp(self):
26 | super().setUp()
27 | self._na = 5
28 | self._observation_shape = (3, 4)
29 |
30 | def test_valid_subclass(self):
31 |
32 | # Create a simple subclass that implements the abstract methods.
33 | class SimpleAgent(agent.Agent):
34 |
35 | def begin_episode(self, unused_obs: None) -> int:
36 | return 0
37 |
38 | def step( # pytype: disable=signature-mismatch # overriding-parameter-type-checks
39 | self, reward: float, observation: None) -> int:
40 | return 0
41 |
42 | def end_episode(self, reward: float, terminal: bool) -> None:
43 | pass
44 |
45 | simple_agent = SimpleAgent(self._na, self._observation_shape)
46 | self.assertEqual('SimpleAgent', simple_agent.get_name())
47 | self.assertEqual(self._na, simple_agent._num_actions)
48 | self.assertEqual(self._observation_shape, simple_agent._observation_shape)
49 | self.assertEqual(simple_agent.reload_latest_checkpoint(''), -1)
50 |
51 |
52 | class RandomAgentTest(absltest.TestCase):
53 |
54 | def setUp(self):
55 | super().setUp()
56 | self._na = 5
57 | self._observation_shape = (3, 4)
58 | self._observation = np.zeros(self._observation_shape, dtype=np.float32)
59 |
60 | def test_create_agent(self):
61 | random_agent = agent.RandomAgent(self._na, self._observation_shape)
62 | self.assertEqual('RandomAgent', random_agent.get_name())
63 | self.assertEqual(self._na, random_agent._num_actions)
64 | self.assertEqual(self._observation_shape, random_agent._observation_shape)
65 |
66 | def test_action_selection(self):
67 | random_agent = agent.RandomAgent(self._na, self._observation_shape)
68 | for _ in range(10): # Test for 10 episodes.
69 | action = random_agent.begin_episode(self._observation)
70 | self.assertGreaterEqual(action, 0)
71 | self.assertLess(action, self._na)
72 | for _ in range(20): # Each episode includes 20 steps.
73 | action = random_agent.step(0.0, self._observation)
74 | self.assertIn(action, range(self._na))
75 | random_agent.end_episode(0.0, True)
76 |
77 |
78 | if __name__ == '__main__':
79 | absltest.main()
80 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/configs/dqn.gin:
--------------------------------------------------------------------------------
1 | # Hyperparameters for a simple DQN-style agent.
2 | import dopamine.jax.agents.dqn.dqn_agent
3 | import balloon_learning_environment.agents.dqn_agent
4 | import balloon_learning_environment.agents.networks
5 | import dopamine.replay_memory.circular_replay_buffer
6 |
7 | balloon_learning_environment.agents.dqn_agent.DQNAgent.network = @networks.MLPNetwork
8 | balloon_learning_environment.agents.dqn_agent.DQNAgent.checkpoint_duration = 5
9 | networks.MLPNetwork.num_layers = 8
10 | networks.MLPNetwork.hidden_units = 600
11 | JaxDQNAgent.gamma = 0.993
12 | JaxDQNAgent.update_horizon = 5
13 | JaxDQNAgent.min_replay_history = 500
14 | JaxDQNAgent.update_period = 4
15 | JaxDQNAgent.target_update_period = 100
16 | JaxDQNAgent.epsilon_fn = @dqn_agent.identity_epsilon
17 | JaxDQNAgent.epsilon_train = 0.01
18 | JaxDQNAgent.epsilon_eval = 0.0
19 | JaxDQNAgent.optimizer = 'adam'
20 | JaxDQNAgent.loss_type = 'mse' # MSE works better with Adam.
21 | JaxDQNAgent.summary_writing_frequency = 1
22 | JaxDQNAgent.allow_partial_reload = True
23 | dqn_agent.create_optimizer.learning_rate = 2e-6
24 | dqn_agent.create_optimizer.eps = 0.00002
25 |
26 | OutOfGraphReplayBuffer.replay_capacity = 2000000
27 | OutOfGraphReplayBuffer.batch_size = 32
28 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/configs/finetune_perciatelli.gin:
--------------------------------------------------------------------------------
1 | # Hyperparameters for a Quantile-style agent that uses MarcoPolo exploration.
2 | # This is the same configuration used in our Nature paper, and further
3 | # starts from the trained parameters used in the Nature paper.
4 | import dopamine.jax.agents.dqn.dqn_agent
5 | import dopamine.jax.agents.quantile.quantile_agent
6 | import balloon_learning_environment.agents.marco_polo_exploration
7 | import balloon_learning_environment.agents.quantile_agent
8 | import balloon_learning_environment.agents.networks
9 | import balloon_learning_environment.agents.random_walk_agent
10 | import dopamine.replay_memory.prioritized_replay_buffer
11 |
12 | QuantileAgent.network = @agents.networks.QuantileNetwork
13 | QuantileAgent.exploration_wrapper_constructor = @marco_polo_exploration.MarcoPoloExploration
14 | QuantileAgent.reload_perciatelli = True
15 | QuantileAgent.checkpoint_duration = 5
16 | MarcoPoloExploration.exploratory_episode_probability = 0.8
17 | MarcoPoloExploration.exploratory_agent_constructor = @random_walk_agent.RandomWalkAgent
18 | agents.networks.QuantileNetwork.num_layers = 8
19 | agents.networks.QuantileNetwork.hidden_units = 600
20 | JaxQuantileAgent.gamma = 0.993
21 | JaxQuantileAgent.update_horizon = 5
22 | JaxQuantileAgent.min_replay_history = 500
23 | JaxQuantileAgent.update_period = 4
24 | JaxQuantileAgent.target_update_period = 100
25 | JaxQuantileAgent.epsilon_train = 0.0
26 | JaxQuantileAgent.epsilon_eval = 0.0
27 | JaxQuantileAgent.epsilon_fn = @dqn_agent.identity_epsilon
28 | JaxQuantileAgent.optimizer = 'adam'
29 | JaxQuantileAgent.summary_writing_frequency = 1
30 | JaxQuantileAgent.num_atoms = 51
31 | JaxQuantileAgent.allow_partial_reload = True
32 | dqn_agent.create_optimizer.learning_rate = 2e-6
33 | dqn_agent.create_optimizer.eps = 0.00002
34 |
35 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 2000000
36 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32
37 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/configs/mlp.gin:
--------------------------------------------------------------------------------
1 | import balloon_learning_environment.agents.networks
2 |
3 | networks.MLPNetwork.num_layers = 1
4 | networks.MLPNetwork.hidden_units = 256
5 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/configs/quantile.gin:
--------------------------------------------------------------------------------
1 | # Hyperparameters for a Quantile-style agent that uses MarcoPolo exploration.
2 | # This is the same configuration used in our Nature paper.
3 | import dopamine.jax.agents.dqn.dqn_agent
4 | import dopamine.jax.agents.quantile.quantile_agent
5 | import balloon_learning_environment.agents.marco_polo_exploration
6 | import balloon_learning_environment.agents.quantile_agent
7 | import balloon_learning_environment.agents.networks
8 | import balloon_learning_environment.agents.random_walk_agent
9 | import dopamine.replay_memory.prioritized_replay_buffer
10 |
11 | QuantileAgent.network = @agents.networks.QuantileNetwork
12 | QuantileAgent.exploration_wrapper_constructor = @marco_polo_exploration.MarcoPoloExploration
13 | QuantileAgent.reload_perciatelli = False
14 | QuantileAgent.checkpoint_duration = 5
15 | MarcoPoloExploration.exploratory_episode_probability = 0.8
16 | MarcoPoloExploration.exploratory_agent_constructor = @random_walk_agent.RandomWalkAgent
17 | agents.networks.QuantileNetwork.num_layers = 8
18 | agents.networks.QuantileNetwork.hidden_units = 600
19 | JaxQuantileAgent.gamma = 0.993
20 | JaxQuantileAgent.update_horizon = 5
21 | JaxQuantileAgent.min_replay_history = 500
22 | JaxQuantileAgent.update_period = 4
23 | JaxQuantileAgent.target_update_period = 100
24 | JaxQuantileAgent.epsilon_train = 0.0
25 | JaxQuantileAgent.epsilon_eval = 0.0
26 | JaxQuantileAgent.epsilon_fn = @dqn_agent.identity_epsilon
27 | JaxQuantileAgent.optimizer = 'adam'
28 | JaxQuantileAgent.summary_writing_frequency = 1
29 | JaxQuantileAgent.num_atoms = 51
30 | JaxQuantileAgent.allow_partial_reload = True
31 | dqn_agent.create_optimizer.learning_rate = 2e-6
32 | dqn_agent.create_optimizer.eps = 0.00002
33 |
34 | OutOfGraphPrioritizedReplayBuffer.replay_capacity = 2000000
35 | OutOfGraphPrioritizedReplayBuffer.batch_size = 32
36 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/dopamine_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Common utilities for Dopamine-based agents."""
17 |
18 | import os.path as osp
19 | import pathlib
20 | import pickle
21 | from typing import Any, Dict, Callable
22 |
23 | from absl import logging
24 | import tensorflow as tf
25 |
26 |
27 | def _make_checkpoint_filename(checkpoint_dir: str,
28 | iteration_number: int) -> str:
29 | return osp.join(checkpoint_dir, f'checkpoint_{iteration_number:05d}.pkl')
30 |
31 |
32 | def _is_checkpoint_name(file_name: str) -> bool:
33 | path = pathlib.Path(file_name)
34 | if path.suffix != '.pkl':
35 | return False
36 |
37 | parts = path.stem.split('_')
38 | if len(parts) != 2:
39 | return False
40 |
41 | return parts[0] == 'checkpoint' and len(parts[1]) == 5
42 |
43 |
44 | def _get_checkpoint_iteration(checkpoint_name: str) -> int:
45 | path = pathlib.Path(checkpoint_name)
46 | parts = path.stem.split('_')
47 | return int(parts[1])
48 |
49 |
50 | def save_checkpoint(checkpoint_dir: str,
51 | iteration_number: int,
52 | bundle_fn: Callable[[str, int], Any]) -> None:
53 | """Save a checkpoint using the provided bundling function."""
54 | # Try to create checkpoint directory if it doesn't exist.
55 | try:
56 | tf.io.gfile.makedirs(checkpoint_dir)
57 | except tf.errors.PermissionDeniedError:
58 | # If it already exists, ignore exception.
59 | pass
60 |
61 | bundle = bundle_fn(checkpoint_dir, iteration_number)
62 | if bundle is None:
63 | logging.warning('Unable to checkpoint to %s at iteration %d.',
64 | checkpoint_dir, iteration_number)
65 | return
66 |
67 | filename = _make_checkpoint_filename(checkpoint_dir, iteration_number)
68 | with tf.io.gfile.GFile(filename, 'w') as fout:
69 | pickle.dump(bundle, fout)
70 |
71 |
72 | def load_checkpoint(
73 | checkpoint_dir: str,
74 | iteration_number: int,
75 | unbundle_fn: Callable[[str, int, Dict[Any, Any]], bool]) -> None:
76 | """Load a checkpoint using the provided unbundling function."""
77 | filename = _make_checkpoint_filename(checkpoint_dir, iteration_number)
78 | if not tf.io.gfile.exists(filename):
79 | logging.warning('Unable to restore bundle from %s', filename)
80 | return
81 |
82 | with tf.io.gfile.GFile(filename, 'rb') as fin:
83 | bundle = pickle.load(fin)
84 | if not unbundle_fn(checkpoint_dir, iteration_number, bundle):
85 | logging.warning('Call to parent `unbundle` failed.')
86 |
87 |
88 | def get_latest_checkpoint(checkpoint_dir: str) -> int:
89 | """Find the episode ID of the latest checkpoint, if any."""
90 | glob = osp.join(checkpoint_dir, 'checkpoint_*.pkl')
91 | def extract_episode(x):
92 | return int(x[x.rfind('checkpoint_') + 11:-4])
93 |
94 | try:
95 | checkpoint_files = tf.io.gfile.glob(glob)
96 | except tf.errors.NotFoundError:
97 | logging.warning('Unable to reload checkpoint at %s', checkpoint_dir)
98 | return -1
99 |
100 | try:
101 | latest_episode = max(extract_episode(x) for x in checkpoint_files)
102 | except ValueError:
103 | return -1
104 | return latest_episode
105 |
106 |
107 | def clean_up_old_checkpoints(checkpoint_dir: str,
108 | episode_number: int,
109 | checkpoint_duration: int = 50) -> None:
110 | """Removes the most recent stale checkpoint.
111 |
112 | Args:
113 | checkpoint_dir: Directory where checkpoints are stored.
114 | episode_number: Current episode number.
115 | checkpoint_duration: How long (in terms of episodes) a checkpoint should
116 | last.
117 | """
118 | for file_name in tf.io.gfile.listdir(checkpoint_dir):
119 | if _is_checkpoint_name(file_name):
120 | iteration = _get_checkpoint_iteration(file_name)
121 | if iteration < episode_number - checkpoint_duration:
122 | tf.io.gfile.remove(osp.join(checkpoint_dir, file_name))
123 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/exploration.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base class for exploration modules which can be used by the agents.
17 |
18 | An Agent can wrap an Exploration module during its action selection.
19 | So rather than simply issuing
20 | `return action`
21 | It can issue:
22 | `return self.exploration_wrapper(action)`
23 | """
24 |
25 | from typing import Sequence
26 | import numpy as np
27 |
28 |
29 | class Exploration(object):
30 | """Base class for an exploration module; this wrapper is a no-op."""
31 |
32 | def __init__(self, unused_num_actions: int,
33 | unused_observation_shape: Sequence[int]):
34 | pass
35 |
36 | def begin_episode(self, observation: np.ndarray, a: int) -> int:
37 | """Returns the same action passed by the agent."""
38 | del observation
39 | return a
40 |
41 | def step(self, reward: float, observation: np.ndarray, a: int) -> int:
42 | """Returns the same action passed by the agent."""
43 | del reward
44 | del observation
45 | return a
46 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/exploration_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.agents.exploration."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.agents import exploration
20 | import numpy as np
21 |
22 |
23 | class ExplorationTest(absltest.TestCase):
24 |
25 | def test_exploration_class(self):
26 | num_actions = 5
27 | observation_shape = (3, 4)
28 | e = exploration.Exploration(num_actions, observation_shape)
29 | # This class just echoes back the actions passed in, ignoring all other
30 | # parameters.
31 | for i in range(5):
32 | self.assertEqual(
33 | i, e.begin_episode(np.random.rand(*observation_shape), i))
34 | for j in range(10):
35 | self.assertEqual(
36 | j, e.step(np.random.rand(), np.random.rand(*observation_shape), j))
37 |
38 |
39 | if __name__ == '__main__':
40 | absltest.main()
41 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/marco_polo_exploration.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Exploration strategy used in the Nature paper.
17 |
18 | Specifically, it interleaves between two phases:
19 | - RL phase (where the parent agent's actions are maintained).
20 | - Exploration phase (where a second agent picks actions).
21 | """
22 |
23 | import datetime as dt
24 | import time
25 | from typing import Callable, Optional, Sequence
26 |
27 | from balloon_learning_environment.agents import agent
28 | from balloon_learning_environment.agents import exploration
29 | from balloon_learning_environment.utils import constants
30 | import gin
31 | import jax
32 | import numpy as np
33 |
34 |
35 | _RL_PHASE_LENGTH = dt.timedelta(hours=4)
36 | _EXPLORATORY_PHASE_LENGTH = dt.timedelta(hours=2)
37 |
38 |
39 | @gin.configurable
40 | class MarcoPoloExploration(exploration.Exploration):
41 | """Exploration strategy used in the Nature paper."""
42 |
43 | def __init__(self, num_actions: int, observation_shape: Sequence[int],
44 | exploratory_episode_probability: float = gin.REQUIRED,
45 | exploratory_agent_constructor: Callable[
46 | [int, Sequence[int]], agent.Agent] = gin.REQUIRED,
47 | seed: Optional[int] = None):
48 | self._exploratory_agent = exploratory_agent_constructor(
49 | num_actions, observation_shape)
50 | self._exploratory_episode_probability = exploratory_episode_probability
51 | self._exploratory_episode = False
52 | self._exploratory_phase = False
53 | self._phase_time_elapsed = dt.timedelta()
54 | seed = int(time.time() * 1e6) if seed is None else seed
55 | self._rng = jax.random.PRNGKey(seed)
56 |
57 | def begin_episode(self, observation: np.ndarray, action: int) -> int:
58 | """Initialize episode, which always starts in RL phase."""
59 | self._exploratory_agent.begin_episode(observation)
60 | self._phase_time_elapsed = dt.timedelta()
61 | rng, self._rng = jax.random.split(self._rng)
62 | self._exploratory_episode = (
63 | jax.random.uniform(rng) <= self._exploratory_episode_probability)
64 | # We always start in the RL phase.
65 | self._exploratory_phase = False
66 | return action
67 |
68 | def _phase_expired(self) -> bool:
69 | if (self._exploratory_phase and
70 | self._phase_time_elapsed >= _EXPLORATORY_PHASE_LENGTH):
71 | return True
72 |
73 | if not self._exploratory_phase and self._phase_time_elapsed >= _RL_PHASE_LENGTH:
74 | return True
75 |
76 | return False
77 |
78 | def _update_phase(self) -> None:
79 | self._phase_time_elapsed += constants.AGENT_TIME_STEP
80 | if not self._exploratory_episode:
81 | return
82 |
83 | if self._phase_expired():
84 | self._exploratory_phase = not self._exploratory_phase
85 | self._phase_time_elapsed = dt.timedelta()
86 |
87 | def step(self, reward: float, observation: np.ndarray, action: int) -> int:
88 | """Return `action` if in RL phase, otherwise query _exploratory_agent."""
89 | self._update_phase()
90 | if self._exploratory_phase:
91 | return self._exploratory_agent.step(reward, observation)
92 |
93 | return action
94 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/networks.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A common set of networks available for agents."""
17 |
18 | from absl import logging
19 | from dopamine.discrete_domains import atari_lib
20 | from flax import linen as nn
21 | import gin
22 | import jax
23 | import jax.numpy as jnp
24 |
25 |
26 | @gin.configurable
27 | class MLPNetwork(nn.Module):
28 | """A simple MLP network."""
29 | num_actions: int
30 | num_layers: int = gin.REQUIRED
31 | hidden_units: int = gin.REQUIRED
32 | is_dopamine: bool = False
33 |
34 | @nn.compact
35 | def __call__(self, x: jnp.ndarray):
36 | # This method sets up the MLP for inference, using the specified number of
37 | # layers and units.
38 | logging.info('Creating MLP network with %d layers and %d hidden units',
39 | self.num_layers, self.hidden_units)
40 |
41 | # Network initializer.
42 | kernel_initializer = jax.nn.initializers.glorot_uniform()
43 | x = x.astype(jnp.float32) # Convert to JAX float32 type.
44 | x = x.reshape(-1) # Flatten.
45 |
46 | # Pass through the desired number of hidden layers (we do this for
47 | # one less than `self.num_layers`, as `self._final_layer` counts as one).
48 | for _ in range(self.num_layers - 1):
49 | x = nn.Dense(features=self.hidden_units,
50 | kernel_init=kernel_initializer)(x)
51 | x = nn.relu(x)
52 |
53 | # The final layer will output a value for each action.
54 | q_values = nn.Dense(features=self.num_actions,
55 | kernel_init=kernel_initializer)(x)
56 |
57 | if self.is_dopamine:
58 | q_values = atari_lib.DQNNetworkType(q_values)
59 | return q_values
60 |
61 |
62 | @gin.configurable
63 | class QuantileNetwork(nn.Module):
64 | """Network used to compute the agent's return quantiles."""
65 | num_actions: int
66 | num_layers: int = gin.REQUIRED
67 | hidden_units: int = gin.REQUIRED
68 | num_atoms: int = 51 # Normally set by JaxQuantileAgent.
69 | inputs_preprocessed: bool = False
70 |
71 | @nn.compact
72 | def __call__(self, x: jnp.ndarray):
73 | # This method sets up the MLP for inference, using the specified number of
74 | # layers and units.
75 | logging.info('Creating MLP network with %d layers, %d hidden units, and '
76 | '%d atoms', self.num_layers, self.hidden_units, self.num_atoms)
77 |
78 | # Network initializer.
79 | kernel_initializer = nn.initializers.variance_scaling(
80 | scale=1.0 / jnp.sqrt(3.0),
81 | mode='fan_in',
82 | distribution='uniform')
83 | x = x.astype(jnp.float32) # Convert to JAX float32 type.
84 | x = x.reshape(-1) # Flatten.
85 |
86 | # Pass through the desired number of hidden layers (we do this for
87 | # one less than `self.num_layers`, as `self._final_layer` counts as one).
88 | for _ in range(self.num_layers - 1):
89 | x = nn.Dense(features=self.hidden_units,
90 | kernel_init=kernel_initializer)(x)
91 | x = nn.relu(x)
92 |
93 | x = nn.Dense(features=self.num_actions * self.num_atoms,
94 | kernel_init=kernel_initializer)(x)
95 | logits = x.reshape((self.num_actions, self.num_atoms))
96 | probabilities = nn.softmax(logits)
97 | q_values = jnp.mean(logits, axis=1)
98 | return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
99 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/networks_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.agents.networks."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.agents import networks
20 | import gin
21 | import jax
22 | import jax.numpy as jnp
23 |
24 |
25 | class NetworksTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self._num_actions = 4
30 | self._observation_shape = (6, 7)
31 | self._example_state = jnp.zeros(self._observation_shape)
32 | gin.bind_parameter('MLPNetwork.num_layers', 1)
33 | gin.bind_parameter('MLPNetwork.hidden_units', 256)
34 |
35 | def _create_network(self):
36 | self._network_def = networks.MLPNetwork(num_actions=self._num_actions)
37 |
38 | def test_default_network_parameters(self):
39 | self._create_network()
40 | self.assertEqual(self._num_actions, self._network_def.num_actions)
41 | self.assertEqual(1, self._network_def.num_layers)
42 | self.assertEqual(256, self._network_def.hidden_units)
43 | network_params = self._network_def.init(jax.random.PRNGKey(0),
44 | self._example_state)
45 | self.assertIn('Dense_0', network_params['params'])
46 |
47 | def test_custom_network(self):
48 | num_layers = 5
49 | hidden_units = 64
50 | gin.bind_parameter('MLPNetwork.num_layers', num_layers)
51 | gin.bind_parameter('MLPNetwork.hidden_units', hidden_units)
52 | self._create_network()
53 | self.assertEqual(num_layers, self._network_def.num_layers)
54 | self.assertEqual(hidden_units, self._network_def.hidden_units)
55 | network_params = self._network_def.init(jax.random.PRNGKey(0),
56 | self._example_state)
57 | self.assertIn('Dense_0', network_params['params'])
58 | for i in range(num_layers - 1):
59 | self.assertIn(f'Dense_{i}', network_params['params'])
60 |
61 | def test_call_network(self):
62 | self._create_network()
63 | network_params = self._network_def.init(jax.random.PRNGKey(0),
64 | self._example_state)
65 | # All zeros in should produce all zeros out, since the default initializer
66 | # for bias is all zeros.
67 | zeros_out = self._network_def.apply(network_params,
68 | jnp.zeros_like(self._example_state))
69 | self.assertTrue(jnp.array_equal(zeros_out, jnp.zeros(self._num_actions)))
70 | # All ones in should produce something that is non-zero at the output, since
71 | # we are using Glorot initiazilation.
72 | ones_out = self._network_def.apply(network_params,
73 | jnp.ones_like(self._example_state))
74 | self.assertFalse(jnp.array_equal(ones_out, jnp.zeros(self._num_actions)))
75 |
76 |
77 | if __name__ == '__main__':
78 | absltest.main()
79 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/perciatelli44.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A frozen Perciatelli44 agent."""
17 |
18 | from typing import Sequence
19 |
20 | from balloon_learning_environment.agents import agent
21 | from balloon_learning_environment.models import models
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 |
26 | def load_perciatelli_session() -> tf.compat.v1.Session:
27 | serialized_perciatelli = models.load_perciatelli44()
28 |
29 | sess = tf.compat.v1.Session()
30 | graph_def = tf.compat.v1.GraphDef()
31 | graph_def.ParseFromString(serialized_perciatelli)
32 |
33 | tf.compat.v1.import_graph_def(graph_def)
34 | return sess
35 |
36 |
37 | class Perciatelli44(agent.Agent):
38 | """Perciatelli44 Agent.
39 |
40 | This is the agent which was reported as state of the art in
41 | "Autonomous navigation of stratospheric balloons using reinforcement
42 | learning" (Bellemare, Candido, Castro, Gong, Machado, Moitra, Ponda,
43 | and Wang, 2020).
44 |
45 | This agent has its weights frozen, and is intended for comparison in
46 | evaluation, not for retraining.
47 | """
48 |
49 | def __init__(self, num_actions: int, observation_shape: Sequence[int]):
50 | super(Perciatelli44, self).__init__(num_actions, observation_shape)
51 |
52 | if num_actions != 3:
53 | raise ValueError('Perciatelli44 only supports 3 actions.')
54 | if list(observation_shape) != [1099]:
55 | raise ValueError('Perciatelli44 only supports 1099 dimensional input.')
56 |
57 | # TODO(joshgreaves): It would be nice to use the saved_model API
58 | # for loading the Perciatelli graph.
59 | # TODO(joshgreaves): We wanted to avoid a dependency on TF, but adding
60 | # this to the agent registry makes TF a necessity.
61 | self._sess = load_perciatelli_session()
62 | self._action = self._sess.graph.get_tensor_by_name('sleepwalk_action:0')
63 | self._q_vals = self._sess.graph.get_tensor_by_name('q_values:0')
64 | self._observation = self._sess.graph.get_tensor_by_name('observation:0')
65 |
66 | def begin_episode(self, observation: np.ndarray) -> int:
67 | observation = observation.reshape((1, 1099))
68 | q_vals = self._sess.run(self._q_vals,
69 | feed_dict={self._observation: observation})
70 | return np.argmax(q_vals).item()
71 |
72 | def step(self, reward: float, observation: np.ndarray) -> int:
73 | observation = observation.reshape((1, 1099))
74 | q_vals = self._sess.run(self._q_vals,
75 | feed_dict={self._observation: observation})
76 | return np.argmax(q_vals).item()
77 |
78 | def end_episode(self, reward: float, terminal: bool = True) -> None:
79 | pass
80 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/perciatelli44_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for perciatelli44."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.agents import perciatelli44
20 | import numpy as np
21 |
22 |
23 | class Perciatelli44Test(absltest.TestCase):
24 |
25 | def setUp(self):
26 | super(Perciatelli44Test, self).setUp()
27 | self._perciatelli44 = perciatelli44.Perciatelli44(3, [1099])
28 | self._observation = np.ones(1099, dtype=np.float32)
29 |
30 | def test_begin_episode_returns_valid_action(self):
31 | action = self._perciatelli44.begin_episode(self._observation)
32 | self.assertIn(action, [0, 1, 2])
33 |
34 | def test_step_returns_valid_action(self):
35 | action = self._perciatelli44.begin_episode(self._observation)
36 | self.assertIn(action, [0, 1, 2])
37 |
38 | def test_perciatelli_only_accepts_3_actions(self):
39 | with self.assertRaises(ValueError):
40 | perciatelli44.Perciatelli44(4, [1099])
41 |
42 | def test_perciatelli_only_accepts_1099_dim_observation(self):
43 | with self.assertRaises(ValueError):
44 | perciatelli44.Perciatelli44(3, [1100])
45 |
46 |
47 | if __name__ == '__main__':
48 | absltest.main()
49 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/random_walk_agent.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An exploratory agent that selects actions based on a random walk.
17 |
18 | Note that this class assumes the features passed in correspond to the
19 | Perciatelli features (see balloon_learning_environment.env.features).
20 | """
21 |
22 |
23 | import datetime as dt
24 | import time
25 | from typing import Optional, Sequence
26 |
27 | from balloon_learning_environment.agents import agent
28 | from balloon_learning_environment.env import features
29 | from balloon_learning_environment.env.balloon import control
30 | from balloon_learning_environment.utils import constants
31 | from balloon_learning_environment.utils import sampling
32 | import gin
33 | import jax
34 | import numpy as np
35 |
36 |
37 | _PERCIATELLI_FEATURES_SHAPE = (1099,) # Expected shape of Perciatelli features.
38 | _HYSTERESIS = 100 # In Pascals.
39 | _STDDEV = 0.1666 # ~ 10 [Pa/min].
40 |
41 |
42 | # Although this class does not have any gin-configurable parameters, it is
43 | # decorated as gin-configurable so it can be injected into other classes
44 | # (e.g. MarcoPoloExploration).
45 | @gin.configurable
46 | class RandomWalkAgent(agent.Agent):
47 | """An exploratory agent that selects actions based on a random walk."""
48 |
49 | def __init__(self, num_actions: int, observation_shape: Sequence[int],
50 | seed: Optional[int] = None):
51 | del num_actions
52 | del observation_shape
53 | seed = int(time.time() * 1e6) if seed is None else seed
54 | self._rng = jax.random.PRNGKey(seed)
55 | self._time_elapsed = dt.timedelta()
56 | self._sample_new_target_pressure()
57 |
58 | def _sample_new_target_pressure(self):
59 | self._rng, rng = jax.random.split(self._rng)
60 | self._target_pressure = sampling.sample_pressure(rng)
61 |
62 | def _select_action(self, features_as_vector: np.ndarray) -> int:
63 | assert features_as_vector.shape == _PERCIATELLI_FEATURES_SHAPE
64 | balloon_pressure = features.NamedPerciatelliFeatures(
65 | features_as_vector).balloon_pressure
66 | # Note: higher pressures means lower altitude.
67 | if balloon_pressure - _HYSTERESIS > self._target_pressure:
68 | return control.AltitudeControlCommand.UP
69 |
70 | if balloon_pressure + _HYSTERESIS < self._target_pressure:
71 | return control.AltitudeControlCommand.DOWN
72 |
73 | return control.AltitudeControlCommand.STAY
74 |
75 | def begin_episode(self, observation: np.ndarray) -> int:
76 | self._time_elapsed = dt.timedelta()
77 | self._sample_new_target_pressure()
78 | return self._select_action(observation)
79 |
80 | def step(self, reward: float, observation: np.ndarray) -> int:
81 | del reward
82 | # Advance time_elapsed.
83 | self._time_elapsed += constants.AGENT_TIME_STEP
84 | # Update target pressure. This is essentially a random walk between
85 | # altitudes by sampling from zero-mean Gaussian noise, where the amount of
86 | # variance is proportional to the amount of time (in seconds) that has
87 | # elapsed since the last time it was updated.
88 | self._rng, rng = jax.random.split(self._rng)
89 | self._target_pressure += (
90 | self._time_elapsed.total_seconds() * _STDDEV * jax.random.normal(rng))
91 | return self._select_action(observation)
92 |
93 | def end_episode(self, reward: float, terminal: bool = True) -> None:
94 | pass
95 |
--------------------------------------------------------------------------------
/balloon_learning_environment/agents/station_seeker_agent_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for station_seeker_agent."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.agents import station_seeker_agent
20 | import numpy as np
21 |
22 |
23 | class StationSeekerAgentTest(absltest.TestCase):
24 |
25 | def setUp(self):
26 | super().setUp()
27 | self._na = 3
28 | self._observation_shape = (3, 4)
29 | self._ss_agent = station_seeker_agent.StationSeekerAgent(
30 | self._na, self._observation_shape)
31 |
32 | def test_create_agent(self):
33 | self.assertEqual('StationSeekerAgent', self._ss_agent.get_name())
34 |
35 | def test_action_selection(self):
36 | mock_observation = np.zeros(1099)
37 |
38 | # Because the observation is uniform everywhere, we expect the controller
39 | # to stay (action = 1).
40 | for _ in range(10): # Test for 10 episodes.
41 | action = self._ss_agent.begin_episode(mock_observation)
42 | self.assertEqual(action, 1)
43 | for _ in range(20): # Each episode includes 20 steps.
44 | action = self._ss_agent.step(0.0, mock_observation)
45 | self.assertEqual(action, 1)
46 |
47 | def test_end_episode(self):
48 | # end_episode doesn't do anything (it exists to conform to the Agent
49 | # interface). This next line just checks that it runs without problems.
50 | self._ss_agent.end_episode(0.0, True)
51 |
52 | # TODO(bellemare): Test wind score: decreases as bearing increases.
53 |
54 | if __name__ == '__main__':
55 | absltest.main()
56 |
--------------------------------------------------------------------------------
/balloon_learning_environment/colab/summarize_eval.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Summarize Eval",
7 | "provenance": [],
8 | "collapsed_sections": []
9 | },
10 | "kernelspec": {
11 | "name": "python3",
12 | "display_name": "Python 3"
13 | },
14 | "language_info": {
15 | "name": "python"
16 | }
17 | },
18 | "cells": [
19 | {
20 | "cell_type": "markdown",
21 | "source": [
22 | "Copyright 2021 The Balloon Learning Environment Authors.\n",
23 | "\n",
24 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
25 | "you may not use this file except in compliance with the License.\n",
26 | "You may obtain a copy of the License at\n",
27 | "\n",
28 | " http://www.apache.org/licenses/LICENSE-2.0\n",
29 | "\n",
30 | "Unless required by applicable law or agreed to in writing, software\n",
31 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
32 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
33 | "See the License for the specific language governing permissions and\n",
34 | "limitations under the License."
35 | ],
36 | "metadata": {
37 | "id": "u8hGsWm_qiGT"
38 | }
39 | },
40 | {
41 | "cell_type": "code",
42 | "metadata": {
43 | "cellView": "form",
44 | "id": "VoqfM73Tmji_"
45 | },
46 | "source": [
47 | "# @title Imports\n",
48 | "import collections\n",
49 | "import json\n",
50 | "\n",
51 | "from google.colab import files\n",
52 | "import matplotlib.pyplot as plt\n",
53 | "import pandas as pd\n",
54 | "\n",
55 | "from IPython.display import HTML"
56 | ],
57 | "execution_count": null,
58 | "outputs": []
59 | },
60 | {
61 | "cell_type": "code",
62 | "metadata": {
63 | "cellView": "form",
64 | "id": "GgEqoO-9mpkt"
65 | },
66 | "source": [
67 | "# @title Upload Data\n",
68 | "uploaded_files = files.upload()\n",
69 | "dataframes = dict()\n",
70 | "\n",
71 | "for name, data in uploaded_files.items():\n",
72 | " name = name.rsplit('.', maxsplit=1)[0]\n",
73 | " json_data = json.loads(data)\n",
74 | " dataframes[name] = pd.DataFrame.from_dict(json_data)"
75 | ],
76 | "execution_count": null,
77 | "outputs": []
78 | },
79 | {
80 | "cell_type": "code",
81 | "metadata": {
82 | "cellView": "form",
83 | "id": "JGMAmrANm0an"
84 | },
85 | "source": [
86 | "# @title Print Aggregate results\n",
87 | "aggregate_data = collections.defaultdict(list)\n",
88 | "for name, df in dataframes.items():\n",
89 | " aggregate_data['num episodes'].append(len(df.out_of_power))\n",
90 | " aggregate_data['out of power'].append(df.out_of_power.sum())\n",
91 | " aggregate_data['zeropressure'].append(df.zeropressure.sum())\n",
92 | " aggregate_data['envelope burst'].append(df.envelope_burst.sum())\n",
93 | "\n",
94 | " finished_runs = df.loc[df.out_of_power == False]\n",
95 | " finished_runs = finished_runs.loc[finished_runs.zeropressure == False]\n",
96 | " finished_runs = finished_runs.loc[finished_runs.envelope_burst == False]\n",
97 | " aggregate_data['mean cumulative reward (finished episodes)'].append(\n",
98 | " finished_runs.cumulative_reward.mean())\n",
99 | " aggregate_data['mean TWR50 (finished episodes)'].append(\n",
100 | " finished_runs.time_within_radius.mean())\n",
101 | " aggregate_data['mean cumulative reward (all episodes)'].append(\n",
102 | " df.cumulative_reward.mean())\n",
103 | " aggregate_data['mean TWR50 (all episodes)'].append(\n",
104 | " df.time_within_radius.mean())\n",
105 | "\n",
106 | "df = pd.DataFrame(aggregate_data)\n",
107 | "df.index = dataframes.keys()\n",
108 | "\n",
109 | "# This is a little hacky, but it works 🤷♂️ \n",
110 | "html = df.to_html()\n",
111 | "html += \"\"\n",
112 | "HTML(html)"
113 | ],
114 | "execution_count": null,
115 | "outputs": []
116 | }
117 | ]
118 | }
--------------------------------------------------------------------------------
/balloon_learning_environment/env/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/acs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """ACS power as a function of superpressure."""
17 |
18 | from balloon_learning_environment.utils import constants
19 | from balloon_learning_environment.utils import units
20 | import numpy as np
21 | from scipy import interpolate
22 |
23 |
24 | _PRESSURE_RATIO_TO_POWER: interpolate.interp1d = (
25 | interpolate.interp1d(
26 | np.array([1.0, 1.05, 1.2, 1.25, 1.35]), # pressure_ratio
27 | np.array([100.0, 100.0, 300.0, 400.0, 400.0]), # power
28 | fill_value='extrapolate'))
29 |
30 |
31 | _PRESSURE_RATIO_POWER_TO_EFFICIENCY: interpolate.interp2d = (
32 | interpolate.interp2d(
33 | np.linspace(1.05, 1.35, 13), # pressure_ratio
34 | np.linspace(100.0, 400.0, 4), # power
35 | np.array([0.4, 0.4, 0.3, 0.2, 0.2, 0.00000, 0.00000, 0.00000, 0.00000,
36 | 0.00000, 0.00000, 0.00000, 0.00000, 0.4, 0.3, 0.3, 0.30, 0.25,
37 | 0.23, 0.20, 0.15, 0.12, 0.10, 0.00000, 0.00000, 0.00000,
38 | 0.00000, 0.3, 0.25, 0.25, 0.25, 0.20, 0.20, 0.20, 0.2, 0.15,
39 | 0.13, 0.12, 0.11, 0.00000, 0.23, 0.23, 0.23, 0.23, 0.23, 0.20,
40 | 0.20, 0.20, 0.18, 0.16, 0.15, 0.13]), # efficiency
41 | fill_value=None))
42 |
43 |
44 | def get_most_efficient_power(pressure_ratio: float) -> units.Power:
45 | """Lookup the optimal operating power from static tables.
46 |
47 | Gets the optimal operating power [W] in terms of kg of air moved per unit
48 | energy.
49 |
50 | Args:
51 | pressure_ratio: Ratio of (balloon pressure + superpresure) to baloon
52 | pressure.
53 |
54 | Returns:
55 | Optimal ACS power at current pressure_ratio.
56 | """
57 | power = _PRESSURE_RATIO_TO_POWER(pressure_ratio)
58 | return units.Power(watts=power)
59 |
60 |
61 | def get_fan_efficiency(pressure_ratio: float, power: units.Power) -> float:
62 | # Compute efficiency of air flow from current pressure ratio and power.
63 | efficiency = _PRESSURE_RATIO_POWER_TO_EFFICIENCY(pressure_ratio, power.watts)
64 | return float(efficiency)
65 |
66 |
67 | def get_mass_flow(power: units.Power, efficiency: float) -> float:
68 | return efficiency * power.watts / constants.NUM_SECONDS_PER_HOUR
69 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/acs_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.env.acs."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from balloon_learning_environment.env.balloon import acs
21 | from balloon_learning_environment.utils import units
22 |
23 |
24 | class AcsTest(parameterized.TestCase):
25 |
26 | @parameterized.named_parameters(
27 | dict(testcase_name='at_min', pressure_ratio=1.0, power=100.0,
28 | comparator='eq'),
29 | dict(testcase_name='at_mid', pressure_ratio=1.2, power=300.0,
30 | comparator='eq'),
31 | dict(testcase_name='at_max', pressure_ratio=1.35, power=400.0,
32 | comparator='eq'),
33 | dict(testcase_name='below_min', pressure_ratio=0.01, power=100.0,
34 | comparator='lt'),
35 | dict(testcase_name='above_max', pressure_ratio=2.0, power=400.0,
36 | comparator='gt'))
37 | def test_get_most_efficient_power(self, pressure_ratio, power, comparator):
38 | if comparator == 'eq':
39 | comparator = self.assertEqual
40 | elif comparator == 'lt':
41 | comparator = self.assertLessEqual
42 | else:
43 | comparator = self.assertGreaterEqual
44 | comparator(acs.get_most_efficient_power(pressure_ratio).watts, power)
45 |
46 | @parameterized.named_parameters(
47 | dict(testcase_name='at_min', pressure_ratio=1.05, power=100.0,
48 | efficiency=0.4, comparator='eq'),
49 | dict(testcase_name='at_max', pressure_ratio=1.35, power=400.0,
50 | efficiency=0.13, comparator='eq'),
51 | dict(testcase_name='below_min', pressure_ratio=0.01, power=10.0,
52 | efficiency=0.4, comparator='gt'),
53 | dict(testcase_name='above_max', pressure_ratio=2.0, power=500.0,
54 | efficiency=0.13, comparator='lt'))
55 | def test_get_fan_efficiency(self, pressure_ratio, power, efficiency,
56 | comparator):
57 | if comparator == 'eq':
58 | comparator = self.assertEqual
59 | elif comparator == 'lt':
60 | comparator = self.assertLessEqual
61 | else:
62 | comparator = self.assertGreaterEqual
63 | comparator(acs.get_fan_efficiency(pressure_ratio, units.Power(watts=power)),
64 | efficiency)
65 |
66 | def test_get_mass_flow(self):
67 | self.assertEqual(
68 | acs.get_mass_flow(units.Power(watts=3.6), 10.0), 0.01)
69 |
70 |
71 | if __name__ == '__main__':
72 | absltest.main()
73 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/altitude_safety.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Prevents the balloon navigating to unsafe altitudes.
17 |
18 | This safety layer prevents the balloon from navigating to below 50,000 feet
19 | of altitude. Internally, it maintains a state machine to remember whether
20 | the balloon is close to the altitude limit. The balloon only returns to the
21 | nominal state once the balloon has navigated sufficiently far from the
22 | altitude limit. If the balloon moves below the altitude limit, the safety
23 | layer will issue the ascend command.
24 | """
25 |
26 | import enum
27 | import logging
28 |
29 | from balloon_learning_environment.env.balloon import control
30 | from balloon_learning_environment.env.balloon import standard_atmosphere
31 | from balloon_learning_environment.utils import units
32 | import transitions
33 |
34 | # TODO(joshgreaves): This may require some tuning.
35 | BUFFER = units.Distance(feet=500.0)
36 | RESTART_HYSTERESIS = units.Distance(feet=500.0)
37 | MIN_ALTITUDE = units.Distance(feet=50_000.0)
38 |
39 |
40 | class _AltitudeState(enum.Enum):
41 | NOMINAL = 0
42 | LOW = 1
43 | VERY_LOW = 2
44 |
45 |
46 | # Note: Transitions are applied in the order of the first match.
47 | # '*' is a catch-all, and applies to any state.
48 | _ALTITUDE_SAFETY_TRANSITIONS = (
49 | dict(trigger='very_low', source='*', dest=_AltitudeState.VERY_LOW),
50 | dict(trigger='low', source='*', dest=_AltitudeState.LOW),
51 | dict(
52 | trigger='low_nominal',
53 | source=(_AltitudeState.VERY_LOW, _AltitudeState.LOW),
54 | dest=_AltitudeState.LOW),
55 | dict(
56 | trigger='low_nominal',
57 | source=_AltitudeState.NOMINAL,
58 | dest=_AltitudeState.NOMINAL),
59 | dict(trigger='nominal', source='*', dest=_AltitudeState.NOMINAL),
60 | )
61 |
62 |
63 | class AltitudeSafetyLayer:
64 | """A safety layer that prevents balloons navigating to unsafe altitudes."""
65 |
66 | def __init__(self):
67 | self._state_machine = transitions.Machine(
68 | states=_AltitudeState,
69 | transitions=_ALTITUDE_SAFETY_TRANSITIONS,
70 | initial=_AltitudeState.NOMINAL)
71 | logging.getLogger('transitions').setLevel(logging.WARNING)
72 |
73 | def get_action(self, action: control.AltitudeControlCommand,
74 | atmosphere: standard_atmosphere.Atmosphere,
75 | pressure: float) -> control.AltitudeControlCommand:
76 | """Gets the action recommended by the safety layer.
77 |
78 | Args:
79 | action: The action the controller has supplied to the balloon.
80 | atmosphere: The atmospheric conditions the balloon is flying in.
81 | pressure: The current pressure of the balloon.
82 |
83 | Returns:
84 | An action the safety layer recommends.
85 | """
86 | altitude = atmosphere.at_pressure(pressure).height
87 | self._transition_state(altitude)
88 |
89 | if self._state_machine.state == _AltitudeState.VERY_LOW:
90 | # If the balloon is too low, make it ascend.
91 | return control.AltitudeControlCommand.UP
92 | elif self._state_machine.state == _AltitudeState.LOW:
93 | # If the balloon is almost too low, don't let it go lower.
94 | if action == control.AltitudeControlCommand.DOWN:
95 | return control.AltitudeControlCommand.STAY
96 |
97 | return action
98 |
99 | @property
100 | def navigation_is_paused(self):
101 | return self._state_machine.state != _AltitudeState.NOMINAL
102 |
103 | def _transition_state(self, altitude: units.Distance):
104 | if altitude < MIN_ALTITUDE:
105 | self._state_machine.very_low()
106 | elif altitude < MIN_ALTITUDE + BUFFER:
107 | self._state_machine.low()
108 | elif altitude < MIN_ALTITUDE + BUFFER + RESTART_HYSTERESIS:
109 | self._state_machine.low_nominal()
110 | else:
111 | self._state_machine.nominal()
112 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/control.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Control related functionality."""
17 |
18 | import enum
19 |
20 |
21 | class AltitudeControlCommand(enum.IntEnum):
22 | """Specifies the command/action for balloon navigation."""
23 | DOWN = 0
24 | STAY = 1
25 | UP = 2
26 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/power_table.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A lookup table from pressure ratio, state of charge -> power to use."""
17 |
18 | import bisect
19 |
20 |
21 | def lookup(pressure_ratio: float,
22 | state_of_charge: float) -> float:
23 | """Map pressure_ratio x state_of_charge to power to use."""
24 | assert pressure_ratio >= 0.99 and pressure_ratio <= 5
25 | pressure_ratio_intervals = [1.08, 1.11, 1.14, 1.17, 1.2, 1.23, 1.26]
26 | soc_mappings = [ # One entry for each pressure ratio interval.
27 | ([0.3, 0.4, 0.5], [0, 150, 175, 200]), # 0.99 -> 1.08
28 | ([0.3, 0.4, 0.7], [0, 200, 200, 225]), # 1.08 -> 1.11
29 | ([0.3, 0.4, 0.6], [0, 225, 225, 250]), # 1.11 -> 1.14
30 | ([0.3, 0.4, 0.5], [0, 200, 225, 250]), # 1.14 -> 1.17
31 | ([0.3, 0.4, 0.5], [0, 225, 250, 275]), # 1.17 -> 1.2
32 | ([0.4, 0.5], [0, 275, 300]), # 1.2 -> 1.23
33 | ([0.5, 0.6], [0, 300, 325]), # 1.23 -> 1.26
34 | ([0.5, 0.6], [0, 325, 350]) # 1.26 -> 5.0
35 | ]
36 | pr_id = bisect.bisect(pressure_ratio_intervals, pressure_ratio)
37 | soc_id = bisect.bisect(soc_mappings[pr_id][0], state_of_charge)
38 | return soc_mappings[pr_id][1][soc_id]
39 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/power_table_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for power_table."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from balloon_learning_environment.env.balloon import power_table
21 |
22 |
23 | class PowerTableTest(parameterized.TestCase):
24 |
25 | @parameterized.named_parameters(
26 | dict(testcase_name='low_pressure_ratio', pressure_ratio=0.98),
27 | dict(testcase_name='high_pressure_ratio', pressure_ratio=5.01))
28 | def test_invalid_pressure_ratios(self, pressure_ratio: float):
29 | with self.assertRaises(AssertionError):
30 | power_table.lookup(pressure_ratio, 1.0)
31 |
32 | @parameterized.parameters(
33 | (1.0, 0.2, 0),
34 | (1.0, 0.3, 150),
35 | (1.0, 0.4, 175),
36 | (1.0, 0.5, 200),
37 | (1.0, 0.6, 200),
38 | (1.08, 0.2, 0),
39 | (1.08, 0.3, 200),
40 | (1.08, 0.4, 200),
41 | (1.08, 0.7, 225),
42 | (1.08, 0.8, 225),
43 | (1.11, 0.2, 0),
44 | (1.11, 0.3, 225),
45 | (1.11, 0.4, 225),
46 | (1.11, 0.6, 250),
47 | (1.11, 0.7, 250),
48 | (1.14, 0.2, 0),
49 | (1.14, 0.3, 200),
50 | (1.14, 0.4, 225),
51 | (1.14, 0.5, 250),
52 | (1.14, 0.6, 250),
53 | (1.17, 0.2, 0),
54 | (1.17, 0.3, 225),
55 | (1.17, 0.4, 250),
56 | (1.17, 0.5, 275),
57 | (1.17, 0.6, 275),
58 | (1.2, 0.3, 0),
59 | (1.2, 0.4, 275),
60 | (1.2, 0.5, 300),
61 | (1.2, 0.6, 300),
62 | (1.23, 0.4, 0),
63 | (1.23, 0.5, 300),
64 | (1.23, 0.6, 325),
65 | (1.23, 0.7, 325),
66 | (1.26, 0.4, 0),
67 | (1.26, 0.5, 325),
68 | (1.26, 0.6, 350),
69 | (1.26, 0.7, 350))
70 | def test_table_lookup(self, pressure_ratio, state_of_charge,
71 | expected_power_to_use):
72 | self.assertEqual(power_table.lookup(pressure_ratio, state_of_charge),
73 | expected_power_to_use)
74 |
75 |
76 | if __name__ == '__main__':
77 | absltest.main()
78 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/pressure_range_builder_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for pressure_range_builder."""
17 |
18 | import functools
19 |
20 | from absl.testing import absltest
21 | from balloon_learning_environment.env.balloon import altitude_safety
22 | from balloon_learning_environment.env.balloon import pressure_range_builder
23 | from balloon_learning_environment.env.balloon import standard_atmosphere
24 | from balloon_learning_environment.utils import test_helpers
25 | import jax
26 |
27 |
28 | class AltitudeRangeBuilderTest(absltest.TestCase):
29 |
30 | def setUp(self):
31 | super().setUp()
32 | self.atmosphere = standard_atmosphere.Atmosphere(jax.random.PRNGKey(0))
33 | self.create_balloon = functools.partial(
34 | test_helpers.create_balloon, atmosphere=self.atmosphere)
35 |
36 | def test_get_pressure_range_returns_valid_range(self):
37 | b = self.create_balloon()
38 |
39 | pressure_range = pressure_range_builder.get_pressure_range(
40 | b.state, self.atmosphere)
41 |
42 | self.assertIsInstance(pressure_range,
43 | pressure_range_builder.AccessiblePressureRange)
44 | self.assertBetween(pressure_range.min_pressure, 1000.0, 100_000.0)
45 | self.assertBetween(pressure_range.max_pressure, 1000.0, 100_000.0)
46 |
47 | def test_get_pressure_range_returns_min_pressure_below_max_pressure(self):
48 | b = self.create_balloon()
49 |
50 | pressure_range = pressure_range_builder.get_pressure_range(
51 | b.state, self.atmosphere)
52 |
53 | self.assertLess(pressure_range.min_pressure, pressure_range.max_pressure)
54 |
55 | def test_get_pressure_range_returns_max_pressure_above_min_altitude(self):
56 | b = self.create_balloon()
57 |
58 | pressure_range = pressure_range_builder.get_pressure_range(
59 | b.state, self.atmosphere)
60 |
61 | self.assertLessEqual(
62 | pressure_range.max_pressure,
63 | self.atmosphere.at_height(altitude_safety.MIN_ALTITUDE).pressure)
64 |
65 | # TODO(joshgreaves): Add more tests when the pressure ranges are as expected.
66 |
67 |
68 | if __name__ == '__main__':
69 | absltest.main()
70 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon/stable_init_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for stable_params."""
17 |
18 | import datetime as dt
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | from balloon_learning_environment.env import wind_field
23 | from balloon_learning_environment.env.balloon import control
24 | from balloon_learning_environment.env.balloon import solar
25 | from balloon_learning_environment.env.balloon import standard_atmosphere
26 | from balloon_learning_environment.env.balloon import thermal
27 | from balloon_learning_environment.utils import test_helpers
28 | from balloon_learning_environment.utils import units
29 | import jax
30 |
31 |
32 | class StableParamsTest(parameterized.TestCase):
33 |
34 | def setUp(self):
35 | super().setUp()
36 | self.atmosphere = standard_atmosphere.Atmosphere(jax.random.PRNGKey(38))
37 |
38 | @parameterized.named_parameters(
39 | dict(testcase_name='middle_pressure', init_pressure=9_500.0),
40 | dict(testcase_name='high_pressure', init_pressure=11_500.0),
41 | dict(testcase_name='low_pressure', init_pressure=6_500.0))
42 | def test_cold_start_to_stable_params_initializes_mols_air_correctly(
43 | self, init_pressure: float):
44 | # The superpressure is very sensitive to temperature, and hence time of
45 | # day, so create the balloon at midnight.
46 | # create_balloon runs cold_start_to_stable_params by default.
47 | b = test_helpers.create_balloon(
48 | pressure=init_pressure,
49 | date_time=units.datetime(2020, 6, 1, 0, 0, 0),
50 | atmosphere=self.atmosphere)
51 |
52 | # Simulate the balloon for a while. If the internal parameters are
53 | # correctly set, it should stay at roughly the correct pressure level.
54 | for _ in range(100):
55 | b.simulate_step(
56 | wind_field.WindVector(
57 | units.Velocity(mps=3.0), units.Velocity(mps=-4.0)),
58 | self.atmosphere, control.AltitudeControlCommand.STAY,
59 | dt.timedelta(seconds=10.0))
60 |
61 | self.assertLess(abs(b.state.pressure - init_pressure), 100.0)
62 |
63 | @parameterized.named_parameters(
64 | dict(testcase_name='middle_pressure', init_pressure=9_500.0),
65 | dict(testcase_name='high_pressure', init_pressure=11_500.0),
66 | dict(testcase_name='low_pressure', init_pressure=5_000.0))
67 | def test_cold_start_to_stable_params_initializes_temperature_correctly(
68 | self, init_pressure: float):
69 | # create_balloon runs cold_start_to_stable_params by default.
70 | b = test_helpers.create_balloon(
71 | pressure=init_pressure, atmosphere=self.atmosphere)
72 |
73 | solar_elevation, _, solar_flux = solar.solar_calculator(
74 | b.state.latlng, b.state.date_time)
75 | d_internal_temp = thermal.d_balloon_temperature_dt(
76 | b.state.envelope_volume, b.state.envelope_mass,
77 | b.state.internal_temperature, b.state.ambient_temperature,
78 | b.state.pressure, solar_elevation, solar_flux,
79 | b.state.upwelling_infrared)
80 |
81 | # If the rate of change of temperature is low, the temperature is
82 | # initialized to a stable value.
83 | self.assertLess(d_internal_temp, 1e-3)
84 |
85 |
86 | if __name__ == '__main__':
87 | absltest.main()
88 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/balloon_arena_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.env.balloon_arena."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from balloon_learning_environment.utils import constants
21 | from balloon_learning_environment.utils import test_helpers
22 | from balloon_learning_environment.utils import units
23 | import jax
24 |
25 |
26 | class BalloonArenaTest(parameterized.TestCase):
27 |
28 | # TODO(joshgreaves): Patch PerciatelliFeatureConstructor, it's too slow.
29 |
30 | def test_int_seeding_gives_determinisic_balloon_initialization(self):
31 | arena1 = test_helpers.create_arena()
32 | arena2 = test_helpers.create_arena()
33 |
34 | arena1.reset(201)
35 | arena2.reset(201)
36 | balloon_state1 = arena1.get_simulator_state().balloon_state
37 | balloon_state2 = arena2.get_simulator_state().balloon_state
38 |
39 | test_helpers.compare_balloon_states(balloon_state1, balloon_state2)
40 |
41 | def test_array_seeding_gives_determinisic_balloon_initialization(self):
42 | arena1 = test_helpers.create_arena()
43 | arena2 = test_helpers.create_arena()
44 |
45 | arena1.reset(jax.random.PRNGKey(201))
46 | arena2.reset(jax.random.PRNGKey(201))
47 | balloon_state1 = arena1.get_simulator_state().balloon_state
48 | balloon_state2 = arena2.get_simulator_state().balloon_state
49 | test_helpers.compare_balloon_states(balloon_state1, balloon_state2)
50 |
51 | def test_different_seeds_gives_different_initialization(self):
52 | arena1 = test_helpers.create_arena()
53 | arena2 = test_helpers.create_arena()
54 |
55 | arena1.reset(201)
56 | arena2.reset(202)
57 | balloon_state1 = arena1.get_simulator_state().balloon_state
58 | balloon_state2 = arena2.get_simulator_state().balloon_state
59 | test_helpers.compare_balloon_states(
60 | balloon_state1, balloon_state2, check_not_equal=['x', 'y'])
61 |
62 | def test_random_seeding_doesnt_throw_exception(self):
63 | arena = test_helpers.create_arena()
64 |
65 | arena.reset()
66 | # Succeeds if no error was thrown
67 |
68 | @parameterized.named_parameters((str(x), x) for x in (1, 5, 28, 90, 106, 378))
69 | def test_balloon_is_initialized_within_200km(self, seed: int):
70 | arena = test_helpers.create_arena()
71 |
72 | arena.reset(seed)
73 | balloon_state = arena.get_simulator_state().balloon_state
74 |
75 | distance = units.relative_distance(balloon_state.x, balloon_state.y)
76 | self.assertLessEqual(distance.km, 200.0)
77 |
78 | @parameterized.named_parameters((str(x), x) for x in (1, 5, 28, 90, 106, 378))
79 | def test_balloon_is_initialized_within_valid_pressure_range(self, seed: int):
80 | arena = test_helpers.create_arena()
81 |
82 | arena.reset(seed)
83 | balloon_state = arena.get_simulator_state().balloon_state
84 |
85 | self.assertBetween(balloon_state.pressure,
86 | constants.PERCIATELLI_PRESSURE_RANGE_MIN,
87 | constants.PERCIATELLI_PRESSURE_RANGE_MAX)
88 |
89 | if __name__ == '__main__':
90 | absltest.main()
91 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/grid_wind_field_sampler.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An interface for sampling grid wind fields."""
17 |
18 | import abc
19 | import datetime as dt
20 |
21 | from balloon_learning_environment.generative import vae
22 | from jax import numpy as jnp
23 | import numpy as np
24 |
25 |
26 | class GridWindFieldSampler(abc.ABC):
27 |
28 | @property
29 | @abc.abstractmethod
30 | def field_shape(self) -> vae.FieldShape:
31 | """Gets the field shape of wind fields sampled by this class."""
32 |
33 | @abc.abstractmethod
34 | def sample_field(self,
35 | key: jnp.ndarray,
36 | date_time: dt.datetime) -> np.ndarray:
37 | """Samples a wind field and returns it as a numpy array.
38 |
39 | Args:
40 | key: A PRNGKey to use for sampling.
41 | date_time: The date_time of the begining on the wind field.
42 | """
43 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/gym.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Balloon Learning Environment gym utilities."""
17 | import contextlib
18 |
19 |
20 | def register_env() -> None:
21 | """Register the Gym environment."""
22 | # We need to import Gym's registration module inline or else we'll
23 | # get a circular dependency that will result in an error when importing gym
24 | from gym.envs import registration # pylint: disable=g-import-not-at-top
25 |
26 | env_id = 'BalloonLearningEnvironment-v0'
27 | env_entry_point = 'balloon_learning_environment.env.balloon_env:BalloonEnv'
28 | # We guard registration by checking if our env is already registered
29 | # This is necesarry because the plugin system will load our module
30 | # which also calls this function. If multiple `register()` calls are
31 | # made this will result in a warning to the user.
32 | registered = env_id in registration.registry.env_specs
33 |
34 | if not registered:
35 | with contextlib.ExitStack() as stack:
36 | # This is a workaround for Gym 0.21 which didn't support
37 | # registering into the root namespace with the plugin system.
38 | if hasattr(registration, 'namespace'):
39 | stack.enter_context(registration.namespace(None))
40 | registration.register(id=env_id, entry_point=env_entry_point)
41 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/rendering/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/rendering/renderer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Abstract base class for renderers."""
17 |
18 | import abc
19 | from typing import Iterable, Optional, Text, Union
20 |
21 | from balloon_learning_environment.env import simulator_data
22 | from flax.metrics import tensorboard
23 | import numpy as np
24 |
25 |
26 | class Renderer(abc.ABC):
27 | """A renderer object for rendering the simulator state."""
28 |
29 | @abc.abstractmethod
30 | def reset(self) -> None:
31 | pass
32 |
33 | @abc.abstractmethod
34 | def step(self, state: simulator_data.SimulatorState) -> None:
35 | pass
36 |
37 | @abc.abstractmethod
38 | def render(self,
39 | mode: Text,
40 | summary_writer: Optional[tensorboard.SummaryWriter] = None,
41 | iteration: Optional[int] = None) -> Union[None, np.ndarray, Text]:
42 | """Renders a frame.
43 |
44 | Args:
45 | mode: A string specifying the mode. Default gym render modes are `human`,
46 | `rgb_array`, and `ansi`. However, a renderer may specify additional
47 | render modes beyond this. `human` corresponds to rendering directly to
48 | the screen. `rgb_array` renders to a numpy array and returns it. `ansi`
49 | renders to a string or StringIO object.
50 | summary_writer: If not None, will also render the image to the tensorboard
51 | summary.
52 | iteration: Iteration number used for writing to tensorboard.
53 |
54 | Returns:
55 | None, a numpy array of rgb data, or a Text object, depending on the mode.
56 | """
57 | pass
58 |
59 | @property
60 | @abc.abstractmethod
61 | def render_modes(self) -> Iterable[Text]:
62 | pass
63 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/simulator_data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Model classes for simulator state and simulator observations."""
17 |
18 | import dataclasses
19 |
20 | from balloon_learning_environment.env import wind_field
21 | from balloon_learning_environment.env.balloon import balloon
22 | from balloon_learning_environment.env.balloon import standard_atmosphere
23 |
24 |
25 | @dataclasses.dataclass
26 | class SimulatorState(object):
27 | """Specifies the full state of the simulator.
28 |
29 | Since it specifies the full state of the simulator, it should be
30 | possible to use this for checkpointing and restoring the simulator.
31 | """
32 | balloon_state: balloon.BalloonState
33 | wind_field: wind_field.WindField
34 | atmosphere: standard_atmosphere.Atmosphere
35 |
36 |
37 | @dataclasses.dataclass
38 | class SimulatorObservation(object):
39 | """Specifies an observation from the simulator.
40 |
41 | This differs from SimulatorState in that the observations are not
42 | ground truth state, and are instead noisy observations from the
43 | environment.
44 | """
45 | balloon_observation: balloon.BalloonState
46 | wind_at_balloon: wind_field.WindVector
47 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/wind_field_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.env.wind_field."""
17 |
18 | import datetime as dt
19 |
20 | from absl.testing import absltest
21 | from balloon_learning_environment.env import wind_field
22 | from balloon_learning_environment.utils import units
23 |
24 |
25 | class WindFieldTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super(WindFieldTest, self).setUp()
29 | self.x = units.Distance(km=2.1)
30 | self.y = units.Distance(km=2.2)
31 | self.delta = dt.timedelta(minutes=3)
32 |
33 | def test_some_altitude_goes_north(self):
34 | field = wind_field.SimpleStaticWindField()
35 | self.assertEqual(
36 | wind_field.WindVector(
37 | units.Velocity(mps=0.0), units.Velocity(mps=10.0)),
38 | field.get_forecast(self.x, self.y, 9323.0, self.delta))
39 |
40 | def test_some_altitude_goes_south(self):
41 | field = wind_field.SimpleStaticWindField()
42 | self.assertEqual(
43 | wind_field.WindVector(
44 | units.Velocity(mps=0.0), units.Velocity(mps=-10.0)),
45 | field.get_forecast(self.x, self.y, 13999.0, self.delta))
46 |
47 | def test_some_altitude_goes_east(self):
48 | field = wind_field.SimpleStaticWindField()
49 | self.assertEqual(
50 | wind_field.WindVector(
51 | units.Velocity(mps=10.0), units.Velocity(mps=0.0)),
52 | field.get_forecast(self.x, self.y, 5523.0, self.delta))
53 |
54 | def test_some_altitude_goes_west(self):
55 | field = wind_field.SimpleStaticWindField()
56 | self.assertEqual(
57 | wind_field.WindVector(
58 | units.Velocity(mps=-10.0), units.Velocity(mps=0.0)),
59 | field.get_forecast(self.x, self.y, 11212.0, self.delta))
60 |
61 | def test_get_forecast_column_gives_same_result_as_get_forecast(self):
62 | field = wind_field.SimpleStaticWindField()
63 | forecast_10k = field.get_forecast(self.x, self.y, 10_000.0, self.delta)
64 | forecast_11k = field.get_forecast(self.x, self.y, 11_000.0, self.delta)
65 | forecast_column = field.get_forecast_column(
66 | self.x, self.y, [10_000.0, 11_000.0], self.delta)
67 |
68 | self.assertEqual(forecast_10k, forecast_column[0])
69 | self.assertEqual(forecast_11k, forecast_column[1])
70 |
71 |
72 | if __name__ == '__main__':
73 | absltest.main()
74 |
--------------------------------------------------------------------------------
/balloon_learning_environment/env/wind_gp_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.env.wind_gp."""
17 |
18 | import datetime as dt
19 |
20 | from absl.testing import absltest
21 | from balloon_learning_environment.env import wind_field
22 | from balloon_learning_environment.env import wind_gp
23 | from balloon_learning_environment.utils import units
24 |
25 |
26 | class WindGpTest(absltest.TestCase):
27 |
28 | def setUp(self):
29 | super(WindGpTest, self).setUp()
30 | # TODO(bellemare): If multiple wind fields are available, use the
31 | # vanilla (4 sheets) one.
32 | wf = wind_field.SimpleStaticWindField()
33 |
34 | # Sets up a model with a dummy forecast.
35 | model = wind_gp.WindGP(wf)
36 | self.model = model
37 | self.x = units.Distance(m=0.0)
38 | self.y = units.Distance(m=0.0)
39 | self.pressure = 0.0
40 | self.delta = dt.timedelta(seconds=0)
41 | self.wind_vector = wind_field.WindVector(
42 | units.Velocity(mps=1.0), units.Velocity(mps=1.0))
43 |
44 | def test_measurement_has_almost_no_variance(self):
45 | self.model.observe(self.x, self.y, self.pressure, self.delta,
46 | self.wind_vector)
47 | post_measurement = self.model.query(
48 | self.x, self.y, self.pressure, self.delta)
49 |
50 | # Variance is measurement noise at a measured point.
51 | # This constant was computed anaytically. It is given by (see wind_gp.py):
52 | # SIGMA_NOISE_SQUARED / (SIGMA_NOISE_SQUARED + SIGMA_EXP_SQUARED) .
53 | self.assertAlmostEqual(post_measurement[1].item(), 0.003843, places=3)
54 |
55 | def test_observations_affect_forecast_continuously(self):
56 | pre_measurement = self.model.query(
57 | self.x, self.y, self.pressure, self.delta)
58 | self.model.observe(self.x, self.y, self.pressure, self.delta,
59 | self.wind_vector)
60 | post_measurement = self.model.query(
61 | units.Distance(km=0.05), self.y, self.pressure, self.delta)
62 |
63 | self.assertTrue((pre_measurement[0] != post_measurement[0]).all())
64 |
65 | if __name__ == '__main__':
66 | absltest.main()
67 |
--------------------------------------------------------------------------------
/balloon_learning_environment/eval/README.md:
--------------------------------------------------------------------------------
1 | # Balloon Learning Environment Evaluation
2 | This directory includes scripts for evaluating trained agents in the
3 | Balloon Learning Environment.
4 |
5 | ## 1. Run Eval
6 | The following example code will run eval with the random agent on one seed.
7 | For more configurations, see the flags at [eval.py](https://github.com/google/balloon-learning-environment/blob/master/balloon_learning_environment/eval/eval.py).
8 |
9 | ```
10 | python -m balloon_learning_environment.eval.eval \
11 | --output_dir=/tmp/ble/eval \
12 | --agent=random \
13 | --suite=micro_eval
14 | ```
15 | An [evaluation suite](https://github.com/google/balloon-learning-environment/blob/master/balloon_learning_environment/eval/suites.py)
16 | can be split into shards (if you want to parallelize
17 | the work) using the `--num_shards` and `--shard_idx` flags.
18 |
19 |
20 | ## 2. Combine Json Files From Shards
21 | If you didn't use shards in step 1, skip to 3. If you used shards, each of
22 | them produced a separate json file. They need to be combined with
23 | `combine_eval_shards.py`. For example, if you ran eval with both station_seeker
24 | and the random agent:
25 |
26 | ```
27 | python -m balloon_learning_environment.utils.combine_eval_shards \
28 | --path=/tmp/ble/eval \
29 | --models=station_seeker --models=random
30 | ```
31 |
32 |
33 |
34 | ## 3. Visualize Your Results With Colab
35 | Open `balloon_learning_environment/colab/visualize_eval.ipynb` and upload your
36 | json file.
37 |
--------------------------------------------------------------------------------
/balloon_learning_environment/eval/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/eval/combine_eval_shards.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Combines sharded eval results into a single json file.
17 |
18 | """
19 |
20 | import glob
21 | import json
22 | import os
23 | from typing import Sequence
24 |
25 | from absl import app
26 | from absl import flags
27 |
28 |
29 | # TODO(joshgreaves): Rename models to agents (including README).
30 | flags.DEFINE_string('path', None, 'The path containing the shard results.')
31 | flags.DEFINE_multi_string('models', None,
32 | 'The names of the methods in the directory.')
33 | flags.DEFINE_boolean('pretty_json', False,
34 | 'If true, it will write json files with an indent of 2.')
35 | flags.DEFINE_boolean('flight_paths', False,
36 | 'If True, will include flight paths.')
37 | flags.mark_flags_as_required(['path', 'models'])
38 | FLAGS = flags.FLAGS
39 |
40 |
41 | def main(argv: Sequence[str]) -> None:
42 | del argv # Unused.
43 |
44 |
45 | for model in FLAGS.models:
46 | data = list()
47 |
48 | for path in glob.glob(os.path.join(FLAGS.path, f'{model}_*.json')):
49 | with open(path, 'r') as f:
50 | data.extend(json.load(f))
51 |
52 | data = sorted(data, key=lambda x: x['seed'])
53 |
54 | if not FLAGS.flight_paths:
55 | for d in data:
56 | d['flight_path'] = []
57 |
58 | with open(os.path.join(FLAGS.path, f'{model}.json'), 'w') as f:
59 | json.dump(data, f, indent=2 if FLAGS.pretty_json else None)
60 |
61 |
62 | if __name__ == '__main__':
63 | app.run(main)
64 |
--------------------------------------------------------------------------------
/balloon_learning_environment/eval/suites.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A collection of evaluation suites."""
17 |
18 | import dataclasses
19 | from typing import List, Sequence
20 | from balloon_learning_environment.eval import strata_seeds
21 |
22 |
23 | @dataclasses.dataclass
24 | class EvaluationSuite:
25 | """An evaluation suite specification.
26 |
27 | Attributes:
28 | seeds: A sequence of seeds to evaluate the agent on.
29 | max_episode_length: The maximum number of steps to evaluate the agent
30 | on one seed. Must be greater than 0.
31 | """
32 | seeds: Sequence[int]
33 | max_episode_length: int
34 |
35 |
36 | _eval_suites = dict()
37 |
38 |
39 | _eval_suites['big_eval'] = EvaluationSuite(list(range(10_000)), 960)
40 | _eval_suites['medium_eval'] = EvaluationSuite(list(range(1_000)), 960)
41 | _eval_suites['small_eval'] = EvaluationSuite(list(range(100)), 960)
42 | _eval_suites['tiny_eval'] = EvaluationSuite(list(range(10)), 960)
43 | _eval_suites['micro_eval'] = EvaluationSuite([0], 960)
44 | all_strata = []
45 | for strata in ['hardest', 'hard', 'mid', 'easy', 'easiest']:
46 | _eval_suites[f'{strata}_strata'] = EvaluationSuite(
47 | strata_seeds.STRATA_SEEDS[strata], 960)
48 | all_strata += strata_seeds.STRATA_SEEDS[strata]
49 | _eval_suites['all_strata'] = EvaluationSuite(all_strata, 960)
50 |
51 |
52 | def available_suites() -> List[str]:
53 | return list(_eval_suites.keys())
54 |
55 |
56 | def get_eval_suite(name: str) -> EvaluationSuite:
57 | """Gets a named evaluation suite."""
58 | if name not in _eval_suites:
59 | raise ValueError(f'Unknown eval suite {name}')
60 |
61 | # Copy the seeds, rather than returning a mutable object.
62 | suite = _eval_suites[name]
63 | return EvaluationSuite(list(suite.seeds), suite.max_episode_length)
64 |
--------------------------------------------------------------------------------
/balloon_learning_environment/eval/suites_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for suites."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.eval import suites
20 |
21 |
22 | class SuitesTest(absltest.TestCase):
23 |
24 | def test_get_eval_suite_is_successful_for_valid_name(self):
25 | eval_suite = suites.get_eval_suite('big_eval')
26 |
27 | self.assertLen(eval_suite.seeds, 10_000)
28 |
29 | def test_get_eval_suite_raises_error_for_invalid_name(self):
30 | with self.assertRaises(ValueError):
31 | suites.get_eval_suite('invalid name')
32 |
33 |
34 | if __name__ == '__main__':
35 | absltest.main()
36 |
--------------------------------------------------------------------------------
/balloon_learning_environment/generated/multi_balloon.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/generated/multi_balloon.mp4
--------------------------------------------------------------------------------
/balloon_learning_environment/generated/wind_field.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/generated/wind_field.mp4
--------------------------------------------------------------------------------
/balloon_learning_environment/generative/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/generative/dataset_wind_field_reservoir.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """A wind field reservoir that loads a dataset from a file."""
17 |
18 | import pickle
19 | from typing import Union
20 |
21 | from absl import logging
22 | from balloon_learning_environment.generative import wind_field_reservoir
23 | import jax
24 | import jax.numpy as jnp
25 | import tensorflow as tf
26 |
27 |
28 | class DatasetWindFieldReservoir(wind_field_reservoir.WindFieldReservoir):
29 | """Retrieves wind fields from an in-memory datastore."""
30 |
31 | def __init__(self,
32 | data: Union[str, jnp.ndarray],
33 | eval_batch_size: int = 10,
34 | rng_seed=0):
35 | self.eval_batch_size = eval_batch_size
36 |
37 | if isinstance(data, str):
38 | # TODO(scandido): We need to update this to load a single file, with no
39 | # assumed directory/file structure hardcoded.
40 | def _get_shard(i: int):
41 | fn = f'{data}/batch{i:04d}.pickle'
42 | with tf.io.gfile.GFile(fn, 'rb') as f:
43 | arr = pickle.load(f)
44 | return arr
45 |
46 | dataset_shards = []
47 | for i in range(200):
48 | dataset_shards.append(_get_shard(i))
49 | logging.info('Loaded shard %d', i)
50 | data = jnp.concatenate(dataset_shards, axis=0)
51 |
52 | self.dataset = data
53 | self._rng = jax.random.PRNGKey(rng_seed)
54 |
55 | def get_batch(self, batch_size: int) -> jnp.ndarray:
56 | """Returns fields used for training.
57 |
58 | Args:
59 | batch_size: The number of fields.
60 |
61 | Returns:
62 | A jax.numpy array that is batch_size x wind field dimensions (see vae.py).
63 | """
64 |
65 | self._rng, key = jax.random.split(self._rng)
66 | samples = jax.random.choice(
67 | key,
68 | self.dataset.shape[0] - self.eval_batch_size,
69 | shape=(batch_size,),
70 | replace=False)
71 | return self.dataset[samples, ...]
72 |
73 | def get_eval_batch(self) -> jnp.ndarray:
74 | """Returns fields used for eval.
75 |
76 | Returns:
77 | A jax.numpy array that is eval_batch_size x wind field dimensions (see
78 | vae.py).
79 | """
80 |
81 | return self.dataset[-self.eval_batch_size:, ...]
82 |
--------------------------------------------------------------------------------
/balloon_learning_environment/generative/dataset_wind_field_reservoir_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for dataset_wind_field_reservoir."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.generative import dataset_wind_field_reservoir
20 | import jax.numpy as jnp
21 |
22 |
23 | def _make_dataset() -> jnp.ndarray:
24 | train = 0.5 * jnp.ones(shape=(33, 2, 3, 4))
25 | test = 0.2 * jnp.ones(shape=(10, 2, 3, 4))
26 | return jnp.concatenate([train, test], axis=0)
27 |
28 |
29 | class DatasetWindFieldReservoirTest(absltest.TestCase):
30 |
31 | def test_reservoir_returns_eval_for_eval(self):
32 | reservoir = dataset_wind_field_reservoir.DatasetWindFieldReservoir(
33 | data=_make_dataset(), eval_batch_size=10)
34 | self.assertTrue(jnp.allclose(reservoir.get_eval_batch(),
35 | 0.2 * jnp.ones(shape=(10, 2, 3, 4))))
36 |
37 | def test_reservoir_returns_train_for_train(self):
38 | reservoir = dataset_wind_field_reservoir.DatasetWindFieldReservoir(
39 | data=_make_dataset(), eval_batch_size=10)
40 | self.assertTrue(jnp.allclose(reservoir.get_batch(batch_size=33),
41 | 0.5 * jnp.ones(shape=(33, 2, 3, 4))))
42 |
43 |
44 | if __name__ == '__main__':
45 | absltest.main()
46 |
--------------------------------------------------------------------------------
/balloon_learning_environment/generative/vae_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for vae."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.generative import vae
20 | import jax
21 |
22 |
23 | class VaeTest(absltest.TestCase):
24 |
25 | def setUp(self):
26 | super().setUp()
27 | self.key = jax.random.PRNGKey(0)
28 | self.input_shape = (8,)
29 | self.sample_input = jax.random.normal(self.key, self.input_shape)
30 |
31 | def test_encoder_computes_valid_mean_and_logvar(self):
32 | num_latents = 10
33 | encoder = vae.Encoder(num_latents=num_latents)
34 | params = encoder.init(self.key, self.sample_input)
35 |
36 | mean, logvar = encoder.apply(params, self.sample_input)
37 |
38 | # Since mean and logvar are valid in (-inf, inf), just check the shape.
39 | self.assertEqual(mean.shape, (num_latents,))
40 | self.assertEqual(logvar.shape, (num_latents,))
41 |
42 | def test_decoder_computes_valid_wind_field(self):
43 | num_latents = 10
44 | sample_latents = jax.random.normal(self.key, (num_latents,))
45 |
46 | decoder = vae.Decoder()
47 | params = decoder.init(self.key, sample_latents)
48 |
49 | reconstructed_wind_field = decoder.apply(params, sample_latents)
50 |
51 | # Since reconstruction is valid in (-inf, inf), just check the shape.
52 | self.assertEqual(reconstructed_wind_field.shape,
53 | vae.FieldShape().grid_shape())
54 |
55 | def test_vae_computes_valid_wind_field(self):
56 | num_latents = 10
57 | vae_def = vae.WindFieldVAE(num_latents=num_latents)
58 | params = vae_def.init(self.key, self.sample_input, self.key)
59 |
60 | vae_output = vae_def.apply(params, self.sample_input, self.key)
61 |
62 | self.assertEqual(vae_output.reconstruction.shape,
63 | vae.FieldShape().grid_shape())
64 | self.assertEqual(vae_output.encoder_output.mean.shape, (num_latents,))
65 | self.assertEqual(vae_output.encoder_output.logvar.shape, (num_latents,))
66 |
67 |
68 | if __name__ == '__main__':
69 | absltest.main()
70 |
--------------------------------------------------------------------------------
/balloon_learning_environment/generative/wind_field_reservoir.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An interface to datasets of wind fields used for training a VAE model."""
17 |
18 | import abc
19 | import jax.numpy as jnp
20 |
21 |
22 | class WindFieldReservoir(abc.ABC):
23 | """Abstract class for wind datasets for training."""
24 |
25 | @abc.abstractmethod
26 | def get_batch(self, batch_size: int) -> jnp.ndarray:
27 | """Returns fields used for training.
28 |
29 | Args:
30 | batch_size: The number of fields.
31 |
32 | Returns:
33 | A jax.numpy array that is batch_size x wind field dimensions (see vae.py).
34 | """
35 |
36 | @abc.abstractmethod
37 | def get_eval_batch(self) -> jnp.ndarray:
38 | """Returns fields used for eval.
39 |
40 | Returns:
41 | A jax.numpy array that is batch_size x wind field dimensions (see vae.py).
42 | """
43 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/collector.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base class for metric collectors.
17 |
18 | Each Collector should subclass this base class, as the CollectorDispatcher
19 | object expects objects of type Collector.
20 |
21 | The methods to implement are:
22 | - `get_name`: a unique identifier for subdirectory creation.
23 | - `pre_training`: called once before training begins.
24 | - `step`: called once for each training step. The parameter is an object of
25 | type `StatisticsInstance` which contains the statistics of the current
26 | training step.
27 | - `end_training`: called once at the end of training, and passes in a
28 | `StatisticsInstance` containing the statistics of the latest training step.
29 | """
30 |
31 | import abc
32 | import os.path as osp
33 | from typing import Optional
34 |
35 | from balloon_learning_environment.metrics import statistics_instance
36 | import tensorflow as tf
37 |
38 |
39 | class Collector(abc.ABC):
40 | """Abstract class for defining metric collectors."""
41 |
42 | def __init__(self,
43 | base_dir: Optional[str],
44 | num_actions: int,
45 | current_episode: int):
46 | if base_dir is not None:
47 | self._base_dir = osp.join(base_dir, 'metrics', self.get_name())
48 | # Try to create logging directory.
49 | try:
50 | tf.io.gfile.makedirs(self._base_dir)
51 | except tf.errors.PermissionDeniedError:
52 | # If it already exists, ignore exception.
53 | pass
54 | else:
55 | self._base_dir = None
56 | self._num_actions = num_actions
57 | self.current_episode = current_episode
58 | self.summary_writer = None # Should be set by subclass, if needed.
59 |
60 | @abc.abstractmethod
61 | def get_name(self) -> str:
62 | pass
63 |
64 | @abc.abstractmethod
65 | def pre_training(self) -> None:
66 | pass
67 |
68 | @abc.abstractmethod
69 | def begin_episode(self) -> None:
70 | pass
71 |
72 | @abc.abstractmethod
73 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None:
74 | pass
75 |
76 | @abc.abstractmethod
77 | def end_episode(
78 | self, statistics: statistics_instance.StatisticsInstance) -> None:
79 | pass
80 |
81 | @abc.abstractmethod
82 | def end_training(self) -> None:
83 | pass
84 |
85 | def has_summary_writer(self) -> bool:
86 | return self.summary_writer is not None
87 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/collector_dispatcher.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Class that runs a list of Collectors for metrics reporting.
17 |
18 | This class is what should be called from the main binary and will call each of
19 | the specified collectors for metrics reporting.
20 |
21 | Each metric collector can be further configured via gin bindings. The
22 | constructor for each desired collector should be passed in as a list when
23 | creating this object. All of the collectors are expected to be subclasses of the
24 | `Collector` base class (defined in `collector.py`).
25 |
26 | Example configuration:
27 | ```
28 | metrics = CollectorDispatcher(base_dir, num_actions, list_of_constructors)
29 | metrics.pre_training()
30 | for i in range(training_steps):
31 | ...
32 | metrics.step(statistics)
33 | metrics.end_training(statistics)
34 | ```
35 |
36 | The statistics parameter is of type `statistics_instance.StatisticsInstance`,
37 | and contains the raw performance statistics for the current iteration. All
38 | processing (such as averaging) will be handled by each of the individual
39 | collectors.
40 | """
41 |
42 | from typing import Callable, Optional, Sequence
43 |
44 | from balloon_learning_environment.metrics import collector
45 | from balloon_learning_environment.metrics import console_collector
46 | from balloon_learning_environment.metrics import pickle_collector
47 | from balloon_learning_environment.metrics import statistics_instance
48 | from balloon_learning_environment.metrics import tensorboard_collector
49 | from flax.metrics import tensorboard
50 |
51 | BASE_CONFIG_PATH = 'balloon_learning_environment/metrics/configs'
52 | AVAILABLE_COLLECTORS = {
53 | 'console': console_collector.ConsoleCollector,
54 | 'pickle': pickle_collector.PickleCollector,
55 | 'tensorboard': tensorboard_collector.TensorboardCollector,
56 | }
57 |
58 |
59 | CollectorConstructorType = Callable[[str, int, int], collector.Collector]
60 |
61 |
62 | class CollectorDispatcher(object):
63 | """Class for collecting and reporting Balloon Learning Environment metrics."""
64 |
65 | def __init__(self, base_dir: Optional[str], num_actions: int,
66 | collectors: Sequence[CollectorConstructorType],
67 | current_episode: int):
68 | self._collectors = [
69 | collector_constructor(base_dir, num_actions, current_episode)
70 | for collector_constructor in collectors
71 | ]
72 |
73 | def pre_training(self) -> None:
74 | for c in self._collectors:
75 | c.pre_training()
76 |
77 | def begin_episode(self) -> None:
78 | for c in self._collectors:
79 | c.begin_episode()
80 |
81 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None:
82 | for c in self._collectors:
83 | c.step(statistics)
84 |
85 | def end_episode(self,
86 | statistics: statistics_instance.StatisticsInstance) -> None:
87 | for c in self._collectors:
88 | c.end_episode(statistics)
89 |
90 | def end_training(self) -> None:
91 | for c in self._collectors:
92 | c.end_training()
93 |
94 | def get_summary_writer(self) -> Optional[tensorboard.SummaryWriter]:
95 | """Returns the first found instance of a summary_writer, or None."""
96 | for c in self._collectors:
97 | if c.has_summary_writer():
98 | return c.summary_writer
99 | return None
100 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/collector_dispatcher_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.metrics.collector_dispatcher."""
17 |
18 | from absl import flags
19 | from absl.testing import absltest
20 | from balloon_learning_environment.metrics import collector
21 | from balloon_learning_environment.metrics import collector_dispatcher
22 | from balloon_learning_environment.metrics import statistics_instance
23 |
24 |
25 | class CollectorDispatcherTest(absltest.TestCase):
26 |
27 | def setUp(self):
28 | super().setUp()
29 | self._na = 5
30 | self._tmpdir = flags.FLAGS.test_tmpdir
31 |
32 | def test_with_no_collectors(self):
33 | # This test verifies that we can run successfully with no collectors.
34 | metrics = collector_dispatcher.CollectorDispatcher(
35 | self._tmpdir, self._na, [], 0)
36 | metrics.pre_training()
37 | for _ in range(4):
38 | metrics.begin_episode()
39 | for _ in range(10):
40 | metrics.step(statistics_instance.StatisticsInstance(0, 0, 0, False))
41 | metrics.end_episode(
42 | statistics_instance.StatisticsInstance(0, 0, 0, False))
43 | metrics.end_training()
44 |
45 | def test_with_simple_collector(self):
46 | # Create a simple collector that keeps track of received statistics.
47 | logged_stats = []
48 |
49 | class SimpleCollector(collector.Collector):
50 |
51 | def get_name(self) -> str:
52 | return 'simple'
53 |
54 | def pre_training(self) -> None:
55 | pass
56 |
57 | def begin_episode(self) -> None:
58 | logged_stats.append([])
59 |
60 | def step(self, statistics) -> None:
61 | logged_stats[-1].append(statistics)
62 |
63 | def end_episode(self, statistics) -> None:
64 | logged_stats[-1].append(statistics)
65 |
66 | def end_training(self) -> None:
67 | pass
68 |
69 | # Create a simple collector that tracks method calls.
70 | counts = {
71 | 'pre_training': 0,
72 | 'begin_episode': 0,
73 | 'step': 0,
74 | 'end_episode': 0,
75 | 'end_training': 0,
76 | }
77 |
78 | class CountCollector(collector.Collector):
79 |
80 | def get_name(self) -> str:
81 | return 'count'
82 |
83 | def pre_training(self) -> None:
84 | counts['pre_training'] += 1
85 |
86 | def begin_episode(self) -> None:
87 | counts['begin_episode'] += 1
88 |
89 | def step(self, statistics) -> None:
90 | counts['step'] += 1
91 |
92 | def end_episode(self, unused_statistics) -> None:
93 | counts['end_episode'] += 1
94 |
95 | def end_training(self) -> None:
96 | counts['end_training'] += 1
97 |
98 | # Run a collection loop.
99 | metrics = collector_dispatcher.CollectorDispatcher(
100 | self._tmpdir, self._na, [SimpleCollector, CountCollector], 0)
101 | metrics.pre_training()
102 | expected_stats = []
103 | num_episodes = 4
104 | num_steps = 10
105 | for _ in range(num_episodes):
106 | metrics.begin_episode()
107 | expected_stats.append([])
108 | for j in range(num_steps):
109 | stat = statistics_instance.StatisticsInstance(
110 | step=j, action=num_steps-j, reward=j, terminal=False)
111 | metrics.step(stat)
112 | expected_stats[-1].append(stat)
113 | stat = statistics_instance.StatisticsInstance(
114 | step=num_steps, action=0, reward=num_steps, terminal=True)
115 | metrics.end_episode(stat)
116 | expected_stats[-1].append(stat)
117 | metrics.end_training()
118 | self.assertEqual(
119 | counts,
120 | {'pre_training': 1, 'begin_episode': num_episodes,
121 | 'step': num_episodes * num_steps, 'end_episode': num_episodes,
122 | 'end_training': 1})
123 | self.assertEqual(expected_stats, logged_stats)
124 |
125 |
126 | if __name__ == '__main__':
127 | absltest.main()
128 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/collector_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.metrics.collector."""
17 |
18 | import os.path as osp
19 |
20 | from absl import flags
21 | from absl.testing import absltest
22 | from balloon_learning_environment.metrics import collector
23 |
24 |
25 | # A simple subclass that implements the abstract methods.
26 | class SimpleCollector(collector.Collector):
27 |
28 | def get_name(self) -> str:
29 | return 'simple'
30 |
31 | def pre_training(self) -> None:
32 | pass
33 |
34 | def begin_episode(self) -> None:
35 | pass
36 |
37 | def step(self, unused_statistics) -> None:
38 | pass
39 |
40 | def end_episode(self, unused_statistics) -> None:
41 | pass
42 |
43 | def end_training(self) -> None:
44 | pass
45 |
46 |
47 | class CollectorTest(absltest.TestCase):
48 |
49 | def setUp(self):
50 | super().setUp()
51 | self._na = 5
52 | self._tmpdir = flags.FLAGS.test_tmpdir
53 |
54 | def test_instantiate_abstract_class(self):
55 | # It is not possible to instantiate Collector as it has abstract methods.
56 | with self.assertRaises(TypeError):
57 | collector.Collector(self._tmpdir, self._na, 'fail')
58 |
59 | def test_valid_subclass(self):
60 | simple_collector = SimpleCollector(self._tmpdir, self._na, 0)
61 | self.assertEqual(simple_collector._base_dir,
62 | osp.join(self._tmpdir, 'metrics/simple'))
63 | self.assertEqual(self._na, simple_collector._num_actions)
64 | self.assertTrue(osp.exists(simple_collector._base_dir))
65 |
66 | def test_valid_subclass_with_no_basedir(self):
67 | simple_collector = SimpleCollector(None, self._na, 0)
68 | self.assertIsNone(simple_collector._base_dir)
69 | self.assertEqual(self._na, simple_collector._num_actions)
70 |
71 |
72 | if __name__ == '__main__':
73 | absltest.main()
74 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/console_collector.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Collector class for reporting statistics to the console."""
17 |
18 | import os.path as osp
19 | from typing import Union
20 |
21 | from absl import logging
22 | from balloon_learning_environment.metrics import collector
23 | from balloon_learning_environment.metrics import statistics_instance
24 | import gin
25 | import numpy as np
26 | import tensorflow as tf
27 |
28 |
29 | @gin.configurable(allowlist=['fine_grained_logging',
30 | 'fine_grained_frequency',
31 | 'save_to_file'])
32 | class ConsoleCollector(collector.Collector):
33 | """Collector class for reporting statistics to the console."""
34 |
35 | def __init__(self,
36 | base_dir: Union[str, None],
37 | num_actions: int,
38 | current_episode: int,
39 | fine_grained_logging: bool = False,
40 | fine_grained_frequency: int = 1,
41 | save_to_file: bool = True):
42 | super().__init__(base_dir, num_actions, current_episode)
43 | if self._base_dir is not None and save_to_file:
44 | self._log_file = osp.join(self._base_dir, 'console.log')
45 | else:
46 | self._log_file = None
47 | self._fine_grained_logging = fine_grained_logging
48 | self._fine_grained_frequency = fine_grained_frequency
49 |
50 | def get_name(self) -> str:
51 | return 'console'
52 |
53 | def pre_training(self) -> None:
54 | if self._log_file is not None:
55 | self._log_file_writer = tf.io.gfile.GFile(self._log_file, 'w')
56 |
57 | def begin_episode(self) -> None:
58 | self._action_counts = np.zeros(self._num_actions)
59 | self._current_episode_reward = 0.0
60 |
61 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None:
62 | self._current_episode_reward += statistics.reward
63 | if statistics.action < 0 or statistics.action >= self._num_actions:
64 | raise ValueError(f'Invalid action: {statistics.action}')
65 | self._action_counts[statistics.action] += 1
66 | if (self._fine_grained_logging
67 | and statistics.step % self._fine_grained_frequency == 0):
68 | step_string = (
69 | f'Step: {statistics.step}, action: {statistics.action}, '
70 | f'reward: {statistics.reward}, terminal: {statistics.terminal}\n')
71 | logging.info(step_string)
72 | if self._log_file is not None:
73 | self._log_file_writer.write(step_string)
74 |
75 | def end_episode(self,
76 | statistics: statistics_instance.StatisticsInstance) -> None:
77 | self._current_episode_reward += statistics.reward
78 | self._action_counts[statistics.action] += 1
79 | action_distribution = self._action_counts / np.sum(self._action_counts)
80 |
81 | episode_string = (
82 | f'Episode {self.current_episode}: '
83 | f'reward: {self._current_episode_reward:07.2f}, ' # format: 0000.00
84 | f'episode length: {statistics.step}, '
85 | f'action distribution: {action_distribution}')
86 | logging.info(episode_string)
87 |
88 | if self._log_file is not None:
89 | self._log_file_writer.write(episode_string) # pytype: disable=attribute-error # trace-all-classes
90 |
91 | self.current_episode += 1
92 |
93 | def end_training(self) -> None:
94 | if self._log_file is not None:
95 | self._log_file_writer.close() # pytype: disable=attribute-error # trace-all-classes
96 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/pickle_collector.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Collector class for saving episode statistics to a pickle file."""
17 |
18 | import os.path as osp
19 | import pickle
20 |
21 | from balloon_learning_environment.metrics import collector
22 | from balloon_learning_environment.metrics import statistics_instance
23 | import tensorflow as tf
24 |
25 |
26 | class PickleCollector(collector.Collector):
27 | """Collector class for reporting statistics to the console."""
28 |
29 | def __init__(self,
30 | base_dir: str,
31 | num_actions: int,
32 | current_episode: int):
33 | if base_dir is None:
34 | raise ValueError('Must specify a base directory for PickleCollector.')
35 | super().__init__(base_dir, num_actions, current_episode)
36 |
37 | def get_name(self) -> str:
38 | return 'pickle'
39 |
40 | def pre_training(self) -> None:
41 | pass
42 |
43 | def begin_episode(self) -> None:
44 | self._statistics = []
45 |
46 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None:
47 | self._statistics.append(statistics)
48 |
49 | def end_episode(self,
50 | statistics: statistics_instance.StatisticsInstance) -> None:
51 | self._statistics.append(statistics)
52 | pickle_file = osp.join(self._base_dir,
53 | f'pickle_{self.current_episode}.pkl')
54 | with tf.io.gfile.GFile(pickle_file, 'w') as f:
55 | pickle.dump(self._statistics, f, protocol=pickle.HIGHEST_PROTOCOL)
56 | self.current_episode += 1
57 |
58 | def end_training(self) -> None:
59 | pass
60 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/statistics_instance.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Class containing statistics from training steps."""
17 |
18 | import dataclasses
19 |
20 |
21 | @dataclasses.dataclass
22 | class StatisticsInstance:
23 | """Performance statistics to be passed to each of the collectors."""
24 | step: int
25 | action: int
26 | reward: float
27 | terminal: bool
28 |
--------------------------------------------------------------------------------
/balloon_learning_environment/metrics/tensorboard_collector.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Collector class for exporting statistics to Tensorboard."""
17 |
18 | from balloon_learning_environment.metrics import collector
19 | from balloon_learning_environment.metrics import statistics_instance
20 | from flax.metrics import tensorboard
21 | import gin
22 |
23 |
24 | @gin.configurable(allowlist=['fine_grained_logging',
25 | 'fine_grained_frequency'])
26 | class TensorboardCollector(collector.Collector):
27 | """Collector class for reporting statistics on Tensorboard."""
28 |
29 | def __init__(self,
30 | base_dir: str,
31 | num_actions: int,
32 | current_episode: int,
33 | fine_grained_logging: bool = False,
34 | fine_grained_frequency: int = 1):
35 | if not isinstance(base_dir, str):
36 | raise ValueError(
37 | 'Must specify a base directory for TensorboardCollector.')
38 | super().__init__(base_dir, num_actions, current_episode)
39 | self._fine_grained_logging = fine_grained_logging
40 | self._fine_grained_frequency = fine_grained_frequency
41 | self.summary_writer = tensorboard.SummaryWriter(self._base_dir)
42 |
43 | def get_name(self) -> str:
44 | return 'tensorboard'
45 |
46 | def pre_training(self) -> None:
47 | # TODO(joshgreaves): This is wrong if we are starting from a checkpoint.
48 | self._global_step = 0
49 |
50 | def begin_episode(self) -> None:
51 | self._episode_reward = 0.0
52 |
53 | def _log_fine_grained_statistics(
54 | self, statistics: statistics_instance.StatisticsInstance) -> None:
55 | self.summary_writer.scalar('Train/FineGrainedReward', statistics.reward,
56 | self._global_step) # pytype: disable=attribute-error # trace-all-classes
57 | self.summary_writer.flush()
58 |
59 | def step(self, statistics: statistics_instance.StatisticsInstance) -> None:
60 | if self._fine_grained_logging:
61 | if self._global_step % self._fine_grained_frequency == 0:
62 | self._log_fine_grained_statistics(statistics)
63 | self._global_step += 1
64 | self._episode_reward += statistics.reward
65 |
66 | def end_episode(self,
67 | statistics: statistics_instance.StatisticsInstance) -> None:
68 | if self._fine_grained_logging:
69 | self._log_fine_grained_statistics(statistics)
70 | self._episode_reward += statistics.reward
71 | self.summary_writer.scalar('Train/EpisodeReward', self._episode_reward,
72 | self.current_episode)
73 | self.summary_writer.scalar('Train/EpisodeLength', statistics.step,
74 | self.current_episode)
75 | self.summary_writer.flush()
76 | self._global_step += 1 # pytype: disable=attribute-error # trace-all-classes
77 | self.current_episode += 1
78 |
79 | def end_training(self) -> None:
80 | pass
81 |
--------------------------------------------------------------------------------
/balloon_learning_environment/models/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/models/models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Convnience functions for loading models."""
17 |
18 | from importlib import resources
19 | import os
20 | from typing import Optional
21 |
22 | import gin
23 | import tensorflow as tf
24 |
25 | _MODEL_ROOT = 'balloon_learning_environment/models/'
26 | _OFFLINE_SKIES22_RELATIVE_PATH = os.path.join(
27 | _MODEL_ROOT, 'offlineskies22_decoder.msgpack')
28 | _PERCIATELLI44_RELATIVE_PATH = os.path.join(
29 | _MODEL_ROOT, 'perciatelli44.pb')
30 |
31 |
32 | @gin.configurable
33 | def load_offlineskies22(path: Optional[str] = None) -> bytes:
34 | """Loads offlineskies22 serialized wind VAE parameters.
35 |
36 | There are three places this function looks:
37 | 1. At the path specified, if one is specified.
38 | 2. Under the models package using importlib.resources. It should be
39 | found there if the code was installed with pip.
40 | 3. Relative to the project root. It should be found there if running
41 | from a freshly cloned repo.
42 |
43 | Args:
44 | path: An optional path to load the VAE weights from.
45 |
46 | Returns:
47 | The serialized VAE weights as bytes.
48 |
49 | Raises:
50 | ValueError: if a path is specified but the weights can't be loaded.
51 | RuntimeError: if the weights couldn't be found in any of the
52 | specified locations.
53 | """
54 | # Attempt 1: Load from path, if specified.
55 | # If a path is specified, we expect it is a good path.
56 | if path is not None:
57 | try:
58 | with tf.io.gfile.GFile(path, 'rb') as f:
59 | return f.read()
60 | except tf.errors.NotFoundError:
61 | raise ValueError(f'offlineskies22 checkpoint not found at {path}')
62 |
63 | # Attempt 2: Load from location expected in the built wheel.
64 | try:
65 | with resources.open_binary('balloon_learning_environment.models',
66 | 'offlineskies22_decoder.msgpack') as f:
67 | return f.read()
68 | except FileNotFoundError:
69 | pass
70 |
71 | # Attempt 3: Load from the path relative to the source root.
72 | try:
73 | with tf.io.gfile.GFile(_OFFLINE_SKIES22_RELATIVE_PATH, 'rb') as f:
74 | return f.read()
75 | except tf.errors.NotFoundError:
76 | pass
77 |
78 | raise RuntimeError(
79 | 'Unable to load wind VAE checkpoint from the expected locations.')
80 |
81 |
82 | @gin.configurable
83 | def load_perciatelli44(path: Optional[str] = None) -> bytes:
84 | """Loads Perciatelli44.pb as bytes.
85 |
86 | There are three places this function looks:
87 | 1. At the path specified, if one is specified.
88 | 2. Under the models package using importlib.resources. It should be
89 | found there if the code was installed with pip.
90 | 3. Relative to the project root. It should be found there if running
91 | from a freshly cloned repo.
92 |
93 | Args:
94 | path: An optional path to load the VAE weights from.
95 |
96 | Returns:
97 | The serialized VAE weights as bytes.
98 |
99 | Raises:
100 | ValueError: if a path is specified but the weights can't be loaded.
101 | RuntimeError: if the weights couldn't be found in any of the
102 | specified locations.
103 | """
104 | # Attempt 1: Load from path, if specified.
105 | # If a path is specified, we expect it is a good path.
106 | if path is not None:
107 | try:
108 | with tf.io.gfile.GFile(path, 'rb') as f:
109 | return f.read()
110 | except tf.errors.NotFoundError:
111 | raise ValueError(f'perciatelli44 checkpoint not found at {path}')
112 |
113 | # Attempt 2: Load from location expected in the built wheel.
114 | try:
115 | with resources.open_binary('balloon_learning_environment.models',
116 | 'perciatelli44.pb') as f:
117 | return f.read()
118 | except FileNotFoundError:
119 | pass
120 |
121 | # Attempt 3: Load from the path relative to the source root.
122 | try:
123 | with tf.io.gfile.GFile(_PERCIATELLI44_RELATIVE_PATH, 'rb') as f:
124 | return f.read()
125 | except FileNotFoundError:
126 | pass
127 |
128 | raise RuntimeError(
129 | 'Unable to load Perciatelli44 checkpoint from the expected locations.')
130 |
--------------------------------------------------------------------------------
/balloon_learning_environment/models/models_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for models."""
17 |
18 | from importlib import resources
19 | import io
20 | from unittest import mock
21 |
22 | from absl.testing import absltest
23 | from absl.testing import parameterized
24 | from balloon_learning_environment.models import models
25 | import tensorflow as tf
26 |
27 |
28 | class ModelsTest(parameterized.TestCase):
29 |
30 | @parameterized.named_parameters(
31 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22),
32 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22))
33 | def test_load_with_specified_path_loads_data(self, load_fn):
34 | # Write some fake data in a tmpdir.
35 | tmpfile = self.create_tempfile()
36 | fake_content = b'fake content from specified path'
37 | with open(tmpfile, 'wb') as f:
38 | f.write(fake_content)
39 |
40 | result = load_fn(tmpfile.full_path)
41 |
42 | self.assertEqual(result, fake_content)
43 |
44 | @parameterized.named_parameters(
45 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22),
46 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22))
47 | def test_load_with_wrong_path_fails(self, load_fn):
48 | with self.assertRaises(ValueError):
49 | load_fn('this_is_not_a_valid_path')
50 |
51 | @parameterized.named_parameters(
52 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22),
53 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22))
54 | @mock.patch.object(resources, 'open_binary', autospec=True)
55 | def test_load_uses_importlib_if_no_path_is_specified(self,
56 | mock_open,
57 | load_fn):
58 | fake_content = b'fake content from importlib.resources'
59 | mock_open.return_value = io.BytesIO(fake_content)
60 |
61 | result = load_fn()
62 |
63 | mock_open.assert_called_once()
64 | self.assertEqual(result, fake_content)
65 |
66 | @parameterized.named_parameters(
67 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22),
68 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22))
69 | @mock.patch.object(tf.io.gfile, 'GFile', autospec=True)
70 | def test_load_uses_default_path_as_last_resort(self,
71 | mock_gfile,
72 | load_fn):
73 | fake_content = b'fake content from default path'
74 | mock_gfile.return_value = io.BytesIO(fake_content)
75 |
76 | result = load_fn()
77 |
78 | mock_gfile.assert_called_once()
79 | self.assertEqual(result, fake_content)
80 |
81 | @parameterized.named_parameters(
82 | dict(testcase_name='openskies22', load_fn=models.load_offlineskies22),
83 | dict(testcase_name='perciatelli44', load_fn=models.load_offlineskies22))
84 | def test_load_raises_runtime_error_if_not_found(self,
85 | load_fn):
86 | with self.assertRaises(RuntimeError):
87 | load_fn()
88 |
89 | if __name__ == '__main__':
90 | absltest.main()
91 |
--------------------------------------------------------------------------------
/balloon_learning_environment/models/offlineskies22_decoder.msgpack:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/models/offlineskies22_decoder.msgpack
--------------------------------------------------------------------------------
/balloon_learning_environment/models/perciatelli44.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/balloon_learning_environment/models/perciatelli44.pb
--------------------------------------------------------------------------------
/balloon_learning_environment/train.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Main entry point for the Balloon Learning Environment.
17 |
18 | """
19 |
20 | import os.path as osp
21 |
22 | from absl import app
23 | from absl import flags
24 | from balloon_learning_environment import train_lib
25 | from balloon_learning_environment.env.rendering import matplotlib_renderer
26 | from balloon_learning_environment.utils import run_helpers
27 | import gym
28 | import matplotlib
29 | import numpy as np
30 |
31 |
32 | flags.DEFINE_string('agent', 'dqn', 'Type of agent to create.')
33 | flags.DEFINE_string('env_name', 'BalloonLearningEnvironment-v0',
34 | 'Name of environment to create.')
35 | flags.DEFINE_integer('num_iterations', 200, 'Number of episodes to train for.')
36 | flags.DEFINE_integer('max_episode_length', 960,
37 | 'Maximum number of steps per episode. Assuming 2 days, '
38 | 'with each step lasting 3 minutes.')
39 | flags.DEFINE_string('base_dir', None,
40 | 'Directory where to store statistics/images.')
41 | flags.DEFINE_integer(
42 | 'run_number', 1,
43 | 'When running multiple agents in parallel, this number '
44 | 'differentiates between the runs. It is appended to base_dir.')
45 | flags.DEFINE_string(
46 | 'wind_field', 'generative',
47 | 'The wind field type to use. See the _WIND_FIELDS dict below for options.')
48 | flags.DEFINE_string('agent_gin_file', None,
49 | 'Gin file for agent configuration.')
50 | flags.DEFINE_multi_string('collectors', ['console'],
51 | 'Collectors to include in metrics collection.')
52 | flags.DEFINE_multi_string('gin_bindings', [],
53 | 'Gin bindings to override default values.')
54 | flags.DEFINE_string(
55 | 'renderer', None,
56 | 'The renderer to use. Note that it is fastest to have this set to None.')
57 | flags.DEFINE_integer(
58 | 'render_period', 10,
59 | 'The period to render with. Only has an effect if renderer is not None.')
60 | flags.DEFINE_integer(
61 | 'episodes_per_iteration', 50,
62 | 'The number of episodes to run in one iteration. Checkpointing occurs '
63 | 'at the end of each iteration.')
64 | flags.mark_flag_as_required('base_dir')
65 | FLAGS = flags.FLAGS
66 |
67 |
68 | _RENDERERS = {
69 | 'matplotlib': matplotlib_renderer.MatplotlibRenderer,
70 | }
71 |
72 |
73 | def main(_) -> None:
74 | # Prepare metric collector gin files and constructors.
75 | collector_constructors = train_lib.get_collector_data(FLAGS.collectors)
76 | run_helpers.bind_gin_variables(FLAGS.agent,
77 | FLAGS.agent_gin_file,
78 | FLAGS.gin_bindings)
79 |
80 | renderer = None
81 | if FLAGS.renderer is not None:
82 | renderer = _RENDERERS[FLAGS.renderer]()
83 |
84 | wf_factory = run_helpers.get_wind_field_factory(FLAGS.wind_field)
85 | env = gym.make(FLAGS.env_name,
86 | wind_field_factory=wf_factory,
87 | renderer=renderer)
88 |
89 | agent = run_helpers.create_agent(
90 | FLAGS.agent,
91 | env.action_space.n,
92 | observation_shape=env.observation_space.shape)
93 |
94 | base_dir = osp.join(FLAGS.base_dir, FLAGS.agent, str(FLAGS.run_number))
95 | train_lib.run_training_loop(
96 | base_dir,
97 | env,
98 | agent,
99 | FLAGS.num_iterations,
100 | FLAGS.max_episode_length,
101 | collector_constructors,
102 | render_period=FLAGS.render_period,
103 | episodes_per_iteration=FLAGS.episodes_per_iteration)
104 |
105 | if FLAGS.base_dir is not None:
106 | image_save_path = osp.join(FLAGS.base_dir, 'balloon_path.png')
107 | img = env.render(mode='rgb_array')
108 | if isinstance(img, np.ndarray):
109 | matplotlib.image.imsave(image_save_path, img)
110 |
111 |
112 | if __name__ == '__main__':
113 | app.run(main)
114 |
--------------------------------------------------------------------------------
/balloon_learning_environment/train_acme_qrdqn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | r"""Entry point for Acme QrDQN training on the BLE."""
17 |
18 |
19 | from absl import app
20 | from absl import flags
21 | import acme
22 | from acme.jax.deprecated import local_layout
23 | from acme.utils import counting
24 | from acme.utils import loggers
25 | from balloon_learning_environment import acme_utils
26 | import jax
27 |
28 | flags.DEFINE_integer('num_episodes', 1000, 'Number of episodes to train for.')
29 | flags.DEFINE_integer(
30 | 'max_episode_length', 960,
31 | 'Maximum number of steps per episode. Assuming 2 days, '
32 | 'with each step lasting 3 minutes.')
33 | flags.DEFINE_integer('seed', 0, 'Random seed.')
34 |
35 | FLAGS = flags.FLAGS
36 |
37 |
38 | def main(_) -> None:
39 |
40 | env = acme_utils.create_env(False, FLAGS.max_episode_length)
41 | eval_env = acme_utils.create_env(True, FLAGS.max_episode_length)
42 | env_spec = acme.make_environment_spec(env)
43 |
44 | (rl_agent, config, dqn_network_fn, behavior_policy_fn,
45 | eval_policy_fn) = acme_utils.create_dqn({})
46 | dqn_network = dqn_network_fn(env_spec)
47 |
48 | counter = counting.Counter(time_delta=0.)
49 |
50 | agent = local_layout.LocalLayout(
51 | seed=FLAGS.seed,
52 | environment_spec=env_spec,
53 | builder=rl_agent,
54 | networks=dqn_network,
55 | policy_network=behavior_policy_fn(dqn_network),
56 | batch_size=config.batch_size,
57 | prefetch_size=4,
58 | counter=counting.Counter(counter, 'learner'))
59 |
60 | eval_actor = rl_agent.make_actor(
61 | jax.random.PRNGKey(0),
62 | policy=eval_policy_fn(dqn_network),
63 | environment_spec=env_spec,
64 | variable_source=agent)
65 |
66 | actor_logger = loggers.make_default_logger('actor')
67 | evaluator_logger = loggers.make_default_logger('evaluator')
68 |
69 | loop = acme.EnvironmentLoop(
70 | env,
71 | agent,
72 | logger=actor_logger,
73 | counter=counting.Counter(counter, 'actor', time_delta=0.))
74 | eval_loop = acme.EnvironmentLoop(
75 | eval_env,
76 | eval_actor,
77 | logger=evaluator_logger,
78 | counter=counting.Counter(counter, 'evaluator', time_delta=0.))
79 | for _ in range(FLAGS.num_episodes):
80 | loop.run(1)
81 | eval_loop.run(1)
82 |
83 |
84 | if __name__ == '__main__':
85 | app.run(main)
86 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/constants.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Common constants used throughout the codebase."""
17 |
18 | import datetime as dt
19 |
20 |
21 | # --- Physics Constants ---
22 |
23 | GRAVITY: float = 9.80665 # [m/s^2]
24 | NUM_SECONDS_PER_HOUR = 3_600
25 | NUM_SECONDS_PER_DAY: int = 86_400
26 | UNIVERSAL_GAS_CONSTANT: float = 8.3144621 # [J/(mol.K)]
27 | DRY_AIR_MOLAR_MASS: float = 0.028964922481160 # Dry Air. [kg/mol]
28 | HE_MOLAR_MASS: float = 0.004002602 # Helium. [kg/mol]
29 | DRY_AIR_SPECIFIC_GAS_CONSTANT: float = (
30 | UNIVERSAL_GAS_CONSTANT / DRY_AIR_MOLAR_MASS) # [J/(kg.K)]
31 |
32 |
33 | # --- RL constants ---
34 | # Amount of time that elapses between agent steps.
35 | AGENT_TIME_STEP: dt.timedelta = dt.timedelta(minutes=3)
36 | # Pressure limits for the Perciatelli features.
37 | PERCIATELLI_PRESSURE_RANGE_MIN: int = 5000
38 | PERCIATELLI_PRESSURE_RANGE_MAX: int = 14000
39 |
40 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/run_helpers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Helper functions for running agents in train/eval."""
17 |
18 | import os
19 | from typing import Callable, Optional, Sequence
20 |
21 | import balloon_learning_environment
22 | from balloon_learning_environment.agents import agent as base_agent
23 | from balloon_learning_environment.agents import agent_registry
24 | from balloon_learning_environment.env import generative_wind_field
25 | from balloon_learning_environment.env import wind_field
26 | import gin
27 |
28 |
29 |
30 | def get_agent_gin_file(agent_name: str,
31 | gin_file: Optional[str]) -> Optional[str]:
32 | """Gets a gin file for a specified agent.
33 |
34 | If gin_file is specified, that is the gin_file that will be used.
35 | However, if no gin file is specified, it will use the default gin file
36 | for that agent. If the agent has no default gin file, this may return None.
37 |
38 | Args:
39 | agent_name: The name of the agent to retrieve the gin file for.
40 | gin_file: An optional gin file to override the agent's default gin file.
41 |
42 | Returns:
43 | A path to a gin file, or None.
44 | """
45 | return (agent_registry.get_default_gin_config(agent_name)
46 | if gin_file is None else gin_file)
47 |
48 |
49 | def create_agent(agent_name: str, num_actions: int,
50 | observation_shape: Sequence[int]) -> base_agent.Agent:
51 | return agent_registry.agent_constructor(agent_name)(
52 | num_actions, observation_shape=observation_shape)
53 |
54 |
55 | def get_wind_field_factory(
56 | wind_field_name: str) -> Callable[[], wind_field.WindField]:
57 | """Gets a wind field by name.
58 |
59 | If the wind field name doesn't exist, raises a ValueError.
60 |
61 | Args:
62 | wind_field_name: The name of the wind field to create.
63 |
64 | Returns:
65 | A callable that returns a WindField object.
66 | """
67 | if wind_field_name == 'simple':
68 | return wind_field.SimpleStaticWindField
69 | elif wind_field_name == 'generative':
70 | return generative_wind_field.generative_wind_field_factory
71 | else:
72 | raise ValueError(f'Unknown wind field {wind_field_name}')
73 |
74 |
75 | def bind_gin_variables(
76 | agent: str,
77 | agent_gin_file: Optional[str] = None,
78 | gin_bindings: Sequence[str] = (),
79 | additional_gin_files: Sequence[str] = ()
80 | ) -> None:
81 | """A helper function for binding gin variables for an experiment.
82 |
83 | Args:
84 | agent: The agent being used in the experiment.
85 | agent_gin_file: An optional path to a gin file to override the agent's
86 | default gin file.
87 | gin_bindings: An optional list of gin bindings passed in on the command
88 | line.
89 | additional_gin_files: Any other additional paths to gin files that should be
90 | parsed and bound.
91 | """
92 | gin_files = []
93 |
94 | # The gin file paths start with balloon_learning_environment,
95 | # so we need to add the parent directory to the search path.
96 | ble_root = os.path.dirname(balloon_learning_environment.__file__)
97 | ble_parent_dir = os.path.dirname(ble_root)
98 | gin.add_config_file_search_path(ble_parent_dir)
99 |
100 | agent_gin_file = get_agent_gin_file(agent, agent_gin_file)
101 | if agent_gin_file is not None:
102 | gin_files.append(agent_gin_file)
103 |
104 | gin_files.extend(additional_gin_files)
105 | gin.parse_config_files_and_bindings(
106 | gin_files, bindings=gin_bindings, skip_unknown=False)
107 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/sampling_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.utils.sampling."""
17 |
18 | from absl.testing import absltest
19 | from absl.testing import parameterized
20 | from balloon_learning_environment.env.balloon import standard_atmosphere
21 | from balloon_learning_environment.utils import sampling
22 | from balloon_learning_environment.utils import units
23 | import jax
24 |
25 |
26 | class SamplingTest(parameterized.TestCase):
27 |
28 | def setUp(self):
29 | super().setUp()
30 | # Deterministic PRNG state. Tests MUST NOT rely on a specific seed.
31 | self.prng_key = jax.random.PRNGKey(123)
32 | self.atmosphere = standard_atmosphere.Atmosphere(jax.random.PRNGKey(0))
33 |
34 | def test_sample_location_with_seed_gives_deterministic_lat_lng(self):
35 | location1 = sampling.sample_location(self.prng_key)
36 | location2 = sampling.sample_location(self.prng_key)
37 |
38 | self.assertEqual(location1, location2)
39 |
40 | def test_sample_location_gives_valid_lat_lng(self):
41 | latlng = sampling.sample_location(self.prng_key)
42 |
43 | # We only allow locations near the equator
44 | self.assertBetween(latlng.lat().degrees, -10.0, 10.0)
45 | # We don't allow locations near the international date line
46 | self.assertBetween(latlng.lng().degrees, -175.0, 175.0)
47 |
48 | def test_sample_time_with_seed_gives_deterministic_time(self):
49 | t1 = sampling.sample_time(self.prng_key)
50 | t2 = sampling.sample_time(self.prng_key)
51 |
52 | self.assertEqual(t1, t2)
53 |
54 | def test_sample_time_gives_time_within_range(self):
55 | # Pick a 1 hour segment to give a small valid range for testing
56 | begin_range = units.datetime(year=2020, month=1, day=1, hour=1)
57 | end_range = units.datetime(year=2020, month=1, day=1, hour=2)
58 | t = sampling.sample_time(
59 | self.prng_key, begin_range=begin_range, end_range=end_range)
60 |
61 | self.assertBetween(t, begin_range, end_range)
62 |
63 | def test_sample_pressure_with_seed_gives_deterministic_pressure(self):
64 | p1 = sampling.sample_pressure(self.prng_key, self.atmosphere)
65 | p2 = sampling.sample_pressure(self.prng_key, self.atmosphere)
66 |
67 | self.assertEqual(p1, p2)
68 |
69 | def test_sample_pressure_gives_pressure_within_range(self):
70 | p = sampling.sample_pressure(self.prng_key, self.atmosphere)
71 | self.assertBetween(p, 5000, 14000)
72 |
73 | def test_sample_upwelling_infrared_is_within_range(self):
74 | ir = sampling.sample_upwelling_infrared(self.prng_key)
75 | self.assertBetween(ir, 100.0, 350.0)
76 |
77 | @parameterized.named_parameters(
78 | dict(testcase_name='logit_normal', distribution_type='logit_normal'),
79 | dict(testcase_name='inverse_lognormal',
80 | distribution_type='inverse_lognormal'))
81 | def test_sample_upwelling_infrared_is_within_range_nondefault(
82 | self, distribution_type):
83 | ir = sampling.sample_upwelling_infrared(self.prng_key,
84 | distribution_type=distribution_type)
85 | self.assertBetween(ir, 100.0, 350.0)
86 |
87 | def test_sample_upwelling_infrared_invalid_distribution_type(self):
88 | with self.assertRaises(ValueError):
89 | sampling.sample_upwelling_infrared(self.prng_key,
90 | distribution_type='invalid')
91 |
92 |
93 | if __name__ == '__main__':
94 | absltest.main()
95 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/spherical_geometry.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Spherical geometry functions."""
17 |
18 | import copyreg
19 | import math
20 | from typing import Any
21 |
22 | from balloon_learning_environment.utils import units
23 |
24 | import s2sphere as s2
25 |
26 |
27 | # We use the spherical Earth approximation rather than WGS-84, which simplifies
28 | # things and is appropriate for the use case of our simulator.
29 | _EARTH_RADIUS = units.Distance(km=6371)
30 |
31 |
32 | # Note: We _must_ register a pickle function for LatLng, otherwise
33 | # they break when used with gin or dataclasses.astuple. Basically anywhere
34 | # the value may be copied.
35 | # This module should be included most places s2LatLng is used eventually,
36 | # but it may be beneficial to enforce it is included at some later date.
37 | # These type hints are bad, but it's what copyreg wants 😬.
38 | def pickle_latlng(obj: Any) -> tuple: # pylint: disable=g-bare-generic
39 | return s2.LatLng.from_degrees, (obj.lat().degrees, obj.lng().degrees)
40 |
41 | copyreg.pickle(s2.LatLng, pickle_latlng)
42 |
43 |
44 | def calculate_latlng_from_offset(center_latlng: s2.LatLng,
45 | x: units.Distance,
46 | y: units.Distance) -> s2.LatLng:
47 | """Calculates a new lat lng given an origin and x y offsets.
48 |
49 | Args:
50 | center_latlng: The starting latitude and longitude.
51 | x: An offset from center_latlng parallel to longitude.
52 | y: An offset from center_latlng parallel to latitude.
53 |
54 | Returns:
55 | A new latlng that is the specified distance from the start latlng.
56 | """
57 | # x and y are swapped to give heading with 0 degrees = North.
58 | # This is equivalent to pi / 2 - atan2(y, x).
59 | heading = math.atan2(x.km, y.km) # In radians.
60 | angle = units.relative_distance(x, y) / _EARTH_RADIUS # In radians.
61 |
62 | cos_angle = math.cos(angle)
63 | sin_angle = math.sin(angle)
64 | sin_from_lat = math.sin(center_latlng.lat().radians)
65 | cos_from_lat = math.cos(center_latlng.lat().radians)
66 |
67 | sin_lat = (cos_angle * sin_from_lat +
68 | sin_angle * cos_from_lat * math.cos(heading))
69 | d_lng = math.atan2(sin_angle * cos_from_lat * math.sin(heading),
70 | cos_angle - sin_from_lat * sin_lat)
71 |
72 | new_lat = math.asin(sin_lat)
73 | new_lat = min(max(new_lat, -math.pi / 2.0), math.pi / 2.0)
74 | new_lng = center_latlng.lng().radians + d_lng
75 |
76 | return s2.LatLng.from_radians(new_lat, new_lng).normalized()
77 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/spherical_geometry_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for spherical_geometry."""
17 |
18 | import math
19 |
20 | from absl.testing import absltest
21 | from absl.testing import parameterized
22 | from balloon_learning_environment.utils import spherical_geometry
23 | from balloon_learning_environment.utils import units
24 |
25 | import s2sphere as s2
26 |
27 |
28 | class SphericalGeometryTest(parameterized.TestCase):
29 |
30 | @parameterized.named_parameters(
31 | dict(testcase_name='equator', lat=0.0, lng=45.0),
32 | dict(testcase_name='north_pole', lat=80.0, lng=-135.0),
33 | dict(testcase_name='south_pole', lat=-85.0, lng=92.0),)
34 | def test_offset_latlng_gives_1_degree_latitude_per_111km_everywhere(
35 | self, lat: float, lng: float):
36 | center_latlng = s2.LatLng.from_degrees(lat, lng)
37 | x = units.Distance(km=0.0)
38 | y = units.Distance(km=111.0)
39 |
40 | new_latlng = spherical_geometry.calculate_latlng_from_offset(
41 | center_latlng, x, y)
42 |
43 | self.assertAlmostEqual(new_latlng.lat().degrees, lat + 1.0, places=2)
44 |
45 | def test_offset_latlng_gives_1_degree_longitude_per_111km_at_equator(self):
46 | center_latlng = s2.LatLng.from_degrees(0.0, 0.0)
47 | x = units.Distance(km=111.0) # About one degree longitude at equator.
48 | y = units.Distance(km=0.0)
49 |
50 | new_latlng = spherical_geometry.calculate_latlng_from_offset(
51 | center_latlng, x, y)
52 |
53 | self.assertAlmostEqual(new_latlng.lng().degrees, 1.0, places=2)
54 |
55 | def test_offset_latlng_gives_larger_longitude_change_away_from_equator(self):
56 | center_latlng = s2.LatLng.from_degrees(45.0, 0.0)
57 | x = units.Distance(km=111.0) # > 1 degree longitude away from equator.
58 | y = units.Distance(km=0.0)
59 |
60 | new_latlng = spherical_geometry.calculate_latlng_from_offset(
61 | center_latlng, x, y)
62 |
63 | # The change in longitude should be greater than at the equator, but not
64 | # too much greater. These numbers are somewhat arbitrary - they are more
65 | # of a sanity check. The only alternative to this is to directly
66 | # re-write the formula.
67 | self.assertBetween(new_latlng.lng().degrees, 1.25, 1.75)
68 |
69 | def test_offset_latlng_wraps_around_north_pole(self):
70 | center_latlng = s2.LatLng.from_degrees(89.0, -90.0)
71 | x = units.Distance(km=0.0)
72 | y = units.Distance(km=222.0) # About 2 degrees latitude.
73 |
74 | new_latlng = spherical_geometry.calculate_latlng_from_offset(
75 | center_latlng, x, y)
76 |
77 | # We have gone over the North pole, so latitude should be 89 again,
78 | # but we have gon half way around the world longitudinally.
79 | self.assertAlmostEqual(new_latlng.lat().degrees, 89.0, places=2)
80 | self.assertAlmostEqual(new_latlng.lng().degrees, 90.0, places=2)
81 |
82 | @parameterized.named_parameters(
83 | dict(
84 | testcase_name='west_to_east', lng=179.0, degrees=2.0,
85 | expected=-179.0),
86 | dict(
87 | testcase_name='east_to_west', lng=-179.0, degrees=-2.0,
88 | expected=179.0),
89 | dict(
90 | testcase_name='multiple_times',
91 | lng=0.0,
92 | degrees=1440,
93 | expected=0.0))
94 | def test_offset_latlng_wraps_around_longitude(self, lng: float,
95 | degrees: float,
96 | expected: float):
97 | center_latlng = s2.LatLng.from_degrees(0.0, lng)
98 | # arc_length = radius * angle_radians.
99 | x = spherical_geometry._EARTH_RADIUS * math.radians(degrees)
100 | y = units.Distance(km=0.0)
101 |
102 | new_latlng = spherical_geometry.calculate_latlng_from_offset(
103 | center_latlng, x, y)
104 |
105 | self.assertAlmostEqual(new_latlng.lng().degrees, expected)
106 |
107 |
108 | if __name__ == '__main__':
109 | absltest.main()
110 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Common transforms."""
17 |
18 | import typing
19 | from typing import Union
20 |
21 | import numpy as np
22 |
23 |
24 | def _contains_negative_values(x: Union[float, np.ndarray]) -> bool:
25 | if isinstance(x, np.ndarray):
26 | return (x < 0).any()
27 | else:
28 | return x < 0
29 |
30 |
31 | @typing.overload
32 | def linear_rescale_with_extrapolation(x: np.ndarray,
33 | vmin: float,
34 | vmax: float) -> np.ndarray:
35 | ...
36 |
37 |
38 | @typing.overload
39 | def linear_rescale_with_extrapolation(x: float,
40 | vmin: float,
41 | vmax: float) -> float:
42 | ...
43 |
44 |
45 | def linear_rescale_with_extrapolation(x,
46 | vmin,
47 | vmax):
48 | """Returns x normalized between [vmin, vmax], with possible extrapolation."""
49 | if vmax <= vmin:
50 | raise ValueError('Interval must be such that vmax > vmin.')
51 | else:
52 | return (x - vmin) / (vmax - vmin)
53 |
54 |
55 | def undo_linear_rescale_with_extrapolation(x: float, vmin: float,
56 | vmax: float) -> float:
57 | """Computes the input of linear_rescale_with_extrapolation given output."""
58 | if vmax <= vmin:
59 | raise ValueError('Interval must be such that vmax > vmin.')
60 | return vmin + x * (vmax - vmin)
61 |
62 |
63 | def linear_rescale_with_saturation(x: float, vmin: float, vmax: float) -> float:
64 | """Returns x normalized in [0, 1]."""
65 | y = linear_rescale_with_extrapolation(x, vmin, vmax)
66 | return np.clip(y, 0.0, 1.0).item()
67 |
68 |
69 | @typing.overload
70 | def squash_to_unit_interval(x: np.ndarray, constant: float) -> np.ndarray:
71 | ...
72 |
73 |
74 | @typing.overload
75 | def squash_to_unit_interval(x: float, constant: float) -> float:
76 | ...
77 |
78 |
79 | def squash_to_unit_interval(x, constant):
80 | """Scales non-negative x to be in range [0, 1], with a squash."""
81 | if constant <= 0:
82 | raise ValueError('Squash constant must be greater than zero.')
83 | if _contains_negative_values(x):
84 | raise ValueError('Squash can only be performed on non-negative values.')
85 | return x / (x + constant)
86 |
87 |
88 | def undo_squash_to_unit_interval(x: float, constant: float) -> float:
89 | """Computes the input value of squash_to_unit_interval given the output."""
90 | if constant <= 0:
91 | raise ValueError('Squash constant must be greater than zero.')
92 | if 0 > x >= 1:
93 | raise ValueError('Undo squash can only be performed on a value in [0, 1).')
94 | return (x * constant) / (1 - x)
95 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/wind.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Utility functions for evaluating a wind field."""
17 |
18 | import jax
19 | import jax.numpy as jnp
20 | import numpy as np
21 | import scipy.spatial
22 |
23 |
24 | def is_station_keeping_winds(wind_column: np.ndarray) -> bool:
25 | """Determines if a wind column supports station keeping winds.
26 |
27 | We are looking for winds in multiple directions so the balloon can change
28 | altitude and head back (so to speak) towards the target. This corresponds to
29 | the origin sitting within the convex hull of a column of wind vectors.
30 |
31 | Args:
32 | wind_column: A column of (u, v) wind vectors.
33 |
34 | Returns:
35 | yes or no
36 | """
37 |
38 | hull = scipy.spatial.ConvexHull(wind_column)
39 | support = [wind_column[i, :] for i in hull.vertices]
40 | hull = scipy.spatial.Delaunay(support)
41 | return hull.find_simplex(np.zeros(2)) >= 0
42 |
43 |
44 | @jax.jit
45 | def wind_field_speeds(wind_field: jnp.ndarray) -> jnp.ndarray:
46 | """Returns the wind speed throughout the field.
47 |
48 | Args:
49 | wind_field: A 4D wind field with u, v components.
50 |
51 | Returns:
52 | A 4D array of speeds at the same grid points.
53 | """
54 |
55 | u = wind_field[:, :, :, :, 0]
56 | v = wind_field[:, :, :, :, 1]
57 | return jnp.sqrt(u * u + v * v)
58 |
59 |
60 | @jax.jit
61 | def mean_speed_in_wind_field(wind_field: jnp.ndarray) -> float:
62 | """Returns the mean wind speed throughout the field.
63 |
64 | Args:
65 | wind_field: A 4D wind field with u, v components.
66 |
67 | Returns:
68 | The mean wind speed.
69 | """
70 |
71 | return wind_field_speeds(wind_field).mean()
72 |
--------------------------------------------------------------------------------
/balloon_learning_environment/utils/wind_test.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tests for balloon_learning_environment.utils.wind."""
17 |
18 | from absl.testing import absltest
19 | from balloon_learning_environment.utils import wind
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 |
24 | class WindTest(absltest.TestCase):
25 |
26 | def test_is_station_keeping_winds_postive_case(self):
27 | self.assertTrue(wind.is_station_keeping_winds(np.array([
28 | [1, 0], [-1, 0], [0, -1], [0, 1]])))
29 |
30 | def test_is_station_keeping_winds_negative_case(self):
31 | self.assertFalse(wind.is_station_keeping_winds(np.array([
32 | [1, 1.2], [2, 3.2], [10.2, 33.4]])))
33 |
34 | def test_wind_field_speeds_works(self):
35 | u = 2.1 * jnp.ones((5, 6, 7, 8))
36 | v = 3.2 * jnp.ones((5, 6, 7, 8))
37 | field = jnp.stack([u, v], axis=-1)
38 | self.assertAlmostEqual(wind.wind_field_speeds(field).mean().item(),
39 | jnp.sqrt(2.1 * 2.1 + 3.2 * 3.2).item(),
40 | places=3)
41 | self.assertEqual(wind.wind_field_speeds(field).shape, (5, 6, 7, 8))
42 |
43 | def test_mean_speed_in_wind_field_works(self):
44 | u = 2.1 * jnp.ones((5, 6, 7, 8))
45 | v = 3.2 * jnp.ones((5, 6, 7, 8))
46 | field = jnp.stack([u, v], axis=-1)
47 | self.assertAlmostEqual(wind.mean_speed_in_wind_field(field).item(),
48 | jnp.sqrt(2.1 * 2.1 + 3.2 * 3.2).item(),
49 | places=3)
50 |
51 |
52 | if __name__ == '__main__':
53 | absltest.main()
54 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # BLE Documentation
2 |
3 | To build the documentation, make sure you have the dependencies installed:
4 |
5 | ```
6 | pip install -r requirements.txt
7 | ```
8 |
9 | If you are building locally you'll also need the sphinx_rtd_theme:
10 |
11 | ```
12 | pip install sphinx_rtd_theme
13 | ```
14 |
15 | Then:
16 |
17 | ```
18 | make html
19 | ```
20 |
21 | To clean the documentation:
22 |
23 | ```
24 | make clean
25 | ```
26 |
--------------------------------------------------------------------------------
/docs/about.rst:
--------------------------------------------------------------------------------
1 | About the Balloon Learning Environment
2 | ======================================
3 |
4 | This page gives a simple background into the problem of flying
5 | stratospheric balloons and station keeping. This is intended to
6 | be an introduction, and doesn't discuss all the nuances of the simulator
7 | or real-world complexities.
8 |
9 | Stratospheric Balloons
10 | ######################
11 |
12 | The type of balloon that we consider in the BLE is pictured below.
13 | These balloons have an outer envelope and an inner envelope (ballonet).
14 | The ballonet is filled with a buoyant gas (for example, helium).
15 | The outer envelope contains air which acts as ballast.
16 | An Altitude Control System (ACS) is able to pump air in or out of the envelope.
17 | When air is pumped into the balloon the average density
18 | of the gases in the envelope increases, which in turn makes the balloon
19 | descend. On the other hand, if air is pumped out of the balloon then the
20 | average gas density decreases, resulting in the balloon ascending.
21 |
22 | The balloons are also equipped with a battery to power the ACS, and solar
23 | panels to recharge the batteries. Power is used by communication systems
24 | on the balloon, so battery power is constantly draining when the solar panels
25 | aren't in use. This means that the balloons in the BLE are constrained at
26 | night—they cannot constantly use the ACS without running out of power.
27 | The energy required for running the ACS is asymmetric;
28 | the energy needed to release air from the envelope (ascend) is negligable.
29 |
30 | .. image:: imgs/balloon_schematic.jpg
31 |
32 | Navigating a Windfield
33 | ######################
34 |
35 | The balloons in the BLE have no way of moving themselves laterally. They
36 | are only capable of moving up, down, or staying at the current altitude.
37 | To navigate they must "surf" the wind to get where they need to go. The
38 | balloons are flown in a 4d windfield, where each different x, y, altitude, time
39 | position gives a different wind vector. A balloon must learn how to navigate
40 | this windfield to achieve its goal.
41 |
42 | .. image:: imgs/wind_field.gif
43 |
44 | Station-Keeping
45 | ###############
46 |
47 | The goal of station-keeping is to remain within a fixed distance of a
48 | ground station. This distance is only measured in the x, y plane, i.e. the
49 | altitude of the balloon is not taken into account. In the BLE, the task
50 | is to keep a balloon within 50km of the ground station. We call the proportion
51 | of time that a balloon remains within 50km of the ground station TWR50 (for
52 | Time Within a Radius of 50km). In the BLE, each episode lasts two days.
53 |
54 | .. image:: imgs/station_keeping.gif
55 |
56 | Failure Modes
57 | #############
58 |
59 | A balloon can have a critical failure by running out of power, flying too low,
60 | or having a superpressure that is too high or too low.
61 | Each of these are partially protected by a
62 | safety layer, but in extreme conditions there can still be critical failures.
63 |
--------------------------------------------------------------------------------
/docs/benchmarks.rst:
--------------------------------------------------------------------------------
1 | Benchmark Results
2 | =================
3 |
4 | This page will be udpated soon with even more data 📈
5 |
6 | The following graph shows evaluation results on the "small eval" evaluation
7 | suite throughout training for the DQN, quantile, and finetuned Perciatelli44
8 | agents. The horizontal lines are evaluation results for "big eval" for
9 | Perciatelli44 and station seeker.
10 |
11 | .. image:: imgs/training_curve.jpg
--------------------------------------------------------------------------------
/docs/changelist.rst:
--------------------------------------------------------------------------------
1 | Change List
2 | ===========
3 |
4 | v1.0.2
5 | ######
6 |
7 | - Added GridBasedWindField and GenerativeWindFieldSampler.
8 | - Deprecated GenerativeWindField in favor of GridBasedWindField and
9 | GenerativeWindFieldSampler.
10 | - Bug fix: previously the train script would load the latest checkpoint
11 | when restarting but then resume from the previous iteration. It is now
12 | fixed, so if reloading checkpoint i, the agent will continue working
13 | on iteration i + 1.
14 | - Vectorized wind column feature calculations. This gives about a 16% speed
15 | increase.
16 | - Cleanups in balloon.py, by removing staticfunctions.
17 | - Moved wind field creation to run_helpers, since it is common for
18 | training and evaluation.
19 | - Added a flag for calculating the flight path to eval_lib, to allow for
20 | faster evaluation if you don't need flight paths.
21 | - Improvements to AcmeEvalAgent by making it more configurable.
22 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | # Configuration file for the Sphinx documentation builder.
17 | #
18 | # This file only contains a selection of the most common options. For a full
19 | # list see the documentation:
20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
21 |
22 | # -- Path setup --------------------------------------------------------------
23 |
24 | # If extensions (or modules to document with autodoc) are in another directory,
25 | # add these directories to sys.path here. If the directory is relative to the
26 | # documentation root, use os.path.abspath to make it absolute, like shown here.
27 |
28 | import os
29 | import sys
30 | sys.path.insert(0, os.path.abspath('..'))
31 |
32 |
33 | # -- Project information -----------------------------------------------------
34 |
35 | project = 'Balloon Learning Environment'
36 | copyright = '2021, The Balloon Learning Environment Authors'
37 | author = 'The Balloon Learning Environment Authors'
38 |
39 |
40 | # -- General configuration ---------------------------------------------------
41 | master_doc = 'index'
42 |
43 | html_logo = 'imgs/ble_logo_small.png'
44 |
45 | # Add any Sphinx extension module names here, as strings. They can be
46 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
47 | # ones.
48 | extensions = [
49 | 'sphinx.ext.autodoc',
50 | 'sphinx.ext.autodoc.typehints',
51 | 'sphinx.ext.autosummary',
52 | 'sphinx.ext.napoleon',
53 | 'sphinx.ext.viewcode',
54 | 'myst_parser',
55 | ]
56 |
57 | # This simplifies the documentation by only displaying the class name
58 | # without the module path. For example:
59 | # balloon_learning_environment.env.balloon_env.BalloonEnv -> BalloonEnv.
60 | add_module_names = False
61 |
62 | # This orders the members of a class by the source code order.
63 | autodoc_member_order = 'bysource'
64 |
65 | # This removes type hints from the function definition and moves them
66 | # to the description below.
67 | autodoc_typehints = 'description'
68 |
69 | # Add any paths that contain templates here, relative to this directory.
70 | templates_path = ['_templates']
71 |
72 | # List of patterns, relative to source directory, that match files and
73 | # directories to ignore when looking for source files.
74 | # This pattern also affects html_static_path and html_extra_path.
75 | exclude_patterns = [
76 | '_build',
77 | 'Thumbs.db',
78 | '.DS_Store',
79 | 'README.md',
80 | ]
81 |
82 |
83 | # -- Options for HTML output -------------------------------------------------
84 |
85 | # The theme to use for HTML and HTML Help pages. See the documentation for
86 | # a list of builtin themes.
87 | #
88 | html_theme = 'sphinx_rtd_theme'
89 |
90 | # Add any paths that contain custom static files (such as style sheets) here,
91 | # relative to this directory. They are copied after the builtin static files,
92 | # so a file named "default.css" will overwrite the builtin "default.css".
93 | html_static_path = []
94 |
--------------------------------------------------------------------------------
/docs/environment.rst:
--------------------------------------------------------------------------------
1 | Using and Configuring the Environment
2 | =====================================
3 |
4 | Basic Usage
5 | ###########
6 |
7 | The main entrypoint to the BLE for most users is the gym environment. To use
8 | the environment, import the balloon environment and use gym to create it:
9 |
10 | .. code-block:: python
11 |
12 | import balloon_learning_environment.env.balloon_env # Registers the environment.
13 | import gym
14 |
15 | env = gym.make('BalloonLearningEnvironment-v0')
16 |
17 |
18 | This will give you a new
19 | `BalloonEnv `_
20 | object that follows the gym environment interface. Before we run the
21 | environment we can inspect the observation and action spaces:
22 |
23 | .. code-block:: python
24 |
25 | >>> print(env.observation_space)
26 | Box([0. 0. 0. ... 0. 0. 0.], [1. 1. 1. ... 1. 1. 1.], (1099,), float32)
27 | >>> print(env.action_space)
28 | Discrete(3)
29 |
30 |
31 | Here we can see that the observation space is a 1099 dimensional array,
32 | and the action space has 3 discrete actions. We can use the environment
33 | as follows:
34 |
35 | .. code-block:: python
36 |
37 | env.seed(0)
38 | observation_0 = env.reset()
39 | observation_1, reward, is_terminal, info = env.step(0)
40 |
41 |
42 | In this snippet, we seeded the environment to give it a deterministic
43 | initialization. This is useful for replicating results (for example, in
44 | evaluation), but most of the time you'll want to skip this line to have
45 | a random initialization. After seeding the environment we reset it and
46 | stepped once with action 0.
47 |
48 | We expect the observations to have the shape specified by observation_space:
49 |
50 | .. code-block:: python
51 |
52 | >>> print(type(observation_0), observation_0.shape)
53 | (1099,)
54 |
55 |
56 | The reward, is_terminal, and info objects are as follows:
57 |
58 | .. code-block:: python
59 |
60 | >>> print(reward, is_terminal, info, sep='\n')
61 | 0.26891435801077535
62 | False
63 | {'out_of_power': False, 'envelope_burst': False, 'zeropressure': False, 'time_elapsed': datetime.timedelta(seconds=180)}
64 |
65 |
66 | These should be enough to start training an RL agent.
67 |
68 | Configuring the Environment
69 | ###########################
70 |
71 | The environment may be configured to give custom behavior. To see all
72 | options for configuring the environment, see the
73 | `BalloonEnv `_
74 | constructor. Here, we highlight important options.
75 |
76 | First, the
77 | `FeatureConstructor `_
78 | class may be swapped out. The feature constructor receives observations
79 | from the simulator at each step, and returns features when required. This
80 | setup allows a feature constructor to maintain its own state, and use the
81 | simulator history to create a feature vector. The default feature constructor
82 | is the
83 | `PerciatelliFeatureConstructor `_.
84 |
85 | The reward function can also be swapped out. The default reward function,
86 | `perciatelli_reward_function `_
87 | gives a reward of 1.0 as long as the agent is in the stationkeeping readius.
88 | The reward decays exponentially outside of this radius.
89 |
90 | .. image:: imgs/reward_function.png
91 |
92 |
--------------------------------------------------------------------------------
/docs/imgs/balloon_schematic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/balloon_schematic.jpg
--------------------------------------------------------------------------------
/docs/imgs/ble_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/ble_logo.png
--------------------------------------------------------------------------------
/docs/imgs/ble_logo_small.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/ble_logo_small.png
--------------------------------------------------------------------------------
/docs/imgs/reward_function.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/reward_function.png
--------------------------------------------------------------------------------
/docs/imgs/station_keeping.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/station_keeping.gif
--------------------------------------------------------------------------------
/docs/imgs/training_curve.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/training_curve.jpg
--------------------------------------------------------------------------------
/docs/imgs/wind_field.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/balloon-learning-environment/72082feccf404e5bf946e513e4f6c0ae8fb279ad/docs/imgs/wind_field.gif
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Balloon Learning Environment documentation master file, created by
2 | sphinx-quickstart on Wed Dec 1 18:26:08 2021.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Balloon Learning Environment
7 | ============================
8 |
9 | .. note::
10 | We are working hard to improve the Balloon Learning Environment
11 | documentation. Please let us know if there is a page you'd like to see here!
12 |
13 | The Balloon Learning Environment (BLE) is a simulator for stratospheric
14 | balloons. It is designed as a challenge environment for deep reinforcement
15 | learning algorithms.
16 |
17 | .. toctree::
18 | :maxdepth: 1
19 | :caption: Getting Started:
20 |
21 | about
22 | getting_started
23 | new_agent
24 | environment
25 | benchmarks
26 | changelist
27 |
28 |
29 | Indices and tables
30 | ==================
31 |
32 | * :ref:`genindex`
33 | * :ref:`modindex`
34 | * :ref:`search`
35 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.https://www.sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | dopamine-rl==4.0.1
3 | flax==0.3.6
4 | gin-config==0.5.0
5 | jax==0.2.25
6 | jaxlib==0.1.74
7 | gym==0.21.0
8 | myst-parser==0.15.2
9 | numpy==1.21.4
10 | opensimplex==0.3
11 | s2sphere==0.2.5
12 | sklearn==0.0
13 | tensorflow==2.7.0
14 | tensorflow-probability==0.15.0
15 | transitions==0.8.10
16 |
--------------------------------------------------------------------------------
/docs/src/agents.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.agents
2 | ===================================
3 |
4 | The following agents are available in the Balloon Learning Environment.
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 | :caption: balloon_learning_environment.agents:
9 |
10 | agents/agent
11 | agents/dqn_agent
12 | agents/perciatelli44
13 | agents/quantile_agent
14 | agents/station_seeker_agent
--------------------------------------------------------------------------------
/docs/src/agents/agent.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.agents.agent
2 | =========================================
3 |
4 | .. currentmodule:: balloon_learning_environment.agents.agent
5 |
6 | .. autoclass:: balloon_learning_environment.agents.agent.Agent
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
11 | .. autoclass:: balloon_learning_environment.agents.agent.RandomAgent
12 | :members:
13 |
14 | .. autoclass:: balloon_learning_environment.agents.agent.AgentMode
15 | :members:
16 |
--------------------------------------------------------------------------------
/docs/src/agents/dqn_agent.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.agents.dqn_agent
2 | =============================================
3 |
4 | .. currentmodule:: balloon_learning_environment.agents.dqn_agent
5 |
6 | .. autoclass:: balloon_learning_environment.agents.dqn_agent.DQNAgent
7 | :members:
8 |
9 | .. automethod:: __init__
--------------------------------------------------------------------------------
/docs/src/agents/perciatelli44.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.agents.perciatelli44
2 | =================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.agents.perciatelli44
5 |
6 | .. autoclass:: balloon_learning_environment.agents.perciatelli44.Perciatelli44
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/src/agents/quantile_agent.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.agents.quantile_agent
2 | ==================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.agents.quantile_agent
5 |
6 | .. autoclass:: balloon_learning_environment.agents.quantile_agent.QuantileAgent
7 | :members:
8 |
9 | .. automethod:: __init__
--------------------------------------------------------------------------------
/docs/src/agents/station_seeker_agent.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.agents.station_seeker_agent
2 | ========================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.agents.station_seeker_agent
5 |
6 | .. autoclass:: balloon_learning_environment.agents.station_seeker_agent.StationSeekerAgent
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/src/balloon_env.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.env.balloon_env
2 | ============================================
3 |
4 | .. currentmodule:: balloon_learning_environment.env.balloon_env
5 |
6 | .. autoclass:: balloon_learning_environment.env.balloon_env.BalloonEnv
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
11 | .. autofunction:: balloon_learning_environment.env.balloon_env.perciatelli_reward_function
12 |
13 | .. image:: ../imgs/reward_function.png
14 |
15 |
--------------------------------------------------------------------------------
/docs/src/env.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.env
2 | ================================
3 |
4 | .. toctree::
5 | :maxdepth: 1
6 | :caption: balloon_learning_environment.env:
7 |
8 | balloon_env
9 | features
--------------------------------------------------------------------------------
/docs/src/eval_lib.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.eval.eval_lib
2 | ==========================================
3 |
4 | .. autofunction:: balloon_learning_environment.eval.eval_lib.eval_agent
5 |
--------------------------------------------------------------------------------
/docs/src/features.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.env.features
2 | =========================================
3 |
4 | .. currentmodule:: balloon_learning_environment.env.features
5 |
6 | .. autoclass:: balloon_learning_environment.env.features.FeatureConstructor
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
11 |
12 | .. autoclass:: balloon_learning_environment.env.features.PerciatelliFeatureConstructor
13 | :members:
14 |
15 | .. automethod:: __init__
16 |
17 |
18 | .. autoclass:: balloon_learning_environment.env.features.NamedPerciatelliFeatures
19 | :members:
20 |
--------------------------------------------------------------------------------
/docs/src/metrics.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics
2 | ====================================
3 |
4 | The following collectors are available in the Balloon Learning Environment.
5 |
6 | .. toctree::
7 | :maxdepth: 1
8 | :caption: balloon_learning_environment.metrics:
9 |
10 | metrics/collector
11 | metrics/collector_dispatcher
12 | metrics/statistics_instance
13 | metrics/console_collector
14 | metrics/pickle_collector
15 | metrics/tensorboard_collector
16 |
--------------------------------------------------------------------------------
/docs/src/metrics/collector.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics.collector
2 | ==============================================
3 |
4 | .. currentmodule:: balloon_learning_environment.metrics.collector
5 |
6 | .. autoclass:: balloon_learning_environment.metrics.collector.Collector
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
--------------------------------------------------------------------------------
/docs/src/metrics/collector_dispatcher.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics.collector_dispatcher
2 | =========================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.metrics.collector_dispatcher
5 |
6 | .. autoclass:: balloon_learning_environment.metrics.collector_dispatcher.CollectorDispatcher
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
--------------------------------------------------------------------------------
/docs/src/metrics/console_collector.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics.console_collector
2 | ======================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.metrics.console_collector
5 |
6 | .. autoclass:: balloon_learning_environment.metrics.console_collector.ConsoleCollector
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
--------------------------------------------------------------------------------
/docs/src/metrics/pickle_collector.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics.pickle_collector
2 | =====================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.metrics.pickle_collector
5 |
6 | .. autoclass:: balloon_learning_environment.metrics.pickle_collector.PickleCollector
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
--------------------------------------------------------------------------------
/docs/src/metrics/statistics_instance.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics.statistics_instance
2 | =========================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.metrics.statistics_instance
5 |
6 | .. autoclass:: balloon_learning_environment.metrics.statistics_instance.StatisticsInstance
7 | :members:
8 |
--------------------------------------------------------------------------------
/docs/src/metrics/tensorboard_collector.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.metrics.tensorboard_collector
2 | ==========================================================
3 |
4 | .. currentmodule:: balloon_learning_environment.metrics.tensorboard_collector
5 |
6 | .. autoclass:: balloon_learning_environment.metrics.tensorboard_collector.TensorboardCollector
7 | :members:
8 |
9 | .. automethod:: __init__
10 |
--------------------------------------------------------------------------------
/docs/src/train_lib.rst:
--------------------------------------------------------------------------------
1 | balloon_learning_environment.train_lib
2 | ======================================
3 |
4 | .. autofunction:: balloon_learning_environment.train_lib.run_training_loop
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools", "wheel"]
3 | build-backend = "setuptools.build_meta"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.15.0
2 | ale-py==0.7.3
3 | astunparse==1.6.3
4 | cached-property==1.5.2
5 | cachetools==4.2.4
6 | certifi==2021.10.8
7 | charset-normalizer==2.0.7
8 | chex==0.0.8
9 | clang==5.0
10 | cloudpickle==2.0.0
11 | cycler==0.11.0
12 | decorator==5.1.0
13 | dm-tree==0.1.6
14 | dopamine-rl==4.0.0
15 | flatbuffers==1.12
16 | flax==0.3.6
17 | future==0.18.2
18 | gast==0.4.0
19 | gin==0.1.6
20 | gin-config==0.5.0
21 | google-auth==1.35.0
22 | google-auth-oauthlib==0.4.6
23 | google-pasta==0.2.0
24 | grpcio==1.41.1
25 | gym==0.21.0
26 | h5py==3.1.0
27 | idna==3.3
28 | importlib-metadata==4.8.1
29 | importlib-resources==5.4.0
30 | jax==0.3.0
31 | jaxlib==0.3.0
32 | joblib==1.1.0
33 | keras==2.7.0
34 | Keras-Preprocessing==1.1.2
35 | kiwisolver==1.3.2
36 | libclang==12.0.0
37 | Markdown==3.3.4
38 | matplotlib==3.4.3
39 | msgpack==1.0.2
40 | numpy==1.19.5
41 | oauthlib==3.1.1
42 | opencv-python==4.5.4.58
43 | opensimplex==0.3
44 | opt-einsum==3.3.0
45 | optax==0.0.9
46 | pandas==1.3.4
47 | Pillow==8.4.0
48 | protobuf==3.19.1
49 | pyasn1==0.4.8
50 | pyasn1-modules==0.2.8
51 | pygame==2.0.3
52 | pyparsing==3.0.4
53 | python-dateutil==2.8.2
54 | pytz==2021.3
55 | requests==2.26.0
56 | requests-oauthlib==1.3.0
57 | rsa==4.7.2
58 | s2sphere==0.2.5
59 | scikit-learn==1.0.1
60 | scipy==1.7.1
61 | six==1.15.0
62 | sklearn==0.0
63 | tensorboard==2.6.0
64 | tensorboard-data-server==0.6.1
65 | tensorboard-plugin-wit==1.8.0
66 | tensorflow==2.7.0rc1
67 | tensorflow-estimator==2.7.0
68 | tensorflow-io-gcs-filesystem==0.21.0
69 | tensorflow-probability==0.16.0
70 | termcolor==1.1.0
71 | tf-slim==1.1.0
72 | tfp-nightly==0.15.0.dev20211104
73 | threadpoolctl==3.0.0
74 | toolz==0.11.1
75 | transitions==0.8.10
76 | typing-extensions==3.7.4.3
77 | urllib3==1.26.7
78 | Werkzeug==2.0.2
79 | wrapt==1.12.1
80 | zipp==3.6.0
81 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2022 The Balloon Learning Environment Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Setup file for installing the BLE."""
17 | import os
18 | import pathlib
19 | import setuptools
20 | from setuptools.command import build_py
21 | from setuptools.command import develop
22 |
23 | current_directory = pathlib.Path(__file__).parent
24 | description = (current_directory / 'README.md').read_text()
25 |
26 | core_requirements = [
27 | 'absl-py',
28 | 'dopamine-rl >= 4.0.0',
29 | 'flax',
30 | 'gin-config',
31 | 'gym',
32 | 'jax >= 0.2.28',
33 | 'jaxlib >= 0.1.76',
34 | 'opensimplex <= 0.3.0',
35 | 's2sphere',
36 | 'scikit-learn',
37 | 'tensorflow',
38 | 'tensorflow-probability',
39 | 'transitions',
40 | ]
41 |
42 | acme_requirements = [
43 | 'dm-acme',
44 | 'dm-haiku',
45 | 'dm-reverb',
46 | 'dm-sonnet',
47 | 'rlax',
48 | 'xmanager',
49 | ]
50 |
51 |
52 | def generate_requirements_file(path=None):
53 | """Generates requirements.txt file needed for running Acme.
54 |
55 | It is used by Launchpad GCP runtime to generate Acme requirements to be
56 | installed inside the docker image. Acme itself is not installed from pypi,
57 | but instead sources are copied over to reflect any local changes made to
58 | the codebase.
59 | Args:
60 | path: path to the requirements.txt file to generate.
61 | """
62 | if not path:
63 | path = os.path.join(os.path.dirname(__file__), 'acme_requirements.txt')
64 | with open(path, 'w') as f:
65 | for package in set(core_requirements + acme_requirements):
66 | f.write(f'{package}\n')
67 |
68 |
69 | class BuildPy(build_py.build_py):
70 |
71 | def run(self):
72 | generate_requirements_file()
73 | build_py.build_py.run(self)
74 |
75 |
76 | class Develop(develop.develop):
77 |
78 | def run(self):
79 | generate_requirements_file()
80 | develop.develop.run(self)
81 |
82 | cmdclass = {
83 | 'build_py': BuildPy,
84 | 'develop': Develop,
85 | }
86 |
87 | entry_points = {
88 | 'gym.envs': [
89 | '__root__=balloon_learning_environment.env.gym:register_env'
90 | ]
91 | }
92 |
93 |
94 | setuptools.setup(
95 | name='balloon_learning_environment',
96 | long_description=description,
97 | long_description_content_type='text/markdown',
98 | version='1.0.2',
99 | cmdclass=cmdclass,
100 | packages=setuptools.find_packages(),
101 | install_requires=core_requirements,
102 | extras_require={
103 | 'acme': acme_requirements,
104 | },
105 | package_data={
106 | '': ['*.msgpack', '*.pb', '*.gin'],
107 | },
108 | entry_points=entry_points,
109 | python_requires='>=3.7',
110 | )
111 |
--------------------------------------------------------------------------------
/style_guidelines.md:
--------------------------------------------------------------------------------
1 | # Style Guide
2 | This document outlines guidelines for coding style in this repo.
3 |
4 | ## Directory structure
5 | The top-level directory should only contain Python files necessary for executing
6 | the main binary (`train.py`).
7 |
8 | The **env** directory contains all files necessary for running the simulator,
9 | including the balloon simulator, the wind vector model, the feature vector
10 | constructor, and the gym wrapper.
11 |
12 | The **agents** directory contains all files necessary for defining and training
13 | agents which will control the balloon. Note that this may include code necessary
14 | for checkpointing.
15 |
16 | The **metrics** directory contains all files necessary for logging and reporting
17 | performance metrics.
18 |
19 | ## Typing
20 | All variables and functions should be properly typed. We adhere to:
21 | https://docs.python.org/3/library/typing.html
22 |
23 | ## Member variables and methods
24 | In classes, member variables are either _public_ or _private_. We do not make
25 | use of setters and getters. All private members will have their name prefixed by
26 | `_` and should not be accessed from outside the class.
27 |
28 | ## Use of `@property` decorator
29 | We discourage the use of `@property` to avoid confusion, unless required to
30 | conform to an external API specification (for example, for Gym). In these cases,
31 | the reason for its use should be documented above the method.
32 |
33 | ## Abstract classes
34 | We encourage the use of abstract classes to avoid code duplication. Any abstract
35 | class must subclass `abc.ABC` and decorate required methods with
36 | `@abc.abstractmethod`.
37 |
38 | ## Static methods
39 | Functions that are only called within a class but do not access any class
40 | members (e.g. `self.`) must be made static by decorating them with
41 | `@staticmethod`.
42 |
43 | # Data Classes
44 | The `@dataclasses.dataclass` decorator should be used whenever the `__init__`
45 | method would only be setting values based on the constructor parameters.
46 |
47 | # Floats
48 | Prefer to use `0.0` over `0.`, as the former is more obviously a float.
49 |
50 | ## Gin config
51 | We make use of gin config for parameter injection. Rather than passing
52 | parameters from flags all the way to the method/class where it will be used, the
53 | gin-configurable parameters are specified via gin config files (or gin
54 | bindings).
55 |
56 | The guidelines for gin-configurable parameters are:
57 | 1. Only set variables in a gin config which have a default value of
58 | gin.REQUIRED.
59 | 1. Only keyword-only args can be set with a gin config.
60 |
61 | For example in the signature below:
62 | ```
63 | def f(x: float, y: float = 0.0 *, z: float = 0.0, alpha: float = gin.REQUIRED)
64 | ```
65 |
66 | the only variable that can (and must) be set via a gin config is `alpha`.
67 |
68 | We accept two gin-config files: one for specifying _environment_
69 | parameters (via the `--environment_gin_file` flag), and one for specifying
70 | _agent_ parameters (via the `--agent_gin_file` flag). Any other parameters (or
71 | variations from those specified in the config files) can be specified via
72 | the `--gin_bindings` flags.
73 |
74 | For more information see:
75 | https://github.com/google/gin-config
76 |
--------------------------------------------------------------------------------