├── .gitignore ├── LICENCE ├── README.md ├── requirements.txt ├── setup.py └── understanding_rl_vision ├── __init__.py ├── rl_clarity ├── .gitignore ├── __init__.py ├── compiling.py ├── example.py ├── interface.py ├── loading.py ├── svelte │ ├── Makefile │ ├── attribution_selector.svelte │ ├── attribution_viewer.svelte │ ├── chart.svelte │ ├── css_manipulate.js │ ├── feature_viewer.svelte │ ├── graph.svelte │ ├── interface.svelte │ ├── json_load.js │ ├── legend.svelte │ ├── navigator.svelte │ ├── query.svelte │ ├── screen.svelte │ ├── scrubber.svelte │ └── trajectory_display.svelte └── training.py └── svelte3 ├── __init__.py ├── compiling.py ├── json_encoding.py ├── package-lock.json └── package.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Node 42 | 43 | # Dependency directories 44 | node_modules/ 45 | 46 | # MacOS 47 | 48 | # General 49 | .DS_Store 50 | .AppleDouble 51 | .LSOverride 52 | 53 | # Icon must end with two \r 54 | Icon 55 | 56 | # Thumbnails 57 | ._* 58 | 59 | # Files that might appear in the root of a volume 60 | .DocumentRevisions-V100 61 | .fseventsd 62 | .Spotlight-V100 63 | .TemporaryItems 64 | .Trashes 65 | .VolumeIcon.icns 66 | .com.apple.timemachine.donotpresent 67 | 68 | # Directories potentially created on remote AFP share 69 | .AppleDB 70 | .AppleDesktop 71 | Network Trash Folder 72 | Temporary Items 73 | .apdisk 74 | 75 | # Windows 76 | 77 | # Windows thumbnail cache files 78 | Thumbs.db 79 | Thumbs.db:encryptable 80 | ehthumbs.db 81 | ehthumbs_vista.db 82 | 83 | # Dump file 84 | *.stackdump 85 | 86 | # Folder config file 87 | [Dd]esktop.ini 88 | 89 | # Recycle Bin used on file shares 90 | $RECYCLE.BIN/ 91 | 92 | # Windows Installer files 93 | *.cab 94 | *.msi 95 | *.msix 96 | *.msm 97 | *.msp 98 | 99 | # Windows shortcuts 100 | *.lnk -------------------------------------------------------------------------------- /LICENCE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | # Understanding RL Vision 4 | 5 | #### [ [Paper] ](https://distill.pub/2020/understanding-rl-vision) [ [Demo] ](https://openaipublic.blob.core.windows.net/rl-clarity/attribution/demo/interface.html) 6 | 7 | Generate interfaces for interpreting vision models trained using RL. 8 | 9 | The core utilities used to compute feature visualization, attribution and dimensionality reduction can be found in `lucid.scratch.rl_util`, a submodule of [Lucid](https://github.com/tensorflow/lucid/). These are demonstrated in [this notebook](https://colab.research.google.com/github/tensorflow/lucid/blob/master/notebooks/misc/rl_util.ipynb). The code here leverages these utilities to build HTML interfaces similar to the above demo. 10 | 11 | ![](https://openaipublic.blob.core.windows.net/rl-clarity/attribution/demo.gif) 12 | 13 | ## Installation 14 | 15 | Supported platforms: MacOS and Ubuntu, Python 3.7, TensorFlow <= 1.14 16 | 17 | - Install [Baselines](https://github.com/openai/baselines) and its dependencies, including TensorFlow 1. 18 | - Clone the repo: 19 | ``` 20 | git clone https://github.com/openai/understanding-rl-vision.git 21 | ``` 22 | - Install the repo and its dependencies, among which is a pinned version of [Lucid](https://github.com/tensorflow/lucid): 23 | ``` 24 | pip install -e understanding-rl-vision 25 | ``` 26 | - Install an RL environment of your choice. Supported environments: 27 | - [CoinRun](https://github.com/openai/coinrun) (the original version used in the paper): follow the instructions. Note: due to CoinRun's requirements, you should re-install Baselines after installing CoinRun. 28 | - [Procgen](https://github.com/openai/procgen): `pip install procgen` 29 | - [Atari](https://github.com/openai/atari-py): `pip install atari-py` 30 | 31 | ## Generating interfaces 32 | 33 | The main script processes checkpoint files saved by RL code: 34 | ``` 35 | from understanding_rl_vision import rl_clarity 36 | 37 | rl_clarity.run('path/to/checkpoint/file', output_dir='path/to/directory') 38 | ``` 39 | 40 | An example checkpoint file can be downloaded [here](https://openaipublic.blob.core.windows.net/rl-clarity/attribution/models/coinrun.jd), or can be generated using the [example script](understanding_rl_vision/rl_clarity/example.py). Checkpoint files for a number of pre-trained models are indexed [here](https://openaipublic.blob.core.windows.net/rl-clarity/attribution/models/index.html). 41 | 42 | The precise format required of the checkpoint file, along with a full list of keyword arguments, can be found in the function's [docstring](understanding_rl_vision/rl_clarity/__init__.py). 43 | 44 | The script will create an `interface.html` file, along with directories containing images (which can take up several GB), at the location specified by `output_dir`. 45 | 46 | By default, the script will also create some files in the directory of the checkpoint file, in an `rl-clarity` subdirectory. These contain all the necessary information extracted from the model and environment for re-creating the same interface. To create these files in a temporary location instead, set `load_kwargs={'temp_files': True}`. To re-create an interface using existing files, set `load_kwargs={'resample': False}`. 47 | 48 | ### Speed issues 49 | 50 | The slowest part of the script is computing the attribution in all the required combinations. If you set `trajectories_kwargs={'num_envs': num_envs, 'num_steps': num_steps}`, then `num_envs` trajectories will be collected, each of length `num_steps`, and the script will distribute the trajectories among the MPI workers for computing the attribution. The memory requirements of each worker scales with `num_steps`, which defaults to 512 (about as large as a machine with 34 GB of memory can typically handle). The default `num_envs` is 8, so it is best to use 8 MPI workers by default to save time, if you have 8 GPUs available. 51 | 52 | The script should take a few hours to run, but if it is taking too long, then you can tell the script to ignore the first couple of non-input layers by setting `layer_kwargs={'discard_first_n': 2}`, for example. These layers take the longest to compute attribution for since they have the highest spatial resolution, and are usually not that informative anyway. 53 | 54 | By default, attribution is only computed for the value function, since computing attribution for every logit of the policy amounts to a large multiplier on the time taken by the script to run. To compute attribution for the policy, set `attr_policy=True`. To offset the increased computational load when doing this, you may wish to choose a single layer to compute attribution for by setting `layer_kwargs={'name_contains_one_of': ['2b']}`, for example. 55 | 56 | To save disk space, the hover effect for isolating single attribution channels can be disabled by setting `attr_single_channels=False`, though this will not have much effect on speed. 57 | 58 | ## Guide to interfaces 59 | 60 | As shown in [this demo](https://openaipublic.blob.core.windows.net/rl-clarity/attribution/demo/interface.html), interfaces are divided into a number of sections: 61 | 62 | - **Trajectories** - Each trajectory is a separate rollout of the agent interacting with the environment. Here you can select one of them. 63 | - **Bookmarks** - Advantages have been computed using [generalized advantage estimation](https://arxiv.org/abs/1506.02438) (GAE). These provide a measure of how successful each choice made by the agent turned out relative to its expectations, and would usually be used to improve the agent's policy during training. The links here allow you to skip to specific frames from the trajectories with the highest and lowest advantages (with at most one link per episode). 64 | - **Layers** - Here you can select a layer for which attribution (explained below) has been computed. For the input layer, if included, attribution makes less sense, so gradients have been computed instead. 65 | - **Timeline** - Here you can navigate through the frames in each trajectory, either using the buttons or by scrubbing. At the top, information about the current frame is displayed, including the last reward received, the agent's policy, and the action that was chosen next. There are graphs of advantages (as used by the Bookmarks section) and of each network output that has been selected in the Attribution section. 66 | - **Attribution** - Here you can view the observations processed by the agent, and attribution from network outputs (just the value function by default) to the selected layer. Below the observation is chart of the attribution summed over spatial positions. If attribution has been computed for the policy, you can add and remove rows from this section, and select a different network output for each row, such as the value function, or the policy's logit for a particular action. Attribution has been computed using the method of [integrated gradients](https://arxiv.org/abs/1703.01365): the gradient of the network output with respect to selected layer has been numerically integrated along the straight line from zero to the layer's output given the current observation. This effectively decomposes (or "attributes") the network output across the spatial positions and channels of the selected layer. Dimensionality reduction ([non-negative matrix factorization](https://en.wikipedia.org/wiki/Non-negative_matrix_factorization)) has been applied to the channels using a large batch of varied observations, and the resulting channels are represented using different colors. Additional normalization and smoothing has been applied, with strong attribution bleeding into nearby spatial positions. 67 | - **Attribution legend** - For each of the channels produced by dimensionality reduction (explained above), there are small visualizations here of the feature picked out by that channel. These consist of patches taken from observations at the spatial positions where the selected layer was most strongly activated in the direction of the channel. Hovering over these isolates the channel for the displayed attribution, and clicking opens a the Feature visualization popup, where the feature can be further analyzed. 68 | - **Feature visualization** (in popup) - This is displayed after a feature from the Attribution legend section has been selected, and shows a larger visualization of the feature. This also consists of patches taken from observations where the selected layer was most strongly activated in the appropriate direction, but here the location of a patch determines a specific spatial position that must be activated. This means that there is a spatial correspondence between the visualization and observations. Patches with weaker activations are displayed with greater transparency, except when hovering over the image. There are sliders that can be used to set the zoom level of the patches (which can also be controlled by scrolling over the image) and the number of patches (which initially equals the number of spatial positions of the selected layer). Clicking on a patch reveals the full observation from which the patch was extracted. 69 | - **Hotkeys** - Here is a list of available keyboard shortcuts. Toggling between play and pause also toggles between whether the arrow keys change the play direction or take a single step in one direction. 70 | 71 | ## Training models 72 | 73 | There is also a script for training a model using [PPO2](https://github.com/openai/baselines/tree/master/baselines/ppo2) from [Baselines](https://github.com/openai/baselines), and saving a checkpoint file in the required format: 74 | ``` 75 | from understanding_rl_vision import rl_clarity 76 | 77 | rl_clarity.train(env_name='coinrun_old', save_dir='path/to/directory') 78 | ``` 79 | 80 | This script is intended to explain checkpoint files, and has not been well-tested. The [example script](understanding_rl_vision/rl_clarity/example.py) demonstrates how to train a model and then generate an interface for it. 81 | 82 | ## Svelte compilation 83 | 84 | To generate interfaces, the Svelte source must be compiled to JavaScript. At installation, the module will automatically attempt to download the pre-compiled JavaScript from a remote copy, though this copy is not guaranteed to be kept up-to-date. 85 | 86 | To obtain an up-to-date copy, or for development, you may wish to re-compile the JavaScript locally. To do this, first install [Node.js](https://nodejs.org/) if you have not already. On Mac: 87 | ``` 88 | brew install node 89 | ``` 90 | You will then be able to re-compile the JavaScript: 91 | ``` 92 | python -c 'from understanding_rl_vision import rl_clarity; rl_clarity.recompile_js()' 93 | ``` 94 | 95 | ### Standalone compiler 96 | 97 | The `svelte3` package provides generic functions for compiling version 3 of Svelte to JavaScript or HTML. These can be used to create an easy-to-use command-line tool: 98 | ``` 99 | python -c 'from understanding_rl_vision import svelte3; svelte3.compile_html("path/to/svelte/file", "path/to/html/file")' 100 | ``` 101 | 102 | Detailed usage instructions can be found in the functions' [docstrings](svelte3/compiling.py). 103 | 104 | ## Citation 105 | 106 | Please cite using the following BibTeX entry: 107 | ``` 108 | @article{hilton2020understanding, 109 | author = {Hilton, Jacob and Cammarata, Nick and Carter, Shan and Goh, Gabriel and Olah, Chris}, 110 | title = {Understanding RL Vision}, 111 | journal = {Distill}, 112 | year = {2020}, 113 | note = {https://distill.pub/2020/understanding-rl-vision}, 114 | doi = {10.23915/distill.00029} 115 | } 116 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # installs dependencies from ./setup.py, and the package itself, 2 | # in editable mode 3 | -e . 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | from setuptools import setup, find_packages 4 | 5 | REMOTE_JS_PATH = ( 6 | "https://openaipublic.blob.core.windows.net/rl-clarity/attribution/js/interface.js" 7 | ) 8 | 9 | 10 | def download_js(): 11 | dir_ = os.path.dirname(os.path.realpath(__file__)) 12 | js_dir_path = os.path.join(dir_, "understanding_rl_vision/rl_clarity/js") 13 | js_path = os.path.join(js_dir_path, "interface.js") 14 | if not os.path.isfile(js_path): 15 | if not os.path.exists(js_dir_path): 16 | os.mkdir(js_dir_path) 17 | try: 18 | urllib.request.urlretrieve(REMOTE_JS_PATH, js_path) 19 | except: 20 | if os.path.exists(js_path): 21 | os.remove(js_path) 22 | 23 | 24 | setup( 25 | name="understanding-rl-vision", 26 | packages=find_packages(), 27 | version="0.0.1", 28 | install_requires=[ 29 | "mpi4py", 30 | "baselines", 31 | "lucid @ git+https://github.com/tensorflow/lucid.git@16a03dee8f99af4cdd89d6b7c1cc913817174c83", 32 | ], 33 | extras_require={"envs": ["coinrun", "procgen", "atari-py"]}, 34 | ) 35 | 36 | download_js() 37 | -------------------------------------------------------------------------------- /understanding_rl_vision/__init__.py: -------------------------------------------------------------------------------- 1 | from . import rl_clarity 2 | from . import svelte3 3 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/.gitignore: -------------------------------------------------------------------------------- 1 | js/ -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/__init__.py: -------------------------------------------------------------------------------- 1 | from .interface import generate 2 | from .loading import load 3 | from .compiling import recompile_js 4 | from .training import train 5 | 6 | 7 | def run( 8 | checkpoint_path, 9 | *, 10 | output_dir, 11 | load_kwargs={}, 12 | trajectories_kwargs={}, 13 | observations_kwargs={}, 14 | **generate_kwargs 15 | ): 16 | """Generate an interface from a checkpoint file. 17 | 18 | Arguments: 19 | checkpoint_path: path to checkpoint file, a joblib file containing a 20 | dictionary with these keys 21 | - params: saved model parameters as a dictionary mapping tensor names 22 | to numpy arrays 23 | - args: dictionary of metadata with these keys 24 | - env_name: name of the Procgen environment 25 | required if env_kind is 'procgen' 26 | - env_id: lowercase id of the Atari environment 27 | required if env_kind is 'atari' 28 | - env_kind: either 'procgen' or 'atari' 29 | defaults to 'procgen' 30 | - gamma: GAE hyperparameter gamma used to train the model 31 | defaults to None 32 | - lambda: GAE hyperparameter lambda used to train the model 33 | defaults to None 34 | - cnn: model architecture, one of 'clear', 'impala' or 35 | 'nature' 36 | defaults to 'clear' 37 | - any other optional arguments used to create the environment or 38 | get the architecture 39 | output_dir: path to directory where interface is to be saved 40 | required 41 | load_kwargs: dictionary with keys for any of the following 42 | - resample: whether to process the checkpoint file from 43 | scratch, rather than reusing samples previously 44 | saved to a non-temporary location 45 | defaults to True 46 | - model_path: lucid model save location 47 | - metadata_path: metadata dictionary save location 48 | - trajectories_path: trajectories save location 49 | - observations_path: additional observations save location 50 | - full_resolution: whether to also save observations in human-scale 51 | resolution (significant performance cost) 52 | defaults to False 53 | - temp_files: if any of the above paths is not specified, 54 | whether to default to a temporary location 55 | rather than a sudirectory of the checkpoint 56 | file's directory 57 | defaults to False 58 | trajectories_kwargs: dictionary with keys for any of the following 59 | only used if resampling 60 | - num_envs: number of trajectories to collect 61 | defaults to 8 62 | - num_steps: length of each trajectory 63 | defaults to 512 64 | observations_kwargs: dictionary with keys for any of the following 65 | only used if resampling 66 | - num_envs: number of environments to collect additional 67 | observations from in parallel 68 | defaults to 32 69 | - num_obs: number of additional observations to collect from 70 | each parallel environment 71 | defaults to 128 72 | - obs_every: number of steps to wait between each observation 73 | defaults to 128 74 | model_bytes: lucid model, represented as a save file's bytes 75 | defaults to being extracted automatically 76 | observations: numpy array of additional observations used for 77 | feature visualization 78 | defaults to being extracted automatically 79 | observations_full: numpy array of the additional observations in 80 | human-scale resolution, or None to only use 81 | observations at the resolution seen by the model 82 | defaults to being extracted automatically, or None 83 | if human-scale resolution observations were not 84 | saved 85 | trajectories: dictionary of trajectories with keys 86 | 'observations', 'actions', 'rewards', either 87 | 'firsts' or 'dones', and optionally 88 | 'observations_full', each value being a numpy 89 | array with first two dimensions batch and timestep 90 | defaults to being extracted automatically 91 | policy_logits_name: name of tensor of policy logits 92 | defaults to being extracted automatically 93 | value_function_name: name of tensor of value function 94 | defaults to being extracted automatically 95 | env_name: Procgen environment name, used to help infer 96 | action_combos if that is not provided 97 | defaults to being extracted automatically, or 98 | 'unknown' if that fails 99 | numpy_precision: number of significant figures to round numpy 100 | arrays in the HTML file to 101 | defaults to 6 102 | inline_js: whether to include the JavaScript in the HTML file 103 | inline, rather than referencing a separate file 104 | defaults to True (to avoid ad-blocker issues) 105 | inline_large_json: whether to include large amounts of JSON data in 106 | the HTML file inline, rather than referencing 107 | separate files 108 | defaults to whether output_dir does not contain 109 | '://' 110 | batch_size: size of minibatch of observations to pass through 111 | model 112 | defaults to 512 113 | action_combos: list of tuples of strings describing the 114 | combinations of buttons triggered by each action 115 | defaults to being extracted automatically, or 116 | [('0',), ..., ('',)] if that fails 117 | action_group_fns: list of function filters for grouping the action 118 | combos in different ways 119 | defaults to [ 120 | lambda combo: 'RIGHT' in combo, 121 | lambda combo: 'LEFT' in combo, 122 | lambda combo: 'UP' in combo, 123 | lambda combo: 'DOWN' in combo, 124 | lambda combo: 'RIGHT' not in combo 125 | and 'LEFT' not in combo 126 | and 'UP' not in combo 127 | and 'DOWN' not in combo 128 | ] 129 | layer_kwargs: dictionary of options for choosing layers, with keys for 130 | any of the following 131 | - name_contains_one_of: list of strings each layer name must contain 132 | one of, or None to not filter by name 133 | defaults to None 134 | - op_is_one_of: list of strings each layer op must be one of 135 | defaults to ['relu'] 136 | - bottleneck_only: whether to only include layers such that every 137 | path to an earlier convolutional layer passes 138 | through a bottleneck of the network 139 | defaults to True 140 | - discard_first_n: number of first layers to discard 141 | defaults to 0 142 | input_layer_include: whether to additionally calcuate gradients with 143 | respect to the input layer 144 | defaults to False 145 | input_layer_name: display name of the input layer 146 | defaults to 'input' 147 | gae_gamma: gamma for computing advantages using GAE 148 | defaults to being extracted automatically, or 149 | 0.999 if that fails 150 | gae_lambda: lambda for computing advantages using GAE 151 | defaults to being extracted automatically, or 152 | 0.95 if that fails 153 | trajectory_bookmarks: number of links to display to highest advantage 154 | episodes and to lowest advantage episodes 155 | defaults to 16 156 | nmf_features: number of dimensions for NMF dimensionality 157 | reduction 158 | defaults to 8 159 | nmf_attr_opts: dictionary of options for computing attribution 160 | for NMF dimensionality reduction, the main one 161 | being integrate_steps (explained below, see 162 | attr_integrate_steps) 163 | defaults to {'integrate_steps': 10}, though if a 164 | dictionary is provided without an 165 | 'integrate_steps' key, then integrate_steps 166 | defaults to 1 167 | vis_subdiv_mults: list of values of subdiv_mult, the spatial 168 | resolution of the grid of dataset examples used 169 | for feature visualization, as a mulitple of the 170 | resolution of the layer's activations 171 | defaults to [0.25, 0.5, 1, 2] 172 | vis_subdiv_mult_default: default value of subdiv_mult (explained above) 173 | defaults to 1 174 | vis_expand_mults: list of values of expand_mult, the height and 175 | width of each patch used for feature 176 | visualization, as a multiple of the number of 177 | pixels if the layer were overlaid on the 178 | observation 179 | defaults to [1, 2, 4, 8] 180 | vis_expand_mult_default: default value of expand_mult (explained above) 181 | defaults to 4 182 | vis_thumbnail_num_mult: spatial resolution of the grid of dataset 183 | examples used for feature visualization thumbnails 184 | defaults to 4 185 | vis_thumbnail_expand_mult: the height and width of each patch used for 186 | feature visualization thumbnails, as a multiple of 187 | the number of pixels if the layer were overlaid on 188 | the observation 189 | defaults to 4 190 | scrub_range: horizonal interval of observations and attribution 191 | used to construct scrubs 192 | defaults to (42 / 64, 44 / 64) 193 | attr_integrate_steps: number of points on the path used for numerical 194 | integration for computing attribution 195 | defaults to 10 196 | attr_max_paths: maximum number of paths for multi-path 197 | attribution, or None to use single-path 198 | attribution 199 | defaults to None 200 | attr_policy: whether to compute attribution for the policy 201 | defaults to False 202 | attr_single_channels: whether to allow attribution for single channels 203 | to be displayed 204 | defaults to True 205 | observations_subdir: name of subdirectory containing additional 206 | observations 207 | defaults to 'observations/' 208 | trajectories_subdir: name of subdirectory containing trajectory 209 | observations 210 | defaults to 'trajectories/' 211 | trajectories_scrub_subdir: name of subdirectory containing scrubs of 212 | trajectory observations 213 | defaults to 'trajectories_scrub/' 214 | features_subdir: name of subdirectory containing feature 215 | visualizations 216 | defaults to 'features/' 217 | thumbnails_subdir: name of subdirectory containing feature thumbnails 218 | defaults to 'thumbnails/' 219 | attribution_subdir: name of subdirectory containing attribution 220 | defaults to 'attribution/' 221 | attribution_scrub_subdir: name of subdirectory containing scrubs of 222 | attribution 223 | defaults to 'attribution_scrub/' 224 | video_height: css height of each video screen 225 | defaults to '16em' 226 | video_width: css width of each video screen 227 | defaults to '16em' 228 | video_speed: speed of vidoes in frames per second 229 | defaults to 12 230 | policy_display_height: css height of bar displaying policy 231 | defaults to '2em' 232 | policy_display_width: css width of bar displaying policy 233 | defaults to '40em' 234 | navigator_width: css width of navigator bar 235 | defaults to '24em' 236 | scrubber_height: css height of each scrubber 237 | defaults to '4em' 238 | scrubber_width: css width of each scrubber 239 | defaults to '48em' 240 | scrubber_visible_duration: number of frames visible in each scrubber 241 | defaults to 256 242 | legend_item_height: css height of each legend item 243 | defaults to '6em' 244 | legend_item_width: css width of each legend item 245 | defaults to '6em' 246 | feature_viewer_height: css height of feature visualizations in the popup 247 | defaults to '40em' 248 | feature_viewer_width: css width of feature visualizations in the popup 249 | defaults to '40em' 250 | attribution_weight: css opacity of attribution when overlaid on 251 | observations (taking into account the fact that 252 | attribution is mostly transparent) 253 | defaults to 0.9 254 | graph_colors: dictionary specifying css colors of graphs of each 255 | type 256 | defaults to { 257 | 'v': 'green', 258 | 'action': 'red', 259 | 'action_group': 'orange', 260 | 'advantage': 'blue' 261 | } 262 | trajectory_color: css color of text displaying trajectory 263 | information such as actions and rewards 264 | defaults to 'blue' 265 | """ 266 | import tensorflow as tf 267 | from mpi4py import MPI 268 | from baselines.common.mpi_util import setup_mpi_gpus 269 | 270 | comm = MPI.COMM_WORLD 271 | rank = comm.Get_rank() 272 | size = comm.Get_size() 273 | setup_mpi_gpus() 274 | 275 | exn = None 276 | if rank == 0 and load_kwargs.get("resample", True): 277 | kwargs = load( 278 | checkpoint_path, 279 | trajectories_kwargs=trajectories_kwargs, 280 | observations_kwargs=observations_kwargs, 281 | **load_kwargs 282 | ) 283 | comm.barrier() 284 | else: 285 | comm.barrier() 286 | load_kwargs["resample"] = False 287 | try: 288 | kwargs = load( 289 | checkpoint_path, 290 | trajectories_kwargs=trajectories_kwargs, 291 | observations_kwargs=observations_kwargs, 292 | **load_kwargs 293 | ) 294 | except tf.errors.NotFoundError as e: 295 | exn = e 296 | kwargs = None 297 | errors = comm.allreduce(0 if exn is None else 1, op=MPI.SUM) 298 | if errors == size: 299 | raise FileNotFoundError from exn 300 | elif errors > 0: 301 | kwargs = comm.bcast(kwargs, root=0) 302 | kwargs["output_dir"] = output_dir 303 | kwargs.update(generate_kwargs) 304 | 305 | generate(**kwargs) 306 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/compiling.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ..svelte3 import compile_js 3 | 4 | 5 | def construct_path(relpath): 6 | dir_ = os.path.dirname(os.path.realpath(__file__)) 7 | return os.path.join(dir_, relpath) 8 | 9 | 10 | SVELTE_PATH = construct_path("svelte/interface.svelte") 11 | JS_DIR_PATH = construct_path("js") 12 | JS_PATH = construct_path("js/interface.js") 13 | 14 | 15 | def recompile_js(): 16 | if not os.path.exists(JS_DIR_PATH): 17 | os.mkdir(JS_DIR_PATH) 18 | print(compile_js(SVELTE_PATH, js_path=JS_PATH)["command_output"]) 19 | 20 | 21 | def get_compiled_js(): 22 | if not os.path.isfile(JS_PATH): 23 | recompile_js() 24 | return JS_PATH 25 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import argparse 4 | from understanding_rl_vision import rl_clarity 5 | from understanding_rl_vision.rl_clarity.training import ( 6 | PROCGEN_ENV_NAMES, 7 | ATARI_ENV_DICT, 8 | ) 9 | 10 | 11 | def train_and_run(env_name_or_id, *, base_path=None): 12 | if base_path is None: 13 | base_path = tempfile.mkdtemp(prefix="rl_clarity_example_") 14 | training_dir = os.path.join(base_path, "training") 15 | interface_dir = os.path.join(base_path, "interface") 16 | if "://" not in base_path: 17 | os.makedirs(training_dir, exist_ok=True) 18 | os.makedirs(interface_dir, exist_ok=True) 19 | 20 | if env_name_or_id in PROCGEN_ENV_NAMES + ["coinrun_old"]: 21 | env_kwargs = {"env_name": env_name_or_id} 22 | elif env_name_or_id in ATARI_ENV_DICT: 23 | env_kwargs = {"env_id": env_name_or_id, "env_kind": "atari"} 24 | else: 25 | raise ValueError(f"Unsupported env {env_name_or_id}") 26 | 27 | # train for very few timesteps, to demonstrate 28 | # note: training code has not been well-tested 29 | rl_clarity.train( 30 | num_envs=8, 31 | num_steps=16, 32 | timesteps_per_proc=8 * 16 * 2, 33 | save_interval=2, 34 | save_dir=training_dir, 35 | **env_kwargs, 36 | ) 37 | checkpoint_path = os.path.join(training_dir, "checkpoint.jd") 38 | print(f"Checkpoint saved to: {checkpoint_path}") 39 | 40 | print("Generating interface...") 41 | # generate a small interface, to demonstrate 42 | rl_clarity.run( 43 | checkpoint_path, 44 | output_dir=interface_dir, 45 | trajectories_kwargs={"num_envs": 8, "num_steps": 16}, 46 | observations_kwargs={"num_envs": 8, "num_obs": 4, "obs_every": 4}, 47 | layer_kwargs={"name_contains_one_of": ["2b"]}, 48 | ) 49 | 50 | interface_path = os.path.join(interface_dir, "interface.html") 51 | interface_url = ("" if "://" in interface_path else "file://") + interface_path 52 | print(f"Interface URL: {interface_url}") 53 | 54 | 55 | def main(): 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("env", nargs='?', default="coinrun_old") 58 | parser.add_argument("-p", "--path") 59 | args = parser.parse_args() 60 | train_and_run(args.env, base_path=args.path) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/interface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import urllib 4 | import subprocess 5 | import tempfile 6 | from collections import OrderedDict 7 | from lucid.modelzoo.vision_base import Model 8 | from lucid.misc.io import save 9 | from lucid.misc.io.writing import write_handle 10 | from lucid.scratch.rl_util.nmf import LayerNMF, rescale_opacity 11 | from lucid.scratch.rl_util.attribution import ( 12 | get_acts, 13 | default_score_fn, 14 | get_grad, 15 | get_attr, 16 | get_multi_path_attr, 17 | ) 18 | from lucid.scratch.rl_util.util import ( 19 | get_shape, 20 | concatenate_horizontally, 21 | channels_to_rgb, 22 | conv2d, 23 | norm_filter, 24 | brightness_to_opacity, 25 | ) 26 | from .compiling import get_compiled_js 27 | from ..svelte3 import compile_html 28 | 29 | 30 | def get_model(model_bytes): 31 | model_fd, model_path = tempfile.mkstemp(suffix=".model.pb") 32 | with open(model_fd, "wb") as model_file: 33 | model_file.write(model_bytes) 34 | return Model.load(model_path) 35 | 36 | 37 | def exists_path(model, *, from_names, to_names, without_names=[]): 38 | for from_name in from_names: 39 | if from_name not in without_names: 40 | if from_name in to_names: 41 | return True 42 | next_names = [ 43 | name.rsplit(":", 1)[0] 44 | for node in model.graph_def.node 45 | if node.name == from_name 46 | for name in node.input 47 | ] 48 | if exists_path( 49 | model, 50 | from_names=next_names, 51 | to_names=to_names, 52 | without_names=without_names, 53 | ): 54 | return True 55 | return False 56 | 57 | 58 | def longest_common_prefix(l): 59 | l = set([s[: min(map(len, l))] for s in l]) 60 | while len(l) > 1: 61 | l = set([s[:-1] for s in l]) 62 | return list(l)[0] 63 | 64 | 65 | def longest_common_suffix(l): 66 | l = set([s[-min(map(len, l)) :] for s in l]) 67 | while len(l) > 1: 68 | l = set([s[1:] for s in l]) 69 | return list(l)[0] 70 | 71 | 72 | def get_abbreviator(names): 73 | if len(names) <= 1: 74 | return slice(None, None) 75 | prefix = longest_common_prefix(names) 76 | prefix = prefix.rsplit("/", 1)[0] + "/" if "/" in prefix else "" 77 | suffix = longest_common_suffix(names) 78 | suffix = "/" + suffix.split("/", 1)[-1] if "/" in suffix else "" 79 | return slice(len(prefix), None if len(suffix) == 0 else -len(suffix)) 80 | 81 | 82 | def get_layer_names( 83 | model, 84 | output_names, 85 | *, 86 | name_contains_one_of, 87 | op_is_one_of, 88 | bottleneck_only, 89 | discard_first_n, 90 | ): 91 | if isinstance(name_contains_one_of, str): 92 | name_contains_one_of = [name_contains_one_of] 93 | if isinstance(op_is_one_of, str): 94 | name_contains_one_of = [op_is_one_of] 95 | 96 | nodes = model.graph_def.node 97 | 98 | shape_condition = lambda node: len(get_shape(model, node.name)) >= 4 99 | op_condition = lambda node: any(node.op.lower() == s.lower() for s in op_is_one_of) 100 | if bottleneck_only: 101 | bottleneck_names = [ 102 | node.name 103 | for node in nodes 104 | if not exists_path( 105 | model, 106 | from_names=output_names, 107 | to_names=[model.input_name], 108 | without_names=[node.name], 109 | ) 110 | ] 111 | conv_names = [node.name for node in nodes if node.op.lower()[:4] == "conv"] 112 | bottleneck_condition = lambda node: not exists_path( 113 | model, 114 | from_names=[node.name], 115 | to_names=conv_names, 116 | without_names=bottleneck_names, 117 | ) 118 | else: 119 | bottleneck_condition = lambda node: True 120 | 121 | layer_names = [ 122 | node.name 123 | for node in nodes 124 | if shape_condition(node) and op_condition(node) and bottleneck_condition(node) 125 | ] 126 | abbreviator = get_abbreviator(layer_names) 127 | 128 | if name_contains_one_of is None: 129 | name_condition = lambda name: True 130 | else: 131 | name_condition = lambda name: any(s in name for s in name_contains_one_of) 132 | 133 | return OrderedDict( 134 | [(name[abbreviator], name) for name in layer_names if name_condition(name)][ 135 | discard_first_n: 136 | ] 137 | ) 138 | 139 | 140 | def batched_get(data, batch_size, process_minibatch): 141 | n_points = data.shape[0] 142 | n_minibatches = -((-n_points) // batch_size) 143 | return np.concatenate( 144 | [ 145 | process_minibatch(data[i * batch_size : (i + 1) * batch_size]) 146 | for i in range(n_minibatches) 147 | ], 148 | axis=0, 149 | ) 150 | 151 | 152 | def compute_gae(trajectories, *, gae_gamma, gae_lambda): 153 | values = trajectories["values"] 154 | next_values = values[:, 1:] 155 | rewards = trajectories["rewards"][:, :-1] 156 | try: 157 | dones = trajectories["dones"][:, :-1] 158 | except KeyError: 159 | dones = trajectories["firsts"][:, 1:] 160 | assert next_values.shape == rewards.shape == dones.shape 161 | deltas = rewards + (1 - dones) * gae_gamma * next_values - values[:, :-1] 162 | result = np.zeros(values.shape, values.dtype) 163 | for step in reversed(range(values.shape[1] - 1)): 164 | result[:, step] = ( 165 | deltas[:, step] 166 | + (1 - dones[:, step]) * gae_gamma * gae_lambda * result[:, step + 1] 167 | ) 168 | return result 169 | 170 | 171 | def get_bookmarks(trajectories, *, sign, num): 172 | advantages = trajectories["advantages"] 173 | dones = trajectories["dones"].copy() 174 | dones[:, -1] = np.ones_like(dones[:, -1]) 175 | high_scores_and_coords = [] 176 | for trajectory in range(advantages.shape[0]): 177 | high_score = 0 178 | high_score_coords = None 179 | for step in range(advantages.shape[1]): 180 | score = advantages[trajectory][step] * sign 181 | if score > high_score: 182 | high_score = score 183 | high_score_coords = (trajectory, step) 184 | if dones[trajectory][step] and high_score_coords is not None: 185 | high_scores_and_coords.append((high_score, high_score_coords)) 186 | high_score = 0 187 | high_score_coords = None 188 | high_scores_and_coords.sort(key=lambda x: -x[0]) 189 | return list(map(lambda x: x[1], high_scores_and_coords[:num])) 190 | 191 | 192 | def number_to_string(x): 193 | s = str(x) 194 | if s.endswith(".0"): 195 | s = s[:-2] 196 | return "".join([c for c in s if c.isdigit() or c == "e"]) 197 | 198 | 199 | def get_html_colors(n, grayscale=False, mix_with=None, mix_weight=0.5, **kwargs): 200 | if grayscale: 201 | colors = np.linspace(0, 1, n)[..., None].repeat(3, axis=1) 202 | else: 203 | colors = channels_to_rgb(np.eye(n), **kwargs) 204 | colors = colors / colors.max(axis=-1, keepdims=True) 205 | if mix_with is not None: 206 | colors = colors * (1 - mix_weight) + mix_with[None] * mix_weight 207 | colors = np.round(colors * 255) 208 | colors = np.vectorize(lambda x: hex(x)[2:].zfill(2))(colors.astype(int)) 209 | return ["#" + "".join(color) for color in colors] 210 | 211 | 212 | def removeprefix(s, prefix): 213 | if s.startswith(prefix): 214 | return s[len(prefix) :] 215 | return s 216 | 217 | 218 | def generate( 219 | *, 220 | output_dir, 221 | model_bytes, 222 | observations, 223 | observations_full=None, 224 | trajectories, 225 | policy_logits_name, 226 | value_function_name, 227 | env_name=None, 228 | numpy_precision=6, 229 | inline_js=True, 230 | inline_large_json=None, 231 | batch_size=512, 232 | action_combos=None, 233 | action_group_fns=[ 234 | lambda combo: "RIGHT" in combo, 235 | lambda combo: "LEFT" in combo, 236 | lambda combo: "UP" in combo, 237 | lambda combo: "DOWN" in combo, 238 | lambda combo: "RIGHT" not in combo 239 | and "LEFT" not in combo 240 | and "UP" not in combo 241 | and "DOWN" not in combo, 242 | ], 243 | layer_kwargs={}, 244 | input_layer_include=False, 245 | input_layer_name="input", 246 | gae_gamma=None, 247 | gae_lambda=None, 248 | trajectory_bookmarks=16, 249 | nmf_features=8, 250 | nmf_attr_opts=None, 251 | vis_subdiv_mults=[0.25, 0.5, 1, 2], 252 | vis_subdiv_mult_default=1, 253 | vis_expand_mults=[1, 2, 4, 8], 254 | vis_expand_mult_default=4, 255 | vis_thumbnail_num_mult=4, 256 | vis_thumbnail_expand_mult=4, 257 | scrub_range=(42 / 64, 44 / 64), 258 | attr_integrate_steps=10, 259 | attr_max_paths=None, 260 | attr_policy=False, 261 | attr_single_channels=True, 262 | observations_subdir="observations/", 263 | trajectories_subdir="trajectories/", 264 | trajectories_scrub_subdir="trajectories_scrub/", 265 | features_subdir="features/", 266 | thumbnails_subdir="thumbnails/", 267 | attribution_subdir="attribution/", 268 | attribution_scrub_subdir="attribution_scrub/", 269 | features_grids_subdir="features_grids/", 270 | attribution_totals_subdir="attribution_totals/", 271 | video_height="16em", 272 | video_width="16em", 273 | video_speed=12, 274 | policy_display_height="2em", 275 | policy_display_width="40em", 276 | navigator_width="24em", 277 | scrubber_height="4em", 278 | scrubber_width="48em", 279 | scrubber_visible_duration=256, 280 | legend_item_height="6em", 281 | legend_item_width="6em", 282 | feature_viewer_height="40em", 283 | feature_viewer_width="40em", 284 | attribution_weight=0.9, 285 | graph_colors={ 286 | "v": "green", 287 | "action": "red", 288 | "action_group": "orange", 289 | "advantage": "blue", 290 | }, 291 | trajectory_color="blue", 292 | ): 293 | from mpi4py import MPI 294 | 295 | comm = MPI.COMM_WORLD 296 | rank = comm.Get_rank() 297 | 298 | model = get_model(model_bytes) 299 | if rank == 0: 300 | js_source_path = get_compiled_js() 301 | 302 | if env_name is None: 303 | env_name = "unknown" 304 | if inline_large_json is None: 305 | inline_large_json = "://" not in output_dir 306 | layer_kwargs.setdefault("name_contains_one_of", None) 307 | layer_kwargs.setdefault("op_is_one_of", ["relu"]) 308 | layer_kwargs.setdefault("bottleneck_only", True) 309 | layer_kwargs.setdefault("discard_first_n", 0) 310 | if observations_full is None: 311 | observations_full = observations 312 | if "observations_full" not in trajectories: 313 | trajectories["observations_full"] = trajectories["observations"] 314 | if np.issubdtype(observations.dtype, np.integer): 315 | observations = observations / np.float32(255) 316 | if np.issubdtype(observations_full.dtype, np.integer): 317 | observations_full = observations_full / np.float32(255) 318 | if np.issubdtype(trajectories["observations"].dtype, np.integer): 319 | trajectories["observations"] = trajectories["observations"] / np.float32(255) 320 | if np.issubdtype(trajectories["observations_full"].dtype, np.integer): 321 | trajectories["observations_full"] = trajectories[ 322 | "observations_full" 323 | ] / np.float32(255) 324 | if action_combos is None: 325 | num_actions = get_shape(model, policy_logits_name)[-1] 326 | action_combos = list(map(lambda x: (str(x),), range(num_actions))) 327 | if env_name == "coinrun_old": 328 | action_combos = [ 329 | (), 330 | ("RIGHT",), 331 | ("LEFT",), 332 | ("UP",), 333 | ("RIGHT", "UP"), 334 | ("LEFT", "UP"), 335 | ("DOWN",), 336 | ("A",), 337 | ("B",), 338 | ][:num_actions] 339 | if gae_gamma is None: 340 | gae_gamma = 0.999 341 | if gae_lambda is None: 342 | gae_lambda = 0.95 343 | 344 | layer_names = get_layer_names( 345 | model, [policy_logits_name, value_function_name], **layer_kwargs 346 | ) 347 | if not layer_names: 348 | raise RuntimeError( 349 | "No appropriate layers found. " 350 | "Please adapt layer_kwargs to your architecture" 351 | ) 352 | squash = lambda s: s.replace("/", "").replace("_", "") 353 | if len(set([squash(layer_key) for layer_key in layer_names.keys()])) < len( 354 | layer_names 355 | ): 356 | raise RuntimeError( 357 | "Error squashing abbreviated layer names. " 358 | "Different substitutions must be used" 359 | ) 360 | mpi_enumerate = lambda l: ( 361 | lambda indices: list(enumerate(l))[indices[rank] : indices[rank + 1]] 362 | )(np.linspace(0, len(l), comm.Get_size() + 1).astype(int)) 363 | save_image = lambda image, path: save( 364 | image, os.path.join(output_dir, path), domain=(0, 1) 365 | ) 366 | save_images = lambda images, path: save_image( 367 | concatenate_horizontally(images), path 368 | ) 369 | json_preloaded = {} 370 | save_json = lambda data, path: ( 371 | json_preloaded.update({path: data}) 372 | if inline_large_json 373 | else save(data, os.path.join(output_dir, path), indent=None) 374 | ) 375 | get_scrub_slice = lambda width: slice( 376 | int(np.round(scrub_range[0] * width)), 377 | int( 378 | np.maximum( 379 | np.round(scrub_range[1] * width), np.round(scrub_range[0] * width) + 1 380 | ) 381 | ), 382 | ) 383 | action_groups = [ 384 | [action for action, combo in enumerate(action_combos) if group_fn(combo)] 385 | for group_fn in action_group_fns 386 | ] 387 | action_groups = list( 388 | filter(lambda action_group: len(action_group) > 1, action_groups) 389 | ) 390 | 391 | for index, observation in mpi_enumerate(observations_full): 392 | observation_path = os.path.join(observations_subdir, f"{index}.png") 393 | save_image(observation, observation_path) 394 | for index, trajectory_observations in mpi_enumerate( 395 | trajectories["observations_full"] 396 | ): 397 | trajectory_path = os.path.join(trajectories_subdir, f"{index}.png") 398 | save_images(trajectory_observations, trajectory_path) 399 | scrub_slice = get_scrub_slice(trajectory_observations.shape[2]) 400 | scrub = trajectory_observations[:, :, scrub_slice, :] 401 | scrub_path = os.path.join(trajectories_scrub_subdir, f"{index}.png") 402 | save_images(scrub, scrub_path) 403 | 404 | trajectories["policy_logits"] = [] 405 | trajectories["values"] = [] 406 | for trajectory_observations in trajectories["observations"]: 407 | trajectories["policy_logits"].append( 408 | batched_get( 409 | trajectory_observations, 410 | batch_size, 411 | lambda minibatch: get_acts(model, policy_logits_name, minibatch), 412 | ) 413 | ) 414 | trajectories["values"].append( 415 | batched_get( 416 | trajectory_observations, 417 | batch_size, 418 | lambda minibatch: get_acts(model, value_function_name, minibatch), 419 | ) 420 | ) 421 | trajectories["policy_logits"] = np.array(trajectories["policy_logits"]) 422 | trajectories["values"] = np.array(trajectories["values"]) 423 | trajectories["advantages"] = compute_gae( 424 | trajectories, gae_gamma=gae_gamma, gae_lambda=gae_lambda 425 | ) 426 | if "dones" not in trajectories: 427 | trajectories["dones"] = np.concatenate( 428 | [ 429 | trajectories["firsts"][:, 1:], 430 | np.zeros_like(trajectories["firsts"][:, :1]), 431 | ], 432 | axis=-1, 433 | ) 434 | 435 | bookmarks = { 436 | "high": get_bookmarks(trajectories, sign=1, num=trajectory_bookmarks), 437 | "low": get_bookmarks(trajectories, sign=-1, num=trajectory_bookmarks), 438 | } 439 | 440 | nmf_kwargs = {"attr_layer_name": value_function_name} 441 | if nmf_attr_opts is not None: 442 | nmf_kwargs["attr_opts"] = nmf_attr_opts 443 | nmfs = { 444 | layer_key: LayerNMF( 445 | model, 446 | layer_name, 447 | observations, 448 | obses_full=observations_full, 449 | features=nmf_features, 450 | **nmf_kwargs, 451 | ) 452 | for layer_key, layer_name in layer_names.items() 453 | } 454 | 455 | features = [] 456 | attributions = [] 457 | attribution_totals = [] 458 | 459 | for layer_key, layer_name in layer_names.items(): 460 | nmf = nmfs[layer_key] 461 | 462 | if rank == 0: 463 | thumbnails = [] 464 | for number in range(nmf.features): 465 | thumbnail = nmf.vis_dataset_thumbnail( 466 | number, 467 | num_mult=vis_thumbnail_num_mult, 468 | expand_mult=vis_thumbnail_expand_mult, 469 | )[0] 470 | thumbnail = rescale_opacity(thumbnail, max_scale=1, keep_zeros=True) 471 | thumbnails.append(thumbnail) 472 | thumbnails_path = os.path.join( 473 | thumbnails_subdir, f"{squash(layer_key)}.png" 474 | ) 475 | save_images(thumbnails, thumbnails_path) 476 | 477 | for _, number in mpi_enumerate(range(nmf.features)): 478 | feature = { 479 | "layer": layer_key, 480 | "number": number, 481 | "images": [], 482 | "overlay_grids": [], 483 | "metadata": {"subdiv_mult": [], "expand_mult": []}, 484 | } 485 | for subdiv_mult in vis_subdiv_mults: 486 | for expand_mult in vis_expand_mults: 487 | image, overlay_grid = nmf.vis_dataset( 488 | number, subdiv_mult=subdiv_mult, expand_mult=expand_mult 489 | ) 490 | image = rescale_opacity(image) 491 | filename_root = ( 492 | f"{squash(layer_key)}_" 493 | f"feature{number}_" 494 | f"{number_to_string(subdiv_mult)}_" 495 | f"{number_to_string(expand_mult)}" 496 | ) 497 | image_filename = filename_root + ".png" 498 | overlay_grid_filename = filename_root + ".json" 499 | image_path = os.path.join(features_subdir, image_filename) 500 | overlay_grid_path = os.path.join( 501 | features_grids_subdir, overlay_grid_filename 502 | ) 503 | save_image(image, image_path) 504 | save_json(overlay_grid, overlay_grid_path) 505 | feature["images"].append(image_filename) 506 | feature["overlay_grids"].append(overlay_grid_filename) 507 | feature["metadata"]["subdiv_mult"].append(subdiv_mult) 508 | feature["metadata"]["expand_mult"].append(expand_mult) 509 | features.append(feature) 510 | 511 | for layer_key, layer_name in ( 512 | [(input_layer_name, None)] if input_layer_include else [] 513 | ) + list(layer_names.items()): 514 | if layer_name is None: 515 | nmf = None 516 | else: 517 | nmf = nmfs[layer_key] 518 | 519 | for index, trajectory_observations in mpi_enumerate( 520 | trajectories["observations"] 521 | ): 522 | attribution = { 523 | "layer": layer_key, 524 | "trajectory": index, 525 | "images": [], 526 | "metadata": {"type": [], "data": [], "direction": [], "channel": []}, 527 | } 528 | if layer_name is not None: 529 | totals = { 530 | "layer": layer_key, 531 | "trajectory": index, 532 | "channels": [], 533 | "residuals": [], 534 | "metadata": {"type": [], "data": []}, 535 | } 536 | 537 | def get_attr_minibatch( 538 | minibatch, output_name, *, score_fn=default_score_fn 539 | ): 540 | if layer_name is None: 541 | return get_grad(model, output_name, minibatch, score_fn=score_fn) 542 | elif attr_max_paths is None: 543 | return get_attr( 544 | model, 545 | output_name, 546 | layer_name, 547 | minibatch, 548 | score_fn=score_fn, 549 | integrate_steps=attr_integrate_steps, 550 | ) 551 | else: 552 | return get_multi_path_attr( 553 | model, 554 | output_name, 555 | layer_name, 556 | minibatch, 557 | nmf, 558 | score_fn=score_fn, 559 | integrate_steps=attr_integrate_steps, 560 | max_paths=attr_max_paths, 561 | ) 562 | 563 | def get_attr_batched(output_name, *, score_fn=default_score_fn): 564 | return batched_get( 565 | trajectory_observations, 566 | batch_size, 567 | lambda minibatch: get_attr_minibatch( 568 | minibatch, output_name, score_fn=score_fn 569 | ), 570 | ) 571 | 572 | def transform_attr(attr): 573 | if layer_name is None: 574 | return attr, None 575 | else: 576 | attr_trans = nmf.transform(np.maximum(attr, 0)) - nmf.transform( 577 | np.maximum(-attr, 0) 578 | ) 579 | attr_res = ( 580 | attr 581 | - ( 582 | nmf.inverse_transform(np.maximum(attr_trans, 0)) 583 | - nmf.inverse_transform(np.maximum(-attr_trans, 0)) 584 | ) 585 | ).sum(-1, keepdims=True) 586 | nmf_norms = nmf.channel_dirs.sum(-1) 587 | return attr_trans * nmf_norms[None, None, None], attr_res 588 | 589 | def save_attr(attr, attr_res, *, type_, data): 590 | if attr_res is None: 591 | attr_res = np.zeros_like(attr).sum(-1, keepdims=True) 592 | filename_root = f"{squash(layer_key)}_{index}_{type_}" 593 | if data is not None: 594 | filename_root = f"{filename_root}_{data}" 595 | if layer_name is not None: 596 | channels_filename = f"{filename_root}_channels.json" 597 | residuals_filename = f"{filename_root}_residuals.json" 598 | channels_path = os.path.join( 599 | attribution_totals_subdir, channels_filename 600 | ) 601 | residuals_path = os.path.join( 602 | attribution_totals_subdir, residuals_filename 603 | ) 604 | save_json(attr.sum(-2).sum(-2), channels_path) 605 | save_json(attr_res[..., 0].sum(-1).sum(-1), residuals_path) 606 | totals["channels"].append(channels_filename) 607 | totals["residuals"].append(residuals_filename) 608 | totals["metadata"]["type"].append(type_) 609 | totals["metadata"]["data"].append(data) 610 | attr_scale = np.median(attr.max(axis=(-3, -2, -1))) 611 | if attr_scale == 0: 612 | attr_scale = attr.max() 613 | if attr_scale == 0: 614 | attr_scale = 1 615 | attr_scaled = attr / attr_scale 616 | attr_res_scaled = attr_res / attr_scale 617 | channels = ["prin", "all"] 618 | if attr_single_channels and layer_name is not None: 619 | channels += list(range(nmf.features)) + ["res"] 620 | for direction in ["abs", "pos", "neg"]: 621 | if direction == "abs": 622 | attr = np.abs(attr_scaled) 623 | attr_res = np.abs(attr_res_scaled) 624 | elif direction == "pos": 625 | attr = np.maximum(attr_scaled, 0) 626 | attr_res = np.maximum(attr_res_scaled, 0) 627 | elif direction == "neg": 628 | attr = np.maximum(-attr_scaled, 0) 629 | attr_res = np.maximum(-attr_res_scaled, 0) 630 | for channel in channels: 631 | if isinstance(channel, int): 632 | attr_single = attr.copy() 633 | attr_single[..., :channel] = 0 634 | attr_single[..., (channel + 1) :] = 0 635 | images = channels_to_rgb(attr_single) 636 | elif channel == "res": 637 | images = attr_res.repeat(3, axis=-1) 638 | else: 639 | images = channels_to_rgb(attr) 640 | if channel == "all": 641 | images += attr_res.repeat(3, axis=-1) 642 | images = brightness_to_opacity( 643 | conv2d(images, filter_=norm_filter(15)) 644 | ) 645 | suffix = f"{direction}_{channel}" 646 | images_filename = f"{filename_root}_{suffix}.png" 647 | images_path = os.path.join(attribution_subdir, images_filename) 648 | save_images(images, images_path) 649 | scrub = images[:, :, get_scrub_slice(images.shape[2]), :] 650 | scrub_path = os.path.join( 651 | attribution_scrub_subdir, images_filename 652 | ) 653 | save_images(scrub, scrub_path) 654 | attribution["images"].append(images_filename) 655 | attribution["metadata"]["type"].append(type_) 656 | attribution["metadata"]["data"].append(data) 657 | attribution["metadata"]["direction"].append(direction) 658 | attribution["metadata"]["channel"].append(channel) 659 | 660 | attr_v = get_attr_batched(value_function_name) 661 | attr_v_trans, attr_v_res = transform_attr(attr_v) 662 | save_attr(attr_v_trans, attr_v_res, type_="v", data=None) 663 | if attr_policy: 664 | attr_actions = np.array( 665 | [ 666 | get_attr_batched( 667 | policy_logits_name, score_fn=lambda t: t[..., action], 668 | ) 669 | for action in range(len(action_combos)) 670 | ] 671 | ) 672 | # attr_pi = attr_actions.sum(axis=-1).transpose( 673 | # (1, 2, 3, 0)) 674 | # attr_pi = np.concatenate([ 675 | # attr_pi[..., group].sum(axis=-1, keepdims=True) 676 | # for group in attr_action_groups 677 | # ], 678 | # axis=-1) 679 | # save_attr(attr_pi, None, type_='pi', data=None) 680 | for action, attr in enumerate(attr_actions): 681 | attr_trans, attr_res = transform_attr(attr) 682 | save_attr(attr_trans, attr_res, type_="action", data=action) 683 | for action_group, actions in enumerate(action_groups): 684 | attr = attr_actions[actions].sum(axis=0) 685 | attr_trans, attr_res = transform_attr(attr) 686 | save_attr( 687 | attr_trans, attr_res, type_="action_group", data=action_group 688 | ) 689 | attributions.append(attribution) 690 | if layer_name is not None: 691 | attribution_totals.append(totals) 692 | 693 | features = comm.gather(features, root=0) 694 | attributions = comm.gather(attributions, root=0) 695 | attribution_totals = comm.gather(attribution_totals, root=0) 696 | 697 | if rank == 0: 698 | features = [feature for l in features for feature in l] 699 | attributions = [attribution for l in attributions for attribution in l] 700 | attribution_totals = [totals for l in attribution_totals for totals in l] 701 | layer_keys = ([input_layer_name] if input_layer_include else []) + list( 702 | layer_names.keys() 703 | ) 704 | action_colors = get_html_colors( 705 | len(action_combos), 706 | grayscale=True, 707 | mix_with=np.array([0.75, 0.75, 0.75]), 708 | mix_weight=0.25, 709 | ) 710 | props = { 711 | "input_layer": input_layer_name, 712 | "layers": layer_keys, 713 | "features": features, 714 | "attributions": attributions, 715 | "attribution_policy": attr_policy, 716 | "attribution_single_channels": attr_single_channels, 717 | "attribution_totals": attribution_totals, 718 | "colors": { 719 | "features": get_html_colors(nmf_features), 720 | "actions": action_colors, 721 | "graphs": graph_colors, 722 | "trajectory": trajectory_color, 723 | }, 724 | "action_combos": action_combos, 725 | "action_groups": action_groups, 726 | "trajectories": { 727 | "actions": trajectories["actions"], 728 | "rewards": trajectories["rewards"], 729 | "dones": trajectories["dones"], 730 | "policy_logits": trajectories["policy_logits"], 731 | "values": trajectories["values"], 732 | "advantages": trajectories["advantages"], 733 | }, 734 | "bookmarks": bookmarks, 735 | "vis_defaults": { 736 | "subdiv_mult": vis_subdiv_mult_default, 737 | "expand_mult": vis_expand_mult_default, 738 | }, 739 | "subdirs": { 740 | "observations": observations_subdir, 741 | "trajectories": trajectories_subdir, 742 | "trajectories_scrub": trajectories_scrub_subdir, 743 | "features": features_subdir, 744 | "thumbnails": thumbnails_subdir, 745 | "attribution": attribution_subdir, 746 | "attribution_scrub": attribution_scrub_subdir, 747 | "features_grids": features_grids_subdir, 748 | "attribution_totals": attribution_totals_subdir, 749 | }, 750 | "formatting": { 751 | "video_height": video_height, 752 | "video_width": video_width, 753 | "video_speed": video_speed, 754 | "policy_display_height": policy_display_height, 755 | "policy_display_width": policy_display_width, 756 | "navigator_width": navigator_width, 757 | "scrubber_height": scrubber_height, 758 | "scrubber_width": scrubber_width, 759 | "scrubber_visible_duration": scrubber_visible_duration, 760 | "legend_item_height": legend_item_height, 761 | "legend_item_width": legend_item_width, 762 | "feature_viewer_height": feature_viewer_height, 763 | "feature_viewer_width": feature_viewer_width, 764 | "attribution_weight": attribution_weight, 765 | }, 766 | "json_preloaded": json_preloaded, 767 | } 768 | 769 | if inline_js: 770 | js_path = js_source_path 771 | else: 772 | with open(js_source_path, "r") as fp: 773 | js_code = fp.read() 774 | js_path = os.path.join(output_dir, "interface.js") 775 | with write_handle(js_path, "w") as fp: 776 | fp.write(js_code) 777 | html_path = os.path.join(output_dir, "interface.html") 778 | compile_html( 779 | js_path, 780 | html_path=html_path, 781 | props=props, 782 | precision=numpy_precision, 783 | inline_js=inline_js, 784 | svelte_to_js=False, 785 | ) 786 | if output_dir.startswith("gs://"): 787 | if not inline_js: 788 | subprocess.run( 789 | [ 790 | "gsutil", 791 | "setmeta", 792 | "-h", 793 | "Content-Type: text/javascript", 794 | js_path, 795 | ] 796 | ) 797 | subprocess.run( 798 | ["gsutil", "setmeta", "-h", "Content-Type: text/html", html_path] 799 | ) 800 | elif output_dir.startswith("https://"): 801 | output_dir_parsed = urllib.parse.urlparse(output_dir) 802 | az_account, az_hostname = output_dir_parsed.netloc.split(".", 1) 803 | if az_hostname == "blob.core.windows.net": 804 | az_container = removeprefix(output_dir_parsed.path, "/").split("/")[0] 805 | az_prefix = f"https://{az_account}.{az_hostname}/{az_container}/" 806 | if not inline_js: 807 | js_az_name = removeprefix(js_path, az_prefix) 808 | subprocess.run( 809 | [ 810 | "az", 811 | "storage", 812 | "blob", 813 | "update", 814 | "--container-name", 815 | az_container, 816 | "--name", 817 | js_az_name, 818 | "--account-name", 819 | az_account, 820 | "--content-type", 821 | "application/javascript", 822 | ] 823 | ) 824 | html_az_name = removeprefix(html_path, az_prefix) 825 | subprocess.run( 826 | [ 827 | "az", 828 | "storage", 829 | "blob", 830 | "update", 831 | "--container-name", 832 | az_container, 833 | "--name", 834 | html_az_name, 835 | "--account-name", 836 | az_account, 837 | "--content-type", 838 | "text/html", 839 | ] 840 | ) 841 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/loading.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from contextlib import contextmanager 4 | import os 5 | import re 6 | import tempfile 7 | from lucid.modelzoo.vision_base import Model 8 | from lucid.misc.io.reading import read 9 | from lucid.scratch.rl_util.joblib_wrapper import load_joblib, save_joblib 10 | from .training import create_env, get_arch 11 | 12 | 13 | def load_params(params, *, sess): 14 | var_list = tf.global_variables() 15 | for name, var_value in params.items(): 16 | matching_vars = [var for var in var_list if var.name == name] 17 | if matching_vars: 18 | matching_vars[0].load(var_value, sess) 19 | 20 | 21 | def save_lucid_model(config, params, *, model_path, metadata_path): 22 | config = config.copy() 23 | config.pop("num_envs") 24 | library = config.get("library", "baselines") 25 | venv = create_env(1, **config) 26 | arch = get_arch(**config) 27 | 28 | with tf.Graph().as_default(), tf.Session() as sess: 29 | observation_space = venv.observation_space 30 | observations_placeholder = tf.placeholder( 31 | shape=(None,) + observation_space.shape, dtype=tf.float32 32 | ) 33 | 34 | if library == "baselines": 35 | from baselines.common.policies import build_policy 36 | 37 | with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE): 38 | policy_fn = build_policy(venv, arch) 39 | policy = policy_fn( 40 | nbatch=None, 41 | nsteps=1, 42 | sess=sess, 43 | observ_placeholder=(observations_placeholder * 255), 44 | ) 45 | pd = policy.pd 46 | vf = policy.vf 47 | 48 | else: 49 | raise ValueError(f"Unsupported library: {library}") 50 | 51 | load_params(params, sess=sess) 52 | 53 | Model.save( 54 | model_path, 55 | input_name=observations_placeholder.op.name, 56 | output_names=[pd.logits.op.name, vf.op.name], 57 | image_shape=observation_space.shape, 58 | image_value_range=[0.0, 1.0], 59 | ) 60 | 61 | metadata = { 62 | "policy_logits_name": pd.logits.op.name, 63 | "value_function_name": vf.op.name, 64 | "env_name": config.get("env_name"), 65 | "gae_gamma": config.get("gamma"), 66 | "gae_lambda": config.get("lambda"), 67 | } 68 | env = venv 69 | while hasattr(env, "env") and (not hasattr(env, "combos")): 70 | env = env.env 71 | if hasattr(env, "combos"): 72 | metadata["action_combos"] = env.combos 73 | else: 74 | metadata["action_combos"] = None 75 | 76 | save_joblib(metadata, metadata_path) 77 | return {"model_bytes": read(model_path, cache=False, mode="rb"), **metadata} 78 | 79 | 80 | @contextmanager 81 | def get_step_fn(config, params, *, num_envs, full_resolution): 82 | config = config.copy() 83 | config.pop("num_envs") 84 | library = config.get("library", "baselines") 85 | venv = create_env(num_envs, **config) 86 | arch = get_arch(**config) 87 | 88 | with tf.Graph().as_default(), tf.Session() as sess: 89 | if library == "baselines": 90 | from baselines.common.policies import build_policy 91 | 92 | with tf.variable_scope("ppo2_model", reuse=tf.AUTO_REUSE): 93 | policy_fn = build_policy(venv, arch) 94 | policy = policy_fn(nbatch=venv.num_envs, nsteps=1, sess=sess) 95 | 96 | stepdata = { 97 | "ob": venv.reset(), 98 | "state": policy.initial_state, 99 | "first": np.ones((venv.num_envs,), bool), 100 | } 101 | if full_resolution: 102 | stepdata["ob_full"] = np.stack( 103 | [info["rgb"] for info in venv.env.get_info()], axis=0 104 | ) 105 | 106 | def step_fn(): 107 | result = {"ob": stepdata["ob"], "first": stepdata["first"].astype(bool)} 108 | if full_resolution: 109 | result["ob_full"] = stepdata["ob_full"] 110 | result["ac"], _, stepdata["state"], _ = policy.step( 111 | stepdata["ob"], 112 | S=stepdata["state"], 113 | M=stepdata["first"].astype(float), 114 | ) 115 | ( 116 | stepdata["ob"], 117 | result["reward"], 118 | stepdata["first"], 119 | result["info"], 120 | ) = venv.step(result["ac"]) 121 | if full_resolution: 122 | stepdata["ob_full"] = np.stack( 123 | [info["rgb"] for info in result["info"]], axis=0 124 | ) 125 | return result 126 | 127 | else: 128 | raise ValueError(f"Unsupported library: {library}") 129 | 130 | load_params(params, sess=sess) 131 | 132 | yield step_fn 133 | 134 | 135 | def save_observations( 136 | config, params, *, observations_path, num_envs, num_obs, obs_every, full_resolution 137 | ): 138 | with get_step_fn( 139 | config, params, num_envs=num_envs, full_resolution=full_resolution 140 | ) as step_fn: 141 | observations = [] 142 | if full_resolution: 143 | observations_full = [] 144 | for _ in range(num_obs): 145 | for _ in range(obs_every): 146 | step_result = step_fn() 147 | observations.append(step_result["ob"]) 148 | if full_resolution: 149 | observations_full.append(step_result["ob_full"]) 150 | observations = np.concatenate(observations, axis=0) 151 | if full_resolution: 152 | observations_full = np.concatenate(observations_full, axis=0) 153 | 154 | result = {"observations": observations} 155 | if full_resolution: 156 | result["observations_full"] = observations_full 157 | save_joblib(result, observations_path) 158 | return result 159 | 160 | 161 | def save_trajectories( 162 | config, params, *, trajectories_path, num_envs, num_steps, full_resolution 163 | ): 164 | with get_step_fn( 165 | config, params, num_envs=num_envs, full_resolution=full_resolution 166 | ) as step_fn: 167 | step_fn() 168 | trajectories = [step_fn() for _ in range(num_steps)] 169 | get_and_stack = lambda ds, key, axis=1: np.stack( 170 | [d[key] for d in ds], axis=axis 171 | ) 172 | result = { 173 | "observations": get_and_stack(trajectories, "ob"), 174 | "actions": get_and_stack(trajectories, "ac"), 175 | "rewards": get_and_stack(trajectories, "reward"), 176 | "firsts": get_and_stack(trajectories, "first"), 177 | } 178 | if full_resolution: 179 | result["observations_full"] = get_and_stack(trajectories, "ob_full") 180 | 181 | save_joblib(result, trajectories_path) 182 | return {"trajectories": result} 183 | 184 | 185 | def load( 186 | checkpoint_path, 187 | *, 188 | resample=True, 189 | model_path=None, 190 | metadata_path=None, 191 | trajectories_path=None, 192 | observations_path=None, 193 | trajectories_kwargs={}, 194 | observations_kwargs={}, 195 | full_resolution=False, 196 | temp_files=False, 197 | ): 198 | if temp_files: 199 | default_path = lambda suffix: tempfile.mkstemp(suffix=suffix)[1] 200 | else: 201 | path_stem = re.split(r"(?<=[^/])\.[^/\.]*$", checkpoint_path)[0] 202 | path_stem = os.path.join( 203 | os.path.dirname(path_stem), "rl-clarity", os.path.basename(path_stem) 204 | ) 205 | default_path = lambda suffix: path_stem + suffix 206 | if model_path is None: 207 | model_path = default_path(".model.pb") 208 | if metadata_path is None: 209 | metadata_path = default_path(".metadata.jd") 210 | if trajectories_path is None: 211 | trajectories_path = default_path(".trajectories.jd") 212 | if observations_path is None: 213 | observations_path = default_path(".observations.jd") 214 | 215 | if resample: 216 | trajectories_kwargs.setdefault("num_envs", 8) 217 | trajectories_kwargs.setdefault("num_steps", 512) 218 | observations_kwargs.setdefault("num_envs", 32) 219 | observations_kwargs.setdefault("num_obs", 128) 220 | observations_kwargs.setdefault("obs_every", 128) 221 | 222 | checkpoint_dict = load_joblib(checkpoint_path, cache=False) 223 | config = checkpoint_dict["args"] 224 | if full_resolution: 225 | config["render_human"] = True 226 | if config.get("use_lstm", 0): 227 | raise ValueError("Recurrent networks not yet supported by this interface.") 228 | params = checkpoint_dict["params"] 229 | config["coinrun_old_extra_actions"] = 0 230 | if config.get("env_name") == "coinrun_old": 231 | # we may need to add extra actions depending on the size of the policy head 232 | policy_bias_keys = [ 233 | k for k in checkpoint_dict["params"] if k.endswith("pi/b:0") 234 | ] 235 | if policy_bias_keys: 236 | [policy_bias_key] = policy_bias_keys 237 | (num_actions,) = checkpoint_dict["params"][policy_bias_key].shape 238 | if num_actions == 9: 239 | config["coinrun_old_extra_actions"] = 2 240 | 241 | return { 242 | **save_lucid_model( 243 | config, params, model_path=model_path, metadata_path=metadata_path 244 | ), 245 | **save_observations( 246 | config, 247 | params, 248 | observations_path=observations_path, 249 | num_envs=observations_kwargs["num_envs"], 250 | num_obs=observations_kwargs["num_obs"], 251 | obs_every=observations_kwargs["obs_every"], 252 | full_resolution=full_resolution, 253 | ), 254 | **save_trajectories( 255 | config, 256 | params, 257 | trajectories_path=trajectories_path, 258 | num_envs=trajectories_kwargs["num_envs"], 259 | num_steps=trajectories_kwargs["num_steps"], 260 | full_resolution=full_resolution, 261 | ), 262 | } 263 | 264 | else: 265 | observations = load_joblib(observations_path, cache=False) 266 | if not isinstance(observations, dict): 267 | observations = {"observations": observations} 268 | return { 269 | "model_bytes": read(model_path, cache=False, mode="rb"), 270 | **observations, 271 | "trajectories": load_joblib(trajectories_path, cache=False), 272 | **load_joblib(metadata_path, cache=False), 273 | } 274 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/Makefile: -------------------------------------------------------------------------------- 1 | compile: 2 | python -c 'from understanding_rl_vision import rl_clarity; rl_clarity.recompile_js()' 3 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/attribution_selector.svelte: -------------------------------------------------------------------------------- 1 | 22 | 23 | 33 | 34 | 35 | 36 |

37 | 38 |

39 | 40 |

41 | policy logits:
42 | {#each action_htmls as action_html, action} 43 | 47 | 48 | {/each} 49 |

50 | 51 |

52 | sums of policy logits:
53 | {#each action_groups as actions, action_group} 54 | 58 | 59 | {/each} 60 |

61 | 62 | 63 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/attribution_viewer.svelte: -------------------------------------------------------------------------------- 1 | 205 | 206 | 214 | 215 |

216 | 217 | 218 | 219 |

220 | 221 | 222 | 223 | {#if attribution_policy} 224 | 225 | {/if} 226 | 227 | {#if attribution_abs} 228 | 229 | {:else} 230 | 231 | 232 | {/if} 233 | 234 | {#each attribution_kinds as attribution_kind, kind_index} 235 | {#if attribution_kind !== null} 236 | 237 | {#if attribution_policy} 238 | 245 | {/if} 246 | {#each attribution_options[kind_index] as options, options_index} 247 | {#if options !== null} 248 | 283 | {/if} 284 | {/each} 285 | 288 | 289 | {#if attribution_policy} 290 | 291 | 296 | 297 | {/if} 298 | {/if} 299 | {/each} 300 | {#if attribution_policy} 301 | 302 | 305 | 306 | {/if} 307 |
ObservationPositive and negative attributionPositive attributionNegative attribution
249 |
250 | 260 |
261 | 268 | {#if options_index === 0} 269 |
270 |
271 | 280 |
281 | {/if} 282 |
292 | 295 |
303 | 304 |
308 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/chart.svelte: -------------------------------------------------------------------------------- 1 | 79 | 80 | {#if values_to_display !== null} 81 |
82 |
88 |
0
96 | {#each Array.from(Array(ticks_count).keys()) as tick_number} 97 |
105 | {/each} 106 |
114 |
115 |
121 | {#each values_to_display as value, index} 122 |
132 | {/each} 133 |
141 | {#each Array.from(Array(ticks_count).keys()) as tick_number} 142 | {#if tick_number !== 0} 143 |
151 | {/if} 152 | {/each} 153 |
154 |
155 | {/if} 156 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/css_manipulate.js: -------------------------------------------------------------------------------- 1 | export const css_manipulate = function(css_value, numeric_fn) { 2 | let last_digit_index = css_value.search(/\d(?!.*\d.*)/); 3 | if (last_digit_index === -1) { 4 | return css_value; 5 | } 6 | else { 7 | let numeric_value = parseFloat(css_value.substring(0, last_digit_index + 1)); 8 | return numeric_fn(numeric_value).toString() + css_value.substring(last_digit_index + 1); 9 | } 10 | }; 11 | 12 | export const css_multiply = function(css_value, multiplier) { 13 | return css_manipulate(css_value, function(numeric_value) { 14 | return numeric_value * multiplier; 15 | }); 16 | }; 17 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/feature_viewer.svelte: -------------------------------------------------------------------------------- 1 | 165 | 166 | 183 | 184 | 185 | 186 |
199 | {#if images === null} 200 |
Select a feature
201 | {:else} 202 | {#each images as image, index} 203 |
209 | {/each} 210 |
216 | {/if} 217 |
218 | 219 |

Feature visualization

220 | 221 | 222 | {#each metadata_configs as config, config_index} 223 | 224 | {#if config_index === 0 && layer !== null && number !== null} 225 | 230 | {/if} 231 | 232 | 243 | 244 | 245 | 246 | {/each} 247 | 248 | 251 | 252 |
226 | Layer {layer}, feature {number + 1}
227 | Dataset examples by spatial position
228 | Click to view example, scroll to zoom
229 |
{config.text[0]} 233 | config.current_index = event.currentTarget.value} 240 | on:input={(event) => config.current_index = event.currentTarget.value} 241 | > 242 | {config.text[1]}
249 | 250 |
253 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/graph.svelte: -------------------------------------------------------------------------------- 1 | 56 | 57 | 74 | 75 |
76 |
77 | 88 | {#each Array.from(series.keys()) as index} 89 | {#each Array.from(Array(series[index].length - 1).keys()) as position} 90 | {#if !dones[index][position]} 91 | 98 | {:else} 99 | 106 | {/if} 107 | {/each} 108 | {/each} 109 | 110 |
111 |
112 | {#each Array.from(series.keys()) as index} 113 | {#if typeof(series[index][state.position]) !== "undefined"} 114 |
123 | {@html titles[index]} 124 |
125 |
133 | {@html format_number(series[index][state.position])} 134 |
135 | {/if} 136 | {/each} 137 |
138 |
139 | {#if scrubber} 140 | 141 | {/if} 142 |
143 |
144 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/interface.svelte: -------------------------------------------------------------------------------- 1 | 256 | 257 | 373 | 374 | show_feature_viewer = false}/> 375 | 376 | 382 | 383 |
384 | 385 |
386 | 387 |

Trajectories

388 |

389 | {#each Array.from(Array(num_trajectories).keys()) as trajectory} 390 | 394 |
395 | {/each} 396 |

397 | 398 |

Bookmarks

399 |

Lowest advantage
episodes
(unexpected failures):

400 |

401 | {#each bookmarks.low as bookmark, bookmark_index} 402 | {selected_attribution_id.trajectory = bookmark[0]; video_state.position = bookmark[1];}}> 403 | trajectory {bookmark[0] + 1}, frame {bookmark[1] + 1} 404 | 405 |
406 | {/each} 407 |

408 |

Highest advantage
episodes
(unexpected successes):

409 |

410 | {#each bookmarks.high as bookmark, bookmark_index} 411 | {selected_attribution_id.trajectory = bookmark[0]; video_state.position = bookmark[1];}}> 412 | trajectory {bookmark[0] + 1}, frame {bookmark[1] + 1} 413 | 414 |
415 | {/each} 416 |

417 | 418 |
419 | 420 |
421 |

Layers

422 |

423 | {#each layers as layer} 424 |
425 | {/each} 426 |

427 |
428 | 429 |
430 | 431 |

Timeline

432 | 433 | 445 | 446 | 452 | 453 |
459 | 460 |
468 | 469 | 478 | 479 | {#each graphs as graph} 480 | 490 | {/each} 491 | 492 |
493 | 494 |
495 | 496 |
497 | 498 |
499 | 500 |
501 | 502 |
506 |
515 |

{#if selected_attribution.layer !== input_layer}Attribution{:else}Gradients{/if}

516 | 541 |
542 | 543 |
544 | 545 |
546 | 547 | {#if selected_attribution.layer !== input_layer} 548 |

Attribution legend

549 |

550 | Click to expand feature 551 | {#if attribution_single_channels} 552 |
Hover to isolate 553 | {/if} 554 |

555 | {selected_feature_id = { layer: selected_attribution.layer, number: event.detail }; show_feature_viewer = true;}} 565 | /> 566 | {:else} 567 |

Gradients legend

568 |

Colors correspond
to input colors

569 | {/if} 570 | 571 |

Hotkeys

572 |

573 | go backwards
574 | go forwards
575 | toggle play/pause
576 |

577 | 578 |
579 | 580 |
581 | 582 |
event.stopPropagation()} 597 | > 598 | show_feature_viewer = false} 612 | /> 613 |
614 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/json_load.js: -------------------------------------------------------------------------------- 1 | // reduced version of lucid/scratch/js/src/load.js 2 | 3 | const active_requests = new Map(); 4 | const cache = new Map(); 5 | 6 | const handle_errors = function(response) { 7 | if (response.ok) { 8 | return response; 9 | } else { 10 | throw new Error(response.status + ':' + response.statusText); 11 | } 12 | }; 13 | 14 | const json_loader = function(url, json_preloaded) { 15 | if (typeof(json_preloaded) !== "undefined" && typeof(json_preloaded[url]) !== "undefined") { 16 | return new Promise((resolve) => { 17 | resolve(json_preloaded[url]); 18 | }); 19 | } 20 | else if (cache.has(url)) { 21 | return cache.get(url); 22 | } 23 | else { 24 | let promise = fetch(url).then(handle_errors).then(response => response.json()); 25 | cache.set(url, promise); 26 | return promise; 27 | } 28 | }; 29 | 30 | 31 | export const json_load = function(url, namespace, json_preloaded) { 32 | let request_id = 0; 33 | if (typeof(namespace) !== "undefined") { 34 | if (active_requests.has(namespace)){ 35 | request_id = active_requests.get(namespace) + 1; 36 | } 37 | active_requests.set(namespace, request_id); 38 | } 39 | return new Promise((resolve, reject) => { 40 | let promise; 41 | if (Array.isArray(url)) { 42 | promise = Promise.all(url.map((u) => json_loader(u, json_preloaded))); 43 | } 44 | else { 45 | promise = json_loader(url, json_preloaded); 46 | } 47 | promise.then((response) => { 48 | if (typeof(namespace) === "undefined" || active_requests.get(namespace) === request_id) { 49 | resolve(response); 50 | } 51 | }).catch((error) => { 52 | if (typeof(namespace) === "undefined" || active_requests.get(namespace) === request_id) { 53 | reject(error); 54 | } 55 | }); 56 | }); 57 | }; 58 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/legend.svelte: -------------------------------------------------------------------------------- 1 | 17 | 18 | 39 | 40 | {#each colors as color, index} 41 |
{dispatch('select', index); event.stopPropagation();}} 48 | on:mouseover={() => {if (enable_hover) {selected_channel = index;}}} 49 | on:mouseout={() => {if (enable_hover) {selected_channel = null;}}} 50 | > 51 |
57 |
{index + 1}
69 | {#if image !== null} 70 |
71 |
78 |
79 | {/if} 80 | {#if labels[index] !== null} 81 |
{labels[index]}
82 | {/if} 83 |
84 |
85 | {/each} 86 | 87 |
{if (enable_hover) {selected_channel = "res";}}} 94 | on:mouseout={() => {if (enable_hover) {selected_channel = null;}}} 95 | > 96 |
102 |
116 | {#if !(show_residual || selected_channel == 'res')} 117 | not
shown 118 | {/if} 119 |
120 |
121 | residual
(everything
else) 122 |
123 |
124 |
125 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/navigator.svelte: -------------------------------------------------------------------------------- 1 | 62 | 63 | 74 | 75 | 76 | 77 |
83 |
84 | 87 | 90 |
91 |
92 | 93 | 96 | 99 | 100 |
101 |
102 | fps 103 |
104 |
105 |
106 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/query.svelte: -------------------------------------------------------------------------------- 1 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/screen.svelte: -------------------------------------------------------------------------------- 1 | 38 | 39 | 52 | 53 |
54 |
55 |
56 |
65 |
66 | {#each images as image, index} 67 |
68 |
76 |
77 | {/each} 78 |
79 | {#if scrubber} 80 |
81 | 82 |
83 | {/if} 84 |
85 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/scrubber.svelte: -------------------------------------------------------------------------------- 1 | 30 | 31 | 47 | 48 | 49 | 50 |
56 |
62 |
63 |
64 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/svelte/trajectory_display.svelte: -------------------------------------------------------------------------------- 1 | 26 | 27 | 49 | 50 | 51 | 52 | 58 | 63 | 64 | 65 | 71 | 72 | 73 | 90 | 91 |
53 | {#if state.position > 0 && rewards[state.position - 1] !== 0} 54 | last reward: 55 | {rewards[state.position - 1]} 56 | {/if} 57 | 59 | {#if state.position > 0 && dones[state.position - 1]} 60 | new episode 61 | {/if} 62 | frame: {state.position + 1}policy: 66 | next action: 67 | 68 | {@html action_htmls[actions[state.position]]} 69 | 70 |
74 |
79 | {#each policy_probs as prob, action} 80 |
{@html action_htmls[action]}
87 | {/each} 88 |
89 |
92 | -------------------------------------------------------------------------------- /understanding_rl_vision/rl_clarity/training.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import tempfile 4 | import numpy as np 5 | from mpi4py import MPI 6 | import gym 7 | from baselines.common.vec_env import ( 8 | VecEnv, 9 | VecEnvWrapper, 10 | VecFrameStack, 11 | VecMonitor, 12 | VecNormalize, 13 | ) 14 | from baselines.common.mpi_util import setup_mpi_gpus 15 | from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv 16 | from baselines.common.atari_wrappers import make_atari, wrap_deepmind 17 | from lucid.scratch.rl_util import save_joblib 18 | 19 | PROCGEN_ENV_NAMES = [ 20 | "bigfish", 21 | "bossfight", 22 | "caveflyer", 23 | "chaser", 24 | "climber", 25 | "coinrun", 26 | "dodgeball", 27 | "fruitbot", 28 | "heist", 29 | "jumper", 30 | "leaper", 31 | "maze", 32 | "miner", 33 | "ninja", 34 | "plunder", 35 | "starpilot", 36 | ] 37 | 38 | PROCGEN_KWARG_KEYS = [ 39 | "num_levels", 40 | "start_level", 41 | "fixed_difficulty", 42 | "use_easy_jump", 43 | "paint_vel_info", 44 | "use_generated_assets", 45 | "use_monochrome_assets", 46 | "restrict_themes", 47 | "use_backgrounds", 48 | "plain_assets", 49 | "is_high_difficulty", 50 | "is_uniform_difficulty", 51 | "distribution_mode", 52 | "use_sequential_levels", 53 | "fix_background", 54 | "physics_mode", 55 | "debug_mode", 56 | "center_agent", 57 | "env_name", 58 | "game_type", 59 | "game_mechanics", 60 | "sample_game_mechanics", 61 | "render_human", 62 | ] 63 | 64 | ATARI_ENV_IDS = [ 65 | "AirRaid", 66 | "Alien", 67 | "Amidar", 68 | "Assault", 69 | "Asterix", 70 | "Asteroids", 71 | "Atlantis", 72 | "BankHeist", 73 | "BattleZone", 74 | "BeamRider", 75 | "Berzerk", 76 | "Bowling", 77 | "Boxing", 78 | "Breakout", 79 | "Carnival", 80 | "Centipede", 81 | "ChopperCommand", 82 | "CrazyClimber", 83 | "DemonAttack", 84 | "DoubleDunk", 85 | "ElevatorAction", 86 | "Enduro", 87 | "FishingDerby", 88 | "Freeway", 89 | "Frostbite", 90 | "Gopher", 91 | "Gravitar", 92 | "Hero", 93 | "IceHockey", 94 | "Jamesbond", 95 | "JourneyEscape", 96 | "Kangaroo", 97 | "Krull", 98 | "KungFuMaster", 99 | "MontezumaRevenge", 100 | "MsPacman", 101 | "NameThisGame", 102 | "Phoenix", 103 | "Pitfall", 104 | "Pong", 105 | "Pooyan", 106 | "PrivateEye", 107 | "Qbert", 108 | "Riverraid", 109 | "RoadRunner", 110 | "Robotank", 111 | "Seaquest", 112 | "Skiing", 113 | "Solaris", 114 | "SpaceInvaders", 115 | "StarGunner", 116 | "Tennis", 117 | "TimePilot", 118 | "Tutankham", 119 | "UpNDown", 120 | "Venture", 121 | "VideoPinball", 122 | "WizardOfWor", 123 | "YarsRevenge", 124 | "Zaxxon", 125 | ] 126 | 127 | ATARI_ENV_DICT = {envid.lower(): envid for envid in ATARI_ENV_IDS} 128 | 129 | 130 | class EpsilonGreedy(VecEnvWrapper): 131 | """ 132 | Overide with random actions with probability epsilon 133 | 134 | Args: 135 | epsilon: the probability actions will be overridden with random actions 136 | """ 137 | 138 | def __init__(self, venv: VecEnv, epsilon: float): 139 | super().__init__(venv) 140 | assert isinstance(self.action_space, gym.spaces.Discrete) or isinstance( 141 | self.action_space, gym.spaces.MultiBinary 142 | ) 143 | self.epsilon = epsilon 144 | 145 | def reset(self): 146 | return self.venv.reset() 147 | 148 | def step_async(self, actions): 149 | mask = np.random.uniform(size=self.num_envs) < self.epsilon 150 | new_actions = np.array( 151 | [ 152 | self.action_space.sample() if mask[i] else actions[i] 153 | for i in range(self.num_envs) 154 | ] 155 | ) 156 | self.venv.step_async(new_actions) 157 | 158 | def step_wait(self): 159 | return self.venv.step_wait() 160 | 161 | 162 | class VecRewardScale(VecEnvWrapper): 163 | """ 164 | Add `task_id` to the corresponding info dict of each environment 165 | in the provided VecEnv 166 | 167 | Args: 168 | venv: A set of environments 169 | task_ids: A list of task_ids corresponding to each environment in `venv` 170 | """ 171 | 172 | def __init__(self, venv: VecEnv, scale: float): 173 | super().__init__(venv) 174 | self._scale = scale 175 | 176 | def reset(self): 177 | return self.venv.reset() 178 | 179 | def step_wait(self): 180 | obs, rews, dones, infos = self.venv.step_wait() 181 | rews = rews * self._scale 182 | return obs, rews, dones, infos 183 | 184 | 185 | # our internal version of CoinRun old ended up with 2 additional actions, so 186 | # the pre-trained models require this wrapper. 187 | class VecExtraActions(VecEnvWrapper): 188 | def __init__(self, venv, *, extra_actions, default_action): 189 | assert isinstance(venv.action_space, gym.spaces.Discrete) 190 | super().__init__( 191 | venv, action_space=gym.spaces.Discrete(venv.action_space.n + extra_actions) 192 | ) 193 | self.default_action = default_action 194 | 195 | def reset(self): 196 | return self.venv.reset() 197 | 198 | def step_async(self, actions): 199 | actions = actions.copy() 200 | for i in range(len(actions)): 201 | if actions[i] >= self.venv.action_space.n: 202 | actions[i] = self.default_action 203 | self.venv.step_async(actions) 204 | 205 | def step_wait(self): 206 | return self.venv.step_wait() 207 | 208 | 209 | # hack to fix a bug caused by observations being modified in-place 210 | class VecShallowCopy(VecEnvWrapper): 211 | def step_async(self, actions): 212 | actions = actions.copy() 213 | self.venv.step_async(actions) 214 | 215 | def reset(self): 216 | obs = self.venv.reset() 217 | return obs.copy() 218 | 219 | def step_wait(self): 220 | obs, rews, dones, infos = self.venv.step_wait() 221 | return obs.copy(), rews.copy(), dones.copy(), infos.copy() 222 | 223 | 224 | coinrun_initialized = False 225 | 226 | 227 | def create_env( 228 | num_envs, 229 | *, 230 | env_kind="procgen", 231 | epsilon_greedy=0.0, 232 | reward_scale=1.0, 233 | frame_stack=1, 234 | use_sticky_actions=0, 235 | coinrun_old_extra_actions=0, 236 | **kwargs, 237 | ): 238 | if env_kind == "procgen": 239 | env_kwargs = {k: v for k, v in kwargs.items() if v is not None} 240 | env_name = env_kwargs.pop("env_name") 241 | 242 | if env_name == "coinrun_old": 243 | import coinrun 244 | from coinrun.config import Config 245 | 246 | Config.initialize_args(use_cmd_line_args=False, **env_kwargs) 247 | global coinrun_initialized 248 | if not coinrun_initialized: 249 | coinrun.init_args_and_threads() 250 | coinrun_initialized = True 251 | venv = coinrun.make("standard", num_envs) 252 | if coinrun_old_extra_actions > 0: 253 | venv = VecExtraActions( 254 | venv, extra_actions=coinrun_old_extra_actions, default_action=0 255 | ) 256 | 257 | else: 258 | from procgen import ProcgenGym3Env 259 | import gym3 260 | 261 | env_kwargs = { 262 | k: v for k, v in env_kwargs.items() if k in PROCGEN_KWARG_KEYS 263 | } 264 | env = ProcgenGym3Env(num_envs, env_name=env_name, **env_kwargs) 265 | env = gym3.ExtractDictObWrapper(env, "rgb") 266 | venv = gym3.ToBaselinesVecEnv(env) 267 | 268 | elif env_kind == "atari": 269 | game_version = "v0" if use_sticky_actions == 1 else "v4" 270 | 271 | def make_atari_env(lower_env_id, num_env): 272 | env_id = ATARI_ENV_DICT[lower_env_id] + f"NoFrameskip-{game_version}" 273 | 274 | def make_atari_env_fn(): 275 | env = make_atari(env_id) 276 | env = wrap_deepmind(env, frame_stack=False, clip_rewards=False) 277 | 278 | return env 279 | 280 | return SubprocVecEnv([make_atari_env_fn for i in range(num_env)]) 281 | 282 | lower_env_id = kwargs["env_id"] 283 | 284 | venv = make_atari_env(lower_env_id, num_envs) 285 | 286 | else: 287 | raise ValueError(f"Unsupported env_kind: {env_kind}") 288 | 289 | if frame_stack > 1: 290 | venv = VecFrameStack(venv=venv, nstack=frame_stack) 291 | 292 | if reward_scale != 1: 293 | venv = VecRewardScale(venv, reward_scale) 294 | 295 | venv = VecMonitor(venv=venv, filename=None, keep_buf=100) 296 | 297 | if epsilon_greedy > 0: 298 | venv = EpsilonGreedy(venv, epsilon_greedy) 299 | 300 | venv = VecShallowCopy(venv) 301 | 302 | return venv 303 | 304 | 305 | def get_arch( 306 | *, 307 | library="baselines", 308 | cnn="clear", 309 | use_lstm=0, 310 | stack_channels="16_32_32", 311 | emb_size=256, 312 | **kwargs, 313 | ): 314 | stack_channels = [int(x) for x in stack_channels.split("_")] 315 | 316 | if library == "baselines": 317 | if cnn == "impala": 318 | from baselines.common.models import build_impala_cnn 319 | 320 | conv_fn = lambda x: build_impala_cnn( 321 | x, depths=stack_channels, emb_size=emb_size 322 | ) 323 | elif cnn == "nature": 324 | from baselines.common.models import nature_cnn 325 | 326 | conv_fn = nature_cnn 327 | elif cnn == "clear": 328 | from lucid.scratch.rl_util.arch import clear_cnn 329 | 330 | conv_fn = clear_cnn 331 | else: 332 | raise ValueError(f"Unsupported cnn: {cnn}") 333 | 334 | if use_lstm: 335 | from baselines.common.models import cnn_lstm 336 | 337 | arch = cnn_lstm(nlstm=256, conv_fn=conv_fn) 338 | else: 339 | arch = conv_fn 340 | 341 | else: 342 | raise ValueError(f"Unsupported library: {library}") 343 | 344 | return arch 345 | 346 | 347 | def create_tf_session(): 348 | """ 349 | Create a TensorFlow session 350 | """ 351 | import tensorflow as tf 352 | 353 | config = tf.ConfigProto() 354 | config.gpu_options.allow_growth = True 355 | return tf.Session(config=config) 356 | 357 | 358 | def get_tf_params(scope): 359 | """ 360 | Get a dictionary of parameters from TensorFlow for the specified scope 361 | """ 362 | import tensorflow as tf 363 | from baselines.common.tf_util import get_session 364 | 365 | sess = get_session() 366 | allvars = tf.trainable_variables(scope) 367 | nonopt_vars = [ 368 | v 369 | for v in allvars 370 | if all(veto not in v.name for veto in ["optimizer", "kbuf", "vbuf"]) 371 | ] 372 | name2var = {v.name: v for v in nonopt_vars} 373 | return sess.run(name2var) 374 | 375 | 376 | def save_data(*, save_dir, args_dict, params, step=None, extra={}): 377 | """ 378 | Save the global config object as well as the current model params to a local file 379 | """ 380 | data_dict = dict(args=args_dict, params=params, extra=extra, time=time.time()) 381 | 382 | step_str = "" if step is None else f"-{step}" 383 | save_path = os.path.join(save_dir, f"checkpoint{step_str}.jd") 384 | 385 | if "://" not in save_dir: 386 | os.makedirs(save_dir, exist_ok=True) 387 | 388 | save_joblib(data_dict, save_path) 389 | 390 | return save_path 391 | 392 | 393 | class VecClipReward(VecEnvWrapper): 394 | def reset(self): 395 | return self.venv.reset() 396 | 397 | def step_wait(self): 398 | """Bin reward to {+1, 0, -1} by its sign.""" 399 | obs, rews, dones, infos = self.venv.step_wait() 400 | return obs, np.sign(rews), dones, infos 401 | 402 | 403 | def train(comm=None, *, save_dir=None, **kwargs): 404 | """ 405 | Train a model using Baselines' PPO2, and to save a checkpoint file in the 406 | required format. 407 | 408 | There is one required kwarg: either env_name (for env_kind="procgen") or 409 | env_id (for env_kind="atari"). 410 | 411 | Models for the paper were trained with 16 parallel MPI workers. 412 | 413 | Note: this code has not been well-tested. 414 | """ 415 | kwargs.setdefault("env_kind", "procgen") 416 | kwargs.setdefault("num_envs", 64) 417 | kwargs.setdefault("learning_rate", 5e-4) 418 | kwargs.setdefault("entropy_coeff", 0.01) 419 | kwargs.setdefault("gamma", 0.999) 420 | kwargs.setdefault("lambda", 0.95) 421 | kwargs.setdefault("num_steps", 256) 422 | kwargs.setdefault("num_minibatches", 8) 423 | kwargs.setdefault("library", "baselines") 424 | kwargs.setdefault("save_all", False) 425 | kwargs.setdefault("ppo_epochs", 3) 426 | kwargs.setdefault("clip_range", 0.2) 427 | kwargs.setdefault("timesteps_per_proc", 1_000_000_000) 428 | kwargs.setdefault("cnn", "clear") 429 | kwargs.setdefault("use_lstm", 0) 430 | kwargs.setdefault("stack_channels", "16_32_32") 431 | kwargs.setdefault("emb_size", 256) 432 | kwargs.setdefault("epsilon_greedy", 0.0) 433 | kwargs.setdefault("reward_scale", 1.0) 434 | kwargs.setdefault("frame_stack", 1) 435 | kwargs.setdefault("use_sticky_actions", 0) 436 | kwargs.setdefault("clip_vf", 1) 437 | kwargs.setdefault("reward_processing", "none") 438 | kwargs.setdefault("save_interval", 10) 439 | 440 | if comm is None: 441 | comm = MPI.COMM_WORLD 442 | rank = comm.Get_rank() 443 | setup_mpi_gpus() 444 | 445 | if save_dir is None: 446 | save_dir = tempfile.mkdtemp(prefix="rl_clarity_train_") 447 | 448 | create_env_kwargs = kwargs.copy() 449 | num_envs = create_env_kwargs.pop("num_envs") 450 | venv = create_env(num_envs, **create_env_kwargs) 451 | 452 | library = kwargs["library"] 453 | if library == "baselines": 454 | reward_processing = kwargs["reward_processing"] 455 | if reward_processing == "none": 456 | pass 457 | elif reward_processing == "clip": 458 | venv = VecClipReward(venv=venv) 459 | elif reward_processing == "normalize": 460 | venv = VecNormalize(venv=venv, ob=False, per_env=False) 461 | else: 462 | raise ValueError(f"Unsupported reward processing: {reward_processing}") 463 | 464 | scope = "ppo2_model" 465 | 466 | def update_fn(update, params=None): 467 | if rank == 0: 468 | save_interval = kwargs["save_interval"] 469 | if save_interval > 0 and update % save_interval == 0: 470 | print("Saving...") 471 | params = get_tf_params(scope) 472 | save_path = save_data( 473 | save_dir=save_dir, 474 | args_dict=kwargs, 475 | params=params, 476 | step=(update if kwargs["save_all"] else None), 477 | ) 478 | print(f"Saved to: {save_path}") 479 | 480 | sess = create_tf_session() 481 | sess.__enter__() 482 | 483 | if kwargs["use_lstm"]: 484 | raise ValueError("Recurrent networks not yet supported.") 485 | arch = get_arch(**kwargs) 486 | 487 | from baselines.ppo2 import ppo2 488 | 489 | ppo2.learn( 490 | env=venv, 491 | network=arch, 492 | total_timesteps=kwargs["timesteps_per_proc"], 493 | save_interval=0, 494 | nsteps=kwargs["num_steps"], 495 | nminibatches=kwargs["num_minibatches"], 496 | lam=kwargs["lambda"], 497 | gamma=kwargs["gamma"], 498 | noptepochs=kwargs["ppo_epochs"], 499 | log_interval=1, 500 | ent_coef=kwargs["entropy_coeff"], 501 | mpi_rank_weight=1.0, 502 | clip_vf=bool(kwargs["clip_vf"]), 503 | comm=comm, 504 | lr=kwargs["learning_rate"], 505 | cliprange=kwargs["clip_range"], 506 | update_fn=update_fn, 507 | init_fn=None, 508 | vf_coef=0.5, 509 | max_grad_norm=0.5, 510 | ) 511 | else: 512 | raise ValueError(f"Unsupported library: {library}") 513 | 514 | return save_dir 515 | -------------------------------------------------------------------------------- /understanding_rl_vision/svelte3/__init__.py: -------------------------------------------------------------------------------- 1 | from .compiling import compile_js, compile_html 2 | -------------------------------------------------------------------------------- /understanding_rl_vision/svelte3/compiling.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import json 3 | import tempfile 4 | import subprocess 5 | import os 6 | from lucid.misc.io.writing import write_handle 7 | from lucid.misc.io.reading import read_handle 8 | from .json_encoding import encoder 9 | 10 | _temp_config_dir = tempfile.mkdtemp(prefix="svelte3_") 11 | _default_js_name = "App" 12 | _default_div_id = "appdiv" 13 | 14 | 15 | class CompileError(Exception): 16 | pass 17 | 18 | 19 | def replace_file_extension(path, extension): 20 | """Replace the file extension of a path with a new extension.""" 21 | if not extension.startswith("."): 22 | extension = "." + extension 23 | dir_, filename = os.path.split(path) 24 | if not filename.endswith(extension): 25 | filename = filename.rsplit(".", 1)[0] 26 | return os.path.join(dir_, filename + extension) 27 | 28 | 29 | @contextmanager 30 | def use_cwd(dir_): 31 | """Context manager for working in a different directory.""" 32 | cwd = os.getcwd() 33 | try: 34 | os.chdir(dir_) 35 | yield 36 | finally: 37 | os.chdir(cwd) 38 | 39 | 40 | def shell_command(command, **kwargs): 41 | """Wrapper around subprocess.check_output. Should be used with care: 42 | https://docs.python.org/3/library/subprocess.html#security-considerations 43 | """ 44 | try: 45 | return subprocess.check_output( 46 | command, 47 | stderr=subprocess.STDOUT, 48 | shell=True, 49 | universal_newlines=True, 50 | **kwargs 51 | ).strip("\n") 52 | except subprocess.CalledProcessError as exn: 53 | raise CompileError( 54 | "Command '%s' failed with output:\n\n%s" % (command, exn.output) 55 | ) from exn 56 | 57 | 58 | def compile_js(svelte_path, js_path=None, *, js_name=None, js_lint=None): 59 | """Compile Svelte to JavaScript. 60 | 61 | Arguments: 62 | svelte_path: path to input Svelte file 63 | js_path: path to output JavaScript file 64 | defaults to svelte_path with a new .js suffix 65 | js_name: name of JavaScript global variable 66 | defaults to _default_js_name 67 | js_lint: whether to use eslint 68 | defaults to True 69 | """ 70 | if js_path is None: 71 | js_path = replace_file_extension(svelte_path, ".js") 72 | if js_name is None: 73 | js_name = _default_js_name 74 | if js_lint is None: 75 | js_lint = True 76 | 77 | eslint_config_fd, eslint_config_path = tempfile.mkstemp( 78 | suffix=".config.json", prefix="eslint_", dir=_temp_config_dir, text=True 79 | ) 80 | eslint_config_path = os.path.abspath(eslint_config_path) 81 | rollup_config_fd, rollup_config_path = tempfile.mkstemp( 82 | suffix=".config.js", prefix="rollup_", dir=_temp_config_dir, text=True 83 | ) 84 | rollup_config_path = os.path.abspath(rollup_config_path) 85 | 86 | svelte_dir = os.path.dirname(svelte_path) or os.curdir 87 | svelte_relpath = os.path.relpath(svelte_path, start=svelte_dir) 88 | js_relpath = os.path.relpath(js_path, start=svelte_dir) 89 | 90 | with open(eslint_config_fd, "w") as eslint_config_file: 91 | json.dump( 92 | { 93 | "env": {"browser": True, "es6": True}, 94 | "extends": "eslint:recommended", 95 | "globals": {"Atomics": "readonly", "SharedArrayBuffer": "readonly"}, 96 | "parserOptions": {"ecmaVersion": 2018, "sourceType": "module"}, 97 | "plugins": ["svelte3"], 98 | "overrides": [{"files": ["*.svelte"], "processor": "svelte3/svelte3"}], 99 | "rules": {}, 100 | }, 101 | eslint_config_file, 102 | ) 103 | 104 | with open(rollup_config_fd, "w") as rollup_config_file: 105 | rollup_config_file.write( 106 | """import svelte from 'rollup-plugin-svelte'; 107 | import resolve from 'rollup-plugin-node-resolve'; 108 | import { eslint } from 'rollup-plugin-eslint'; 109 | import babel from 'rollup-plugin-babel'; 110 | import commonjs from 'rollup-plugin-commonjs'; 111 | import path from 'path'; 112 | 113 | export default { 114 | input: '""" 115 | + svelte_relpath 116 | + """', 117 | output: { 118 | file: '""" 119 | + js_relpath 120 | + """', 121 | format: 'iife', 122 | name: '""" 123 | + js_name 124 | + """' 125 | }, 126 | plugins: [ 127 | eslint({ 128 | include: ['**'], 129 | """ 130 | + ("" if js_lint else "exclude: ['**'],") 131 | + """ 132 | configFile: '""" 133 | + eslint_config_path 134 | + """' 135 | }), 136 | svelte({ 137 | include: ['""" 138 | + svelte_relpath 139 | + """', '**/*.svelte'] 140 | }), 141 | resolve({ 142 | customResolveOptions: { 143 | paths: process.env.NODE_PATH.split( /[;:]/ ) 144 | } 145 | }), 146 | commonjs(), 147 | babel({ 148 | include: ['**', path.resolve(process.env.NODE_PATH, 'svelte/**')], 149 | extensions: ['.js', '.jsx', '.es6', '.es', '.mjs', '.svelte'], 150 | babelrc: false, 151 | cwd: process.env.NODE_PATH, 152 | presets: [['@babel/preset-env', {useBuiltIns: 'usage', corejs: 3}]] 153 | }) 154 | ] 155 | } 156 | """ 157 | ) 158 | 159 | with use_cwd(os.path.dirname(os.path.realpath(__file__)) or os.curdir): 160 | try: 161 | npm_root = shell_command("npm root --quiet") 162 | except CompileError as exn: 163 | raise CompileError( 164 | "Unable to find npm root.\nHave you installed Node.js?" 165 | ) from exn 166 | try: 167 | shell_command("npm ls") 168 | except CompileError as exn: 169 | shell_command("npm install") 170 | 171 | with use_cwd(svelte_dir): 172 | env = os.environ.copy() 173 | env["PATH"] = os.path.join(npm_root, ".bin") + ":" + env["PATH"] 174 | env["NODE_PATH"] = npm_root 175 | command_output = shell_command("rollup -c " + rollup_config_path, env=env) 176 | 177 | return { 178 | "js_path": js_path, 179 | "js_name": js_name, 180 | "command_output": command_output, 181 | } 182 | 183 | 184 | def compile_html( 185 | input_path, 186 | html_path=None, 187 | *, 188 | props=None, 189 | precision=None, 190 | title=None, 191 | div_id=None, 192 | inline_js=None, 193 | svelte_to_js=None, 194 | js_path=None, 195 | js_name=None, 196 | js_lint=None 197 | ): 198 | """Compile Svelte or JavaScript to HTML. 199 | 200 | Arguments: 201 | input_path: path to input Svelte or JavaScript file 202 | html_path: path to output HTML file 203 | defaults to input_path with a new .html suffix 204 | props: JSON-serializable object to pass to Svelte script 205 | defaults to an empty object 206 | precision: number of significant figures to round numpy arrays to 207 | defaults to no rounding 208 | title: title of HTML page 209 | defaults to html_path filename without suffix 210 | div_id: HTML id of div containing Svelte component 211 | defaults to _default_div_id 212 | inline_js: whether to insert the JavaScript into the HTML page inline 213 | defaults to svelte_to_js 214 | svelte_to_js: whether to first compile from Svelte to JavaScript 215 | defaults to whether input_path doesn't have a .js suffix 216 | js_path: path to output JavaScript file if compiling from Svelte 217 | and not inserting the JavaScript inline 218 | defaults to compile_js default 219 | js_name: name of JavaScript global variable 220 | should match existing name if compiling from JavaScript 221 | defaults to _default_js_name 222 | js_lint: whether to use eslint if compiling from Svelte 223 | defaults to compile_js default 224 | """ 225 | if html_path is None: 226 | html_path = replace_file_extension(input_path, ".html") 227 | if props is None: 228 | props = {} 229 | if title is None: 230 | title = os.path.basename(html_path).rsplit(".", 1)[0] 231 | if div_id is None: 232 | div_id = _default_div_id 233 | if svelte_to_js is None: 234 | svelte_to_js = not input_path.endswith(".js") 235 | if inline_js is None: 236 | inline_js = svelte_to_js 237 | 238 | if svelte_to_js: 239 | if inline_js: 240 | if js_path is None: 241 | js_path = replace_file_extension(input_path, ".js") 242 | prefix = "svelte_" + os.path.basename(js_path) 243 | if prefix.endswith(".js"): 244 | prefix = prefix[:-3] 245 | _, js_path = tempfile.mkstemp( 246 | suffix=".js", prefix=prefix + "_", dir=_temp_config_dir, text=True 247 | ) 248 | try: 249 | compile_js_result = compile_js( 250 | input_path, js_path, js_name=js_name, js_lint=js_lint 251 | ) 252 | except CompileError as exn: 253 | raise CompileError( 254 | "Unable to compile Svelte source.\n" 255 | "See the above advice or try supplying pre-compiled JavaScript." 256 | ) from exn 257 | js_path = compile_js_result["js_path"] 258 | js_name = compile_js_result["js_name"] 259 | command_output = compile_js_result["command_output"] 260 | else: 261 | js_path = input_path 262 | if js_name is None: 263 | js_name = _default_js_name 264 | command_output = None 265 | 266 | if inline_js: 267 | with read_handle(js_path, cache=False, mode="r") as js_file: 268 | js_code = js_file.read().rstrip("\n") 269 | js_html = "" 270 | js_path = None 271 | else: 272 | js_relpath = os.path.relpath(js_path, start=os.path.dirname(html_path)) 273 | js_html = '' 274 | 275 | with write_handle(html_path, "w") as html_file: 276 | html_file.write( 277 | """ 278 | 279 | 280 | 281 | """ 282 | + title 283 | + ''' 284 | 285 | 286 |
289 | """ 290 | + js_html 291 | + """ 292 | 304 | 305 | """ 306 | ) 307 | return { 308 | "html_path": html_path, 309 | "js_path": js_path if svelte_to_js else None, 310 | "title": title, 311 | "div_id": div_id, 312 | "js_name": js_name, 313 | "command_output": command_output, 314 | } 315 | -------------------------------------------------------------------------------- /understanding_rl_vision/svelte3/json_encoding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | 4 | 5 | def maybe_round(x, *, precision): 6 | if isinstance(x, list): 7 | return [maybe_round(y, precision=precision) for y in x] 8 | else: 9 | if precision is None: 10 | return float(x) 11 | else: 12 | return float( 13 | np.format_float_positional( 14 | x, precision=precision, unique=False, fractional=False 15 | ) 16 | ) 17 | 18 | 19 | def encoder(precision=None): 20 | class CustomJSONEncoder(json.JSONEncoder): 21 | def default(self, obj): 22 | if isinstance(obj, (tuple, set)): 23 | return list(obj) 24 | elif isinstance(obj, np.integer): 25 | return int(obj) 26 | elif isinstance(obj, np.floating): 27 | return maybe_round(obj, precision=precision) 28 | elif isinstance(obj, np.ndarray): 29 | return maybe_round(obj.tolist(), precision=precision) 30 | elif hasattr(obj, "to_json"): 31 | return obj.to_json() 32 | return json.JSONEncoder.default(self, obj) 33 | 34 | return CustomJSONEncoder 35 | -------------------------------------------------------------------------------- /understanding_rl_vision/svelte3/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "svelte3", 3 | "description": "Svelte 3 compiler with a Python API", 4 | "version": "0.0.1", 5 | "dependencies": { 6 | "@babel/core": "^7.6.2", 7 | "@babel/preset-env": "^7.6.2", 8 | "commonjs": "0.0.1", 9 | "core-js": "^3.2.1", 10 | "eslint": "^6.4.0", 11 | "eslint-plugin-svelte3": "^2.7.3", 12 | "rollup": "^1.21.4", 13 | "rollup-plugin-babel": "^4.3.3", 14 | "rollup-plugin-commonjs": "^10.1.0", 15 | "rollup-plugin-eslint": "^7.0.0", 16 | "rollup-plugin-node-resolve": "^5.2.0", 17 | "rollup-plugin-svelte": "^5.1.0", 18 | "svelte": "^3.12.1" 19 | } 20 | } 21 | --------------------------------------------------------------------------------