├── .floydignore ├── .gitattributes ├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── a2c ├── LICENSE ├── a2c │ ├── a2c.py │ ├── diff_to_baselines.txt │ ├── policies.py │ └── utils.py ├── common │ ├── __init__.py │ ├── atari_wrappers.py │ ├── math_util.py │ ├── misc_util.py │ ├── schedules.py │ ├── tests │ │ └── test_schedules.py │ └── vec_env │ │ ├── __init__.py │ │ └── subproc_vec_env.py └── logger.py ├── enduro_wrapper.py ├── floydhub_utils ├── create_floyd_base.sh ├── floyd_wrapper.sh ├── floyd_wrapper_base.sh ├── get_dir.py ├── get_events.py └── monitor_jobs.py ├── images ├── diagram.png ├── enduro.gif ├── moving-dot-graphs.png ├── moving-dot.gif ├── pong-graphs.png └── pong.gif ├── mem_utils └── plot_mems.py ├── nn_layers.py ├── params.py ├── pref_db.py ├── pref_db_test.py ├── pref_interface.py ├── pref_interface_test.py ├── reward_predictor.py ├── reward_predictor_core_network.py ├── reward_predictor_test.py ├── run.py ├── run_checkpoint.py ├── run_test.py ├── show_prefs.py ├── utils.py └── utils_test.py /.floydignore: -------------------------------------------------------------------------------- 1 | # Directories and files to ignore when uploading code to floyd 2 | 3 | .git 4 | .eggs 5 | eggs 6 | lib 7 | lib64 8 | parts 9 | sdist 10 | var 11 | *.pyc 12 | *.swp 13 | .DS_Store 14 | 15 | runs 16 | images 17 | floydruns 18 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | pref_db_train.pkl filter=lfs diff=lfs merge=lfs -text 2 | pref_db_val.pkl filter=lfs diff=lfs merge=lfs -text 3 | *.pkl filter=lfs diff=lfs merge=lfs -text 4 | reward_network.ckpt* filter=lfs diff=lfs merge=lfs -text 5 | checkpoint* filter=lfs diff=lfs merge=lfs -text 6 | *.ckpt* filter=lfs diff=lfs merge=lfs -text 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.pyc 3 | 4 | .DS_Store 5 | 6 | tags 7 | .idea 8 | 9 | .floydexpt 10 | floydruns 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Matthew Rahtz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | url = "https://pypi.python.org/simple" 3 | verify_ssl = true 4 | name = "pypi" 5 | 6 | [packages] 7 | scipy = "*" 8 | cloudpickle = "*" 9 | matplotlib = "*" 10 | gym = {extras = ["atari"], version = "==0.9.3"} 11 | easy-tf-log = "==1.1" 12 | gym-moving-dot = {git = "https://github.com/mrahtz/gym-moving-dot"} 13 | numpy = "*" 14 | 15 | [dev-packages] 16 | nose = "*" 17 | termcolor = "*" 18 | memory-profiler = "*" 19 | 20 | [requires] 21 | python_version = "3" 22 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "faa502d91b4ab5620b66c7d2cd3e1315accb3bd8eb732815a701b500dcf996c3" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.python.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "atari-py": { 20 | "hashes": [ 21 | "sha256:981fe7c6a4ab68ffcde8499ed0a859697320ac5b14233d52012bb358012b7bd5" 22 | ], 23 | "version": "==0.1.1" 24 | }, 25 | "certifi": { 26 | "hashes": [ 27 | "sha256:13e698f54293db9f89122b0581843a782ad0934a4fe0172d2a980ba77fc61bb7", 28 | "sha256:9fa520c1bacfb634fa7af20a76bcbd3d5fb390481724c597da32c719a7dca4b0" 29 | ], 30 | "version": "==2018.4.16" 31 | }, 32 | "chardet": { 33 | "hashes": [ 34 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 35 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 36 | ], 37 | "version": "==3.0.4" 38 | }, 39 | "cloudpickle": { 40 | "hashes": [ 41 | "sha256:54858c7b7dc763ed894ff91059c1d0b017d593fe23850d3d8d75f47d98398197", 42 | "sha256:ac7acd0dcb1c52fa14873f801efe0c53e6457bcc5f01542cb445a6515d9a4b72" 43 | ], 44 | "index": "pypi", 45 | "version": "==0.5.3" 46 | }, 47 | "cycler": { 48 | "hashes": [ 49 | "sha256:1d8a5ae1ff6c5cf9b93e8811e581232ad8920aeec647c37316ceac982b08cb2d", 50 | "sha256:cd7b2d1018258d7247a71425e9f26463dfb444d411c39569972f4ce586b0c9d8" 51 | ], 52 | "version": "==0.10.0" 53 | }, 54 | "easy-tf-log": { 55 | "hashes": [ 56 | "sha256:b27970e1b16fac2a1b37578eca69ae81c1853bd5fe990243c200b67b4f502353" 57 | ], 58 | "index": "pypi", 59 | "version": "==1.1" 60 | }, 61 | "future": { 62 | "hashes": [ 63 | "sha256:e39ced1ab767b5936646cedba8bcce582398233d6a627067d4c6a454c90cfedb" 64 | ], 65 | "version": "==0.16.0" 66 | }, 67 | "gym": { 68 | "hashes": [ 69 | "sha256:b94939a994a592df8354058e218d820869514b67283fbf59e6c16bd57f489be7" 70 | ], 71 | "index": "pypi", 72 | "version": "==0.9.3" 73 | }, 74 | "gym-moving-dot": { 75 | "git": "https://github.com/mrahtz/gym-moving-dot" 76 | }, 77 | "idna": { 78 | "hashes": [ 79 | "sha256:2c6a5de3089009e3da7c5dde64a141dbc8551d5b7f6cf4ed7c2568d0cc520a8f", 80 | "sha256:8c7309c718f94b3a625cb648ace320157ad16ff131ae0af362c9f21b80ef6ec4" 81 | ], 82 | "version": "==2.6" 83 | }, 84 | "kiwisolver": { 85 | "hashes": [ 86 | "sha256:0ee4ed8b3ae8f5f712b0aa9ebd2858b5b232f1b9a96b0943dceb34df2a223bc3", 87 | "sha256:0f7f532f3c94e99545a29f4c3f05637f4d2713e7fd91b4dd8abfc18340b86cd5", 88 | "sha256:1a078f5dd7e99317098f0e0d490257fd0349d79363e8c923d5bb76428f318421", 89 | "sha256:1aa0b55a0eb1bd3fa82e704f44fb8f16e26702af1a073cc5030eea399e617b56", 90 | "sha256:2874060b91e131ceeff00574b7c2140749c9355817a4ed498e82a4ffa308ecbc", 91 | "sha256:379d97783ba8d2934d52221c833407f20ca287b36d949b4bba6c75274bcf6363", 92 | "sha256:3b791ddf2aefc56382aadc26ea5b352e86a2921e4e85c31c1f770f527eb06ce4", 93 | "sha256:4329008a167fac233e398e8a600d1b91539dc33c5a3eadee84c0d4b04d4494fa", 94 | "sha256:45813e0873bbb679334a161b28cb9606d9665e70561fd6caa8863e279b5e464b", 95 | "sha256:53a5b27e6b5717bdc0125338a822605084054c80f382051fb945d2c0e6899a20", 96 | "sha256:66f82819ff47fa67a11540da96966fb9245504b7f496034f534b81cacf333861", 97 | "sha256:79e5fe3ccd5144ae80777e12973027bd2f4f5e3ae8eb286cabe787bed9780138", 98 | "sha256:8b6a7b596ce1d2a6d93c3562f1178ebd3b7bb445b3b0dd33b09f9255e312a965", 99 | "sha256:9576cb63897fbfa69df60f994082c3f4b8e6adb49cccb60efb2a80a208e6f996", 100 | "sha256:95a25d9f3449046ecbe9065be8f8380c03c56081bc5d41fe0fb964aaa30b2195", 101 | "sha256:aaec1cfd94f4f3e9a25e144d5b0ed1eb8a9596ec36d7318a504d813412563a85", 102 | "sha256:acb673eecbae089ea3be3dcf75bfe45fc8d4dcdc951e27d8691887963cf421c7", 103 | "sha256:b15bc8d2c2848a4a7c04f76c9b3dc3561e95d4dabc6b4f24bfabe5fd81a0b14f", 104 | "sha256:b1c240d565e977d80c0083404c01e4d59c5772c977fae2c483f100567f50847b", 105 | "sha256:ce3be5d520b4d2c3e5eeb4cd2ef62b9b9ab8ac6b6fedbaa0e39cdb6f50644278", 106 | "sha256:e0f910f84b35c36a3513b96d816e6442ae138862257ae18a0019d2fc67b041dc", 107 | "sha256:ea36e19ac0a483eea239320aef0bd40702404ff8c7e42179a2d9d36c5afcb55c", 108 | "sha256:f923406e6b32c86309261b8195e24e18b6a8801df0cfc7814ac44017bfcb3939" 109 | ], 110 | "version": "==1.0.1" 111 | }, 112 | "matplotlib": { 113 | "hashes": [ 114 | "sha256:07055eb872fa109bd88f599bdb52065704b2e22d475b67675f345d75d32038a0", 115 | "sha256:0f2f253d6d51f5ed52a819921f8a0a8e054ce0daefcfbc2557e1c433f14dc77d", 116 | "sha256:1ef9fd285334bd6b0495b6de9d56a39dc95081577f27bafabcf28e0d318bed31", 117 | "sha256:3fb2db66ef98246bafc04b4ef4e9b0e73c6369f38a29716844e939d197df816a", 118 | "sha256:3fd90b407d1ab0dae686a4200030ce305526ff20b85a443dc490d194114b2dfa", 119 | "sha256:45dac8589ef1721d7f2ab0f48f986694494dfcc5d13a3e43a5cb6c816276094e", 120 | "sha256:4bb10087e09629ba3f9b25b6c734fd3f99542f93d71c5b9c023f28cb377b43a9", 121 | "sha256:4dc7ef528aad21f22be85e95725234c5178c0f938e2228ca76640e5e84d8cde8", 122 | "sha256:4f6a516d5ef39128bb16af7457e73dde25c30625c4916d8fbd1cc7c14c55e691", 123 | "sha256:70f0e407fbe9e97f16597223269c849597047421af5eb8b60dbaca0382037e78", 124 | "sha256:7b3d03c876684618e2a2be6abeb8d3a033c3a1bb38a786f751199753ef6227e6", 125 | "sha256:8944d311ce37bee1ba0e41a9b58dcf330ffe0cf29d7654c3d07c572215da68ac", 126 | "sha256:8ff08eaa25c66383fe3b6c7eb288da3c22dcedc4b110a0b592b35f68d0e093b2", 127 | "sha256:9d12378d6a236aa38326e27f3a29427b63edce4ce325745785aec1a7535b1f85", 128 | "sha256:abfd3d9390eb4f2d82cbcaa3a5c2834c581329b64eccb7a071ed9d5df27424f7", 129 | "sha256:bc4d7481f0e8ec94cb1afc4a59905d6274b3b4c389aba7a2539e071766671735", 130 | "sha256:dc0ba2080fd0cfdd07b3458ee4324d35806733feb2b080838d7094731d3f73d9", 131 | "sha256:f26fba7fc68994ab2805d77e0695417f9377a00d36ba4248b5d0f1e5adb08d24" 132 | ], 133 | "index": "pypi", 134 | "version": "==2.2.2" 135 | }, 136 | "numpy": { 137 | "hashes": [ 138 | "sha256:1a784e8ff7ea2a32e393cc53eb0003eca1597c7ca628227e34ce34eb11645a0e", 139 | "sha256:2ba579dde0563f47021dcd652253103d6fd66165b18011dce1a0609215b2791e", 140 | "sha256:3537b967b350ad17633b35c2f4b1a1bbd258c018910b518c30b48c8e41272717", 141 | "sha256:3c40e6b860220ed862e8097b8f81c9af6d7405b723f4a7af24a267b46f90e461", 142 | "sha256:598fe100b2948465cf3ed64b1a326424b5e4be2670552066e17dfaa67246011d", 143 | "sha256:620732f42259eb2c4642761bd324462a01cdd13dd111740ce3d344992dd8492f", 144 | "sha256:709884863def34d72b183d074d8ba5cfe042bc3ff8898f1ffad0209161caaa99", 145 | "sha256:75579acbadbf74e3afd1153da6177f846212ea2a0cc77de53523ae02c9256513", 146 | "sha256:7c55407f739f0bfcec67d0df49103f9333edc870061358ac8a8c9e37ea02fcd2", 147 | "sha256:a1f2fb2da242568af0271455b89aee0f71e4e032086ee2b4c5098945d0e11cf6", 148 | "sha256:a290989cd671cd0605e9c91a70e6df660f73ae87484218e8285c6522d29f6e38", 149 | "sha256:ac4fd578322842dbda8d968e3962e9f22e862b6ec6e3378e7415625915e2da4d", 150 | "sha256:ad09f55cc95ed8d80d8ab2052f78cc21cb231764de73e229140d81ff49d8145e", 151 | "sha256:b9205711e5440954f861ceeea8f1b415d7dd15214add2e878b4d1cf2bcb1a914", 152 | "sha256:bba474a87496d96e61461f7306fba2ebba127bed7836212c360f144d1e72ac54", 153 | "sha256:bebab3eaf0641bba26039fb0b2c5bf9b99407924b53b1ea86e03c32c64ef5aef", 154 | "sha256:cc367c86eb87e5b7c9592935620f22d13b090c609f1b27e49600cd033b529f54", 155 | "sha256:ccc6c650f8700ce1e3a77668bb7c43e45c20ac06ae00d22bdf6760b38958c883", 156 | "sha256:cf680682ad0a3bef56dae200dbcbac2d57294a73e5b0f9864955e7dd7c2c2491", 157 | "sha256:d2910d0a075caed95de1a605df00ee03b599de5419d0b95d55342e9a33ad1fb3", 158 | "sha256:d5caa946a9f55511e76446e170bdad1d12d6b54e17a2afe7b189112ed4412bb8", 159 | "sha256:d89b0dc7f005090e32bb4f9bf796e1dcca6b52243caf1803fdd2b748d8561f63", 160 | "sha256:d95d16204cd51ff1a1c8d5f9958ce90ae190be81d348b514f9be39f878b8044a", 161 | "sha256:e4d5a86a5257843a18fb1220c5f1c199532bc5d24e849ed4b0289fb59fbd4d8f", 162 | "sha256:e58ddb53a7b4959932f5582ac455ff90dcb05fac3f8dcc8079498d43afbbde6c", 163 | "sha256:e80fe25cba41c124d04c662f33f6364909b985f2eb5998aaa5ae4b9587242cce", 164 | "sha256:eda2829af498946c59d8585a9fd74da3f810866e05f8df03a86f70079c7531dd", 165 | "sha256:fd0a359c1c17f00cb37de2969984a74320970e0ceef4808c32e00773b06649d9" 166 | ], 167 | "index": "pypi", 168 | "version": "==1.21.0" 169 | }, 170 | "pillow": { 171 | "hashes": [ 172 | "sha256:00633bc2ec40313f4daf351855e506d296ec3c553f21b66720d0f1225ca84c6f", 173 | "sha256:03514478db61b034fc5d38b9bf060f994e5916776e93f02e59732a8270069c61", 174 | "sha256:040144ba422216aecf7577484865ade90e1a475f867301c48bf9fbd7579efd76", 175 | "sha256:16246261ff22368e5e32ad74d5ef40403ab6895171a7fc6d34f6c17cfc0f1943", 176 | "sha256:1cb38df69362af35c14d4a50123b63c7ff18ec9a6d4d5da629a6f19d05e16ba8", 177 | "sha256:2400e122f7b21d9801798207e424cbe1f716cee7314cd0c8963fdb6fc564b5fb", 178 | "sha256:2ee6364b270b56a49e8b8a51488e847ab130adc1220c171bed6818c0d4742455", 179 | "sha256:3b4560c3891b05022c464b09121bd507c477505a4e19d703e1027a3a7c68d896", 180 | "sha256:41374a6afb3f44794410dab54a0d7175e6209a5a02d407119c81083f1a4c1841", 181 | "sha256:438a3faf5f702c8d0f80b9f9f9b8382cfa048ca6a0d64ef71b86b563b0ee0359", 182 | "sha256:472a124c640bde4d5468f6991c9fa7e30b723d84ac4195a77c6ab6aea30f2b9c", 183 | "sha256:4d32c8e3623a61d6e29ccd024066cd1ba556555abfb4cd714155020e00107e3f", 184 | "sha256:4d8077fd649ac40a5c4165f2c22fa2a4ad18c668e271ecb2f9d849d1017a9313", 185 | "sha256:62ec7ae98357fcd46002c110bb7cad15fce532776f0cbe7ca1d44c49b837d49d", 186 | "sha256:6c7cab6a05351cf61e469937c49dbf3cdf5ffb3eeac71f8d22dc9be3507598d8", 187 | "sha256:6eca36905444c4b91fe61f1b9933a47a30480738a1dd26501ff67d94fc2bc112", 188 | "sha256:74e2ebfd19c16c28ad43b8a28ff73b904ed382ea4875188838541751986e8c9a", 189 | "sha256:7673e7473a13107059377c96c563aa36f73184c29d2926882e0a0210b779a1e7", 190 | "sha256:81762cf5fca9a82b53b7b2d0e6b420e0f3b06167b97678c81d00470daa622d58", 191 | "sha256:8554bbeb4218d9cfb1917c69e6f2d2ad0be9b18a775d2162547edf992e1f5f1f", 192 | "sha256:9b66e968da9c4393f5795285528bc862c7b97b91251f31a08004a3c626d18114", 193 | "sha256:a00edb2dec0035e98ac3ec768086f0b06dfabb4ad308592ede364ef573692f55", 194 | "sha256:b48401752496757e95304a46213c3155bc911ac884bed2e9b275ce1c1df3e293", 195 | "sha256:b6cf18f9e653a8077522bb3aa753a776b117e3e0cc872c25811cfdf1459491c2", 196 | "sha256:bb8adab1877e9213385cbb1adc297ed8337e01872c42a30cfaa66ff8c422779c", 197 | "sha256:c8a4b39ba380b57a31a4b5449a9d257b1302d8bc4799767e645dcee25725efe1", 198 | "sha256:cee9bc75bff455d317b6947081df0824a8f118de2786dc3d74a3503fd631f4ef", 199 | "sha256:d0dc1313dff48af64517cbbd85e046d6b477fbe5e9d69712801f024dcb08c62b", 200 | "sha256:d5bf527ed83617edd1855a5c923eeeaf68bcb9ac0ceb28e3f19b575b3a424984", 201 | "sha256:df5863a21f91de5ecdf7d32a32f406dd9867ebb35d41033b8bd9607a21887599", 202 | "sha256:e39142332541ed2884c257495504858b22c078a5d781059b07aba4c3a80d7551", 203 | "sha256:e52e8f675ba0b2b417fa98579e7286a41a8e23871f17f4793772f5aa884fea79", 204 | "sha256:e6dd55d5d94b9e36929325dd0c9ab85bfde84a5fc35947c334c32af1af668944", 205 | "sha256:e87cc1acbebf263f308a8494272c2d42016aa33c32bf14d209c81e1f65e11868", 206 | "sha256:ea0091cd4100519cedfeea2c659f52291f535ac6725e2368bcf59e874f270efa", 207 | "sha256:eeb247f4f4d962942b3b555530b0c63b77473c7bfe475e51c6b75b7344b49ce3", 208 | "sha256:f0d4433adce6075efd24fc0285135248b0b50f5a58129c7e552030e04fe45c7f", 209 | "sha256:f1f3bd92f8e12dc22884935a73c9f94c4d9bd0d34410c456540713d6b7832b8c", 210 | "sha256:f42a87cbf50e905f49f053c0b1fb86c911c730624022bf44c8857244fc4cdaca", 211 | "sha256:f5f302db65e2e0ae96e26670818157640d3ca83a3054c290eff3631598dcf819", 212 | "sha256:f7634d534662bbb08976db801ba27a112aee23e597eeaf09267b4575341e45bf", 213 | "sha256:fdd374c02e8bb2d6468a85be50ea66e1c4ef9e809974c30d8576728473a6ed03", 214 | "sha256:fe6931db24716a0845bd8c8915bd096b77c2a7043e6fc59ae9ca364fe816f08b" 215 | ], 216 | "version": "==5.1.0" 217 | }, 218 | "pyglet": { 219 | "hashes": [ 220 | "sha256:8b07aea16f34ac861cffd06a0c17723ca944d172e577b57b21859b7990709a66", 221 | "sha256:b00570e7cdf6971af8953b6ece50d83d13272afa5d1f1197c58c0f478dd17743" 222 | ], 223 | "version": "==1.3.2" 224 | }, 225 | "pyopengl": { 226 | "hashes": [ 227 | "sha256:5b3a14a4e4c87cf25ca41bc5878583df0e434624b13fcc55d1a672f79ad1b76f", 228 | "sha256:9b47c5c3a094fa518ca88aeed35ae75834d53e4285512c61879f67a48c94ddaf", 229 | "sha256:efa4e39a49b906ccbe66758812ca81ced13a6f26931ab2ba2dba2750c016c0d0" 230 | ], 231 | "version": "==3.1.0" 232 | }, 233 | "pyparsing": { 234 | "hashes": [ 235 | "sha256:0832bcf47acd283788593e7a0f542407bd9550a55a8a8435214a1960e04bcb04", 236 | "sha256:281683241b25fe9b80ec9d66017485f6deff1af5cde372469134b56ca8447a07", 237 | "sha256:8f1e18d3fd36c6795bb7e02a39fd05c611ffc2596c1e0d995d34d67630426c18", 238 | "sha256:9e8143a3e15c13713506886badd96ca4b579a87fbdf49e550dbfc057d6cb218e", 239 | "sha256:b8b3117ed9bdf45e14dcc89345ce638ec7e0e29b2b579fa1ecf32ce45ebac8a5", 240 | "sha256:e4d45427c6e20a59bf4f88c639dcc03ce30d193112047f94012102f235853a58", 241 | "sha256:fee43f17a9c4087e7ed1605bd6df994c6173c1e977d7ade7b651292fab2bd010" 242 | ], 243 | "version": "==2.2.0" 244 | }, 245 | "python-dateutil": { 246 | "hashes": [ 247 | "sha256:1adb80e7a782c12e52ef9a8182bebeb73f1d7e24e374397af06fb4956c8dc5c0", 248 | "sha256:e27001de32f627c22380a688bcc43ce83504a7bc5da472209b4c70f02829f0b8" 249 | ], 250 | "version": "==2.7.3" 251 | }, 252 | "pytz": { 253 | "hashes": [ 254 | "sha256:65ae0c8101309c45772196b21b74c46b2e5d11b6275c45d251b150d5da334555", 255 | "sha256:c06425302f2cf668f1bba7a0a03f3c1d34d4ebeef2c72003da308b3947c7f749" 256 | ], 257 | "version": "==2018.4" 258 | }, 259 | "requests": { 260 | "hashes": [ 261 | "sha256:6a1b267aa90cac58ac3a765d067950e7dbbf75b1da07e895d1f594193a40a38b", 262 | "sha256:9c443e7324ba5b85070c4a818ade28bfabedf16ea10206da1132edaa6dda237e" 263 | ], 264 | "version": "==2.18.4" 265 | }, 266 | "scipy": { 267 | "hashes": [ 268 | "sha256:0572256c10ddd058e3d315c555538671ddb2737f27eb56189bfbc3483391403f", 269 | "sha256:2e685fdbfa5b989af4338b29c408b9157ea6addec15d661104c437980c292be5", 270 | "sha256:3595c8b64970c9e5a3f137fa1a9eb64da417e78fb7991d0b098b18a00b776d88", 271 | "sha256:3e7df79b42c3015058a5554bfeab6fd4c9906c46560c9ddebb5c652840f3e182", 272 | "sha256:4ef3d4df8af40cb6f4d4eaf7b02780109ebabeec334cda26a7899ec9d8de9176", 273 | "sha256:53116abd5060a5b4a58489cf689bee259b779e6b7ecd4ce366e7147aa7c9626e", 274 | "sha256:5a983d3cebc27294897951a494cebd78af2eae37facf75d9e4ad4f1f62229860", 275 | "sha256:5eb8f054eebb351af7490bbb57465ba9662c4e16e1786655c6c7ed530eb9a74e", 276 | "sha256:6130e22bf6ee506f7cddde7e0515296d97eb6c6c94f7ef5103c2b77aec5833a7", 277 | "sha256:7f4b89c223bd09460b52b669e2e642cab73c28855b540e6ed029692546a86f8d", 278 | "sha256:80df8af7039bce92fb4cd1ceb056258631b11b3c627384e2d29bb48d44c0cae7", 279 | "sha256:821e75f5c16cd7b0ab0ffe7eb9917e5af7b48c25306b4777287de8d792a5f7f3", 280 | "sha256:97ca4552ace1c313707058e774609af59644321e278c3a539322fab2fb09b943", 281 | "sha256:998c5e6ea649489302de2c0bc026ed34284f531df89d2bdc8df3a0d44d165739", 282 | "sha256:aef6e922aea6f2e6bbb539b413c85210a9ee32757535b84204ebd22723e69704", 283 | "sha256:b77ee5e3a9507622e7f98b16122242a3903397f98d1fe3bc269d904a9025e2bc", 284 | "sha256:bd4399d4388ca0239a4825e312b3e61b60f743dd6daf49e5870837716502a92a", 285 | "sha256:c5d012cb82cc1dcfa72609abaabb4a4ed8113e3e8ac43464508a418c146be57d", 286 | "sha256:e7b733d4d98e604109715e11f2ab9340eb45d53f803634ed730039070fc3bc11" 287 | ], 288 | "index": "pypi", 289 | "version": "==1.7.0" 290 | }, 291 | "six": { 292 | "hashes": [ 293 | "sha256:70e8a77beed4562e7f14fe23a786b54f6296e34344c23bc42f07b15018ff98e9", 294 | "sha256:832dc0e10feb1aa2c68dcc57dbb658f1c7e65b9b61af69048abc87a2db00a0eb" 295 | ], 296 | "version": "==1.11.0" 297 | }, 298 | "urllib3": { 299 | "hashes": [ 300 | "sha256:06330f386d6e4b195fbfc736b297f58c5a892e4440e54d294d7004e3a9bbea1b", 301 | "sha256:cc44da8e1145637334317feebd728bd869a35285b93cbb4cca2577da7e62db4f" 302 | ], 303 | "version": "==1.22" 304 | } 305 | }, 306 | "develop": { 307 | "memory-profiler": { 308 | "hashes": [ 309 | "sha256:e38627e66ca787f56ad2898699e07cb7ae2049a7dc075d535367cd882c417b9a" 310 | ], 311 | "index": "pypi", 312 | "version": "==0.52.0" 313 | }, 314 | "nose": { 315 | "hashes": [ 316 | "sha256:9ff7c6cc443f8c51994b34a667bbcf45afd6d945be7477b52e97516fd17c53ac", 317 | "sha256:dadcddc0aefbf99eea214e0f1232b94f2fa9bd98fa8353711dacb112bfcbbb2a", 318 | "sha256:f1bffef9cbc82628f6e7d7b40d7e255aefaa1adb6a1b1d26c69a8b79e6208a98" 319 | ], 320 | "index": "pypi", 321 | "version": "==1.3.7" 322 | }, 323 | "psutil": { 324 | "hashes": [ 325 | "sha256:325c334596ad2d8a178d0e7b4eecc91748096a87489b3701ee16986173000aaa", 326 | "sha256:33384065f0014351fa70187548e3e95952c4df4bc5c38648bd0e647d21eaaf01", 327 | "sha256:51e12aa74509832443862373a2655052b20c83cad7322f49d217452500b9a405", 328 | "sha256:52a91ba928a5e86e0249b4932d6e36972a72d1ad8dcc5b7f753a2ae14825a4ba", 329 | "sha256:99029b6af386b22882f0b6d537ffed5a9c3d5ff31782974aeaa1d683262d8543", 330 | "sha256:b10703a109cc9225cd588c207f7f93480a420ade35c13515ea8f20063b42a392", 331 | "sha256:ddba952ed256151844d82fb13c8fb1019fe11ecaeacbd659d67ba5661ae73d0d", 332 | "sha256:ebe293be36bb24b95cdefc5131635496e88b17fabbcf1e4bc9b5c01f5e489cfe", 333 | "sha256:f24cd52bafa06917935fe1b68c5a45593abe1f3097dc35b2dfc4718236795890" 334 | ], 335 | "version": "==5.4.5" 336 | }, 337 | "termcolor": { 338 | "hashes": [ 339 | "sha256:1d6d69ce66211143803fbc56652b41d73b4a400a2891d7bf7a1cdf4c02de613b" 340 | ], 341 | "index": "pypi", 342 | "version": "==1.1.0" 343 | } 344 | } 345 | } 346 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Reinforcement Learning from Human Preferences 2 | 3 | Reproduction of OpenAI and DeepMind's [Deep Reinforcement Learning from Human 4 | Preferences](https://blog.openai.com/deep-reinforcement-learning-from-human-preferences/), 5 | based on the paper at . 6 | 7 | 8 | ## Results 9 | 10 | The main milestones of this reproduction were: 11 | 12 | * Training an agent to move the dot to the middle in a [simple environment](https://github.com/mrahtz/gym-moving-dot) using synthetic preferences. 13 | 14 | * Training an agent to play Pong using synthetic preferences. 15 | 16 | * Training an agent to stay alongside other cars in Enduro using *human* preferences. 17 | 18 | ![](images/moving-dot.gif) 19 | ![](images/pong.gif) 20 | ![](images/enduro.gif) 21 | 22 | 23 | ## Usage 24 | 25 | ### Python setup 26 | 27 | This project uses Tensorflow 1, which needs Python 3.7 or below. 28 | 29 | To set up an isolated environment and install dependencies, install 30 | [Pipenv](https://github.com/pypa/pipenv), then just run: 31 | 32 | `$ pipenv install` 33 | 34 | However, note that TensorFlow must be installed manually. Either: 35 | 36 | `$ pipenv run pip install tensorflow==1.15` 37 | 38 | or 39 | 40 | `$ pipenv run pip install tensorflow-gpu==1.15` 41 | 42 | depending on whether you have a GPU. (If you run into problems, try installing 43 | TensorFlow 1.6.0, which was used for development.) 44 | 45 | If you want to run tests, also run: 46 | 47 | `$ pipenv install --dev` 48 | 49 | Finally, before running any of the scripts, enter the environment with: 50 | 51 | `$ pipenv shell` 52 | 53 | ### Running 54 | 55 | All training is done using [`run.py`](run.py). Basic usage is: 56 | 57 | `$ python3 run.py ` 58 | 59 | Supported environments are 60 | [`MovingDotNoFrameskip-v0`](https://github.com/mrahtz/gym-moving-dot), 61 | `PongNoFrameskip-v4`, and `EnduroNoFrameskip-v4`. 62 | 63 | ### Training with original rewards 64 | 65 | To train using the original rewards from the environment rather than rewards 66 | based on preferences, use the `train_policy_with_original_rewards` mode. 67 | 68 | For example, to train Pong: 69 | 70 | `$ python3 run.py train_policy_with_original_rewards PongNoFrameskip-v4 --n_envs 16 --million_timesteps 10` 71 | 72 | ### Training end-to-end with preferences 73 | 74 | Use the `train_policy_with_preferences` mode. 75 | 76 | For example, to train `MovingDotNoFrameskip-v0` using *synthetic* preferences: 77 | 78 | `$ python3 run.py train_policy_with_preferences MovingDotNoFrameskip-v0 --synthetic_prefs --ent_coef 0.02 --million_timesteps 0.15` 79 | 80 | On a machine with a GPU, this takes about an hour. TensorBoard logs (created in 81 | a new directory in `runs/` automatically) should look something like: 82 | 83 | ![](images/moving-dot-graphs.png) 84 | 85 | To train Pong using *synthetic* preferences: 86 | 87 | `$ python3 run.py train_policy_with_preferences PongNoFrameskip-v4 --synthetic_prefs --dropout 0.5 --n_envs 16 --million_timesteps 20` 88 | 89 | On a 16-core machine without GPU, this takes about 13 hours. TensorBoard logs 90 | should look something like: 91 | 92 | ![](images/pong-graphs.png) 93 | 94 | To train Enduro (a modified version with a time limit so the weather doesn't change, which the paper notes can confuse the reward predictor) using *human* preferences: 95 | 96 | `$ python3 run.py train_policy_with_preferences EnduroNoFrameskip-v4 --n_envs 16 --render_episodes` 97 | 98 | You'll see two windows: a larger one showing a pair of examples of agent 99 | behaviour, and another smaller window showing the last full episode that the 100 | agent played (so you can see how qualitative behaviour is changing). Enter 'L' 101 | in the terminal to indicate that you prefer the left example; 'R' to indicate 102 | you prefer the right example; 'E' to indicate you prefer them both equally; and 103 | just press enter if the two clips are incomparable. 104 | 105 | On an 8-core machine with GPU, it takes about 2.5 hours to reproduce the video 106 | above - about an hour to collect 500 preferences about behaviour from a random 107 | policy, then half an hour to pretrain the reward predictor using those 500 108 | preferences, then an hour to train the policy (while still collecting 109 | preferences.) 110 | 111 | The bottleneck is mainly labelling speed, so if you're already saved human preferences in `runs/enduro`, you can re-use those preferences by training with: 112 | 113 | `$ python3 run.py train_policy_with_preferences EnduroNoFrameskip-v4 --n_envs 16 --render_episodes --load_prefs_dir runs/enduro --n_initial_epochs 10` 114 | 115 | This only takes about half an hour. 116 | 117 | ### Piece-by-piece runs 118 | 119 | You can also run different parts of the training process separately, saving 120 | their results for later use: 121 | * Use the `gather_initial_prefs` mode to gather the initial 500 preferences 122 | used for pretraining the reward predictor. This saves preferences to 123 | `train_initial.pkl.gz` and `val_initial.pkl.gz` in the run directory. 124 | * Use `pretrain_reward_predictor` to just pretrain the reward predictor (200 125 | epochs). Specify the run directory to load initial preferences from with 126 | `--load_prefs_dir`. 127 | * Load a pretrained reward predictor using the `--load_reward_predictor_ckpt` 128 | argument when running in `train_policy_with_preferences` mode. 129 | 130 | For example, to gather synthetic preferences for `MovingDotNoFrameskip-v0`, 131 | saving to run directory `moving_dot-initial_prefs`: 132 | 133 | `$ python run.py gather_initial_prefs MovingDotNoFrameskip-v0 --synthetic_prefs --run_name moving_dot-initial_prefs` 134 | 135 | ### Running on FloydHub 136 | 137 | To run on [FloydHub](https://www.floydhub.com) (a cloud platform for 138 | running machine learning jobs), use something like: 139 | 140 | `floyd run --follow --env tensorflow-1.5 --tensorboard 141 | 'bash floydhub_utils/floyd_wrapper.sh python run.py 142 | --log_dir /output --synthetic_prefs 143 | train_policy_with_preferences PongNoFrameskip-v4'` 144 | 145 | Check out runs reproducing the above results at 146 | . 147 | 148 | ### Running checkpoints 149 | 150 | To run a trained policy checkpoint so you can see what the agent was doing, use 151 | [`run_checkpoint.py`](run_checkpoint.py) Basic usage is: 152 | 153 | `$ python3 run_checkpoint.py ` 154 | 155 | For example, to run an agent saved in `runs/pong`: 156 | 157 | `$ python3 run_checkpoint.py PongNoFrameskip-v4 runs/pong/policy_checkpoints` 158 | 159 | 160 | ## Architecture notes 161 | 162 | There are three main components: 163 | * The A2C workers ([`a2c/a2c/a2c.py`](a2c/a2c/a2c.py)) 164 | * The preference interface ([`pref_interface.py`](pref_interface.py)) 165 | * The reward predictor ([`reward_predictor.py`](reward_predictor.py)) 166 | 167 | ### Data flow 168 | 169 | The flow of data begins with the A2C workers, which generate video clips of the 170 | agent trying things in the environment. 171 | 172 | These video clips (referred to in the code as 'segments') are sent to the 173 | preference interface. The preference interface shows pairs of video clips to 174 | the user and asks through a command-line interface which clip of each pair 175 | shows more of the kind of behaviour the user wants. 176 | 177 | Preferences are sent to the reward predictor, which trains a deep neural 178 | network to predict the each preference from the associated pair of video clips. 179 | Preferences are predicted based on a comparison between two penultimate scalar 180 | values in the network (one for each video clip) representing some measure of 181 | how much the user likes each of the two clips in the pair. 182 | 183 | That network can then be used to predict rewards for future video clips by 184 | feeding the clip in, running a forward pass to calculate the "how much the user 185 | likes this clip" value, then normalising the result to have zero mean and 186 | constant variance across time. 187 | 188 | This normalised value is then used directly as a reward signal to train the A2C 189 | workers according to the preferences given by the user. 190 | 191 | ### Processes 192 | 193 | All components run asynchronously in different subprocesses: 194 | * A2C workers explore the environment and train the policy. 195 | * The preference interface queries the user for preference. 196 | * The reward predictor is trained using preferences given. 197 | 198 | There are three tricky parts to this: 199 | * Video clips must be sent from the A2C process to the process asking for 200 | preferences using a queue. Video clips are cheap, and the A2C process should 201 | never stop, so the A2C process only puts a clip onto the queue if the queue 202 | is empty, and otherwise drops the clips. The preference interface then just 203 | gets as many clips as it can from the queue in 0.5 seconds, in between asking 204 | about each pair of clips. (Pairs to show the user are selected from the clip 205 | database internal to the preference interface into which clips from the queue 206 | are stored.) 207 | * Preferences must be sent from the preference interface to the reward 208 | predictor using a queue. Preferences should never be dropped, though, so the 209 | preference interface blocks until the preference can be added to the queue, 210 | and the reward predictor training process runs a background thread which 211 | constantly receives from the queue, storing preference in the reward 212 | predictor process's internal database. 213 | * Both the A2C process and the reward predictor training process need to access 214 | the reward predictor network. This is done using Distributed TensorFlow: each 215 | process maintains its own copy of the network, and parameter updates from 216 | the reward predictor training process are automatically replicated to the A2C 217 | worker process's network. 218 | 219 | All subprocesses are started and coordinated by [`run.py`](run.py). 220 | 221 | ![](images/diagram.png) 222 | 223 | 224 | ## Changes to the paper's setup 225 | 226 | It turned out to be possible to reach the milestones in the results section 227 | above even without implementing a number of features described in the original 228 | paper. 229 | 230 | * For regularisation of the reward predictor network, the paper uses dropout, 231 | batchnorm and an adaptive L2 regularisation scheme. Here, we only use 232 | dropout. (Batchnorm is also supported. L2 regularisation is not implemented.) 233 | * In the paper's setup, the rate at which preferences are requested is 234 | gradually reduced over time. We just ask for preferences at a constant rate. 235 | * The paper selects video clips to show the user based on predicted reward 236 | uncertainty among an ensemble of reward predictors. Early experiments 237 | suggested a higher chance of successful training by just selecting video 238 | clips randomly (also noted by the paper in some situations), so we don't do 239 | any ensembling. (Ensembling code *is* implemented in 240 | [`reward_predictor.py`](reward_predictor.py), but we always operate with only 241 | a single-member ensemble, and [`pref_interface.py`](pref_interface.py) just 242 | chooses segments randomly.) 243 | * The preference for each pair of video clips is calculated based on a softmax 244 | over the predicted latent reward values for each clip. In the paper, 245 | "Rather than applying a softmax directly...we assume there is a 10% chance 246 | that the human responds uniformly at random. Conceptually this adjustment is 247 | needed because human raters have a constant probability of making an error, 248 | which doesn’t decay to 0 as the difference in reward difference becomes 249 | extreme." I wasn't sure how to implement this - at least, I couldn't see a 250 | way to implement it that would actually affect the gradients - so we just do 251 | the softmax directly. (Update: see https://github.com/mrahtz/learning-from-human-preferences/issues/8.) 252 | 253 | ## Ideas for future work 254 | 255 | If you want to hack on this project to learn some deep RL, here are some ideas 256 | for extensions and things to investigate: 257 | 258 | * **Better ways of selecting video clips for query**. As mentioned above and in 259 | the paper, it looks like using variance across ensemble members to select 260 | video clips to ask the user about sometimes _harms_ performance. Why is this? 261 | Is there some inherent reason that "Ask the user about the clips we're most 262 | uncertain about" is a bad heuristic (e.g. because then we focus too much on 263 | strange examples, and don't sample enough preferences for more common 264 | situations)? Or is it a problem with the uncertainty calculation? Do we get 265 | different results using [dropout-based 266 | uncertainty](https://arxiv.org/pdf/1506.02142.pdf), or by [ensembling but 267 | with shared parameters](https://arxiv.org/pdf/1602.04621.pdf)? 268 | * **Domain randomisation for the reward predictor**. The paper notes that when 269 | training an agent to stay alongside other cars in Enduro, "the agent learns 270 | to stay almost exactly even with other moving cars for a substantial fraction 271 | of the episode, although it gets confused by changes in background". Could 272 | this be mitigated with [domain 273 | randomization](https://arxiv.org/pdf/1703.06907.pdf)? E.g. would randomly 274 | changing the colours of the frames encourage the reward predictor to be more 275 | invariant to changes in background? 276 | * **Alternative reward predictor architectures**. When training Enduro, the 277 | user ends up giving enough preferences to cover pretty much the full range of 278 | possible car positions on the track. It's therefore unclear how much success 279 | in the kinds of simple environments we're playing with here is down to the 280 | interesting generalisation capabilities of deep neural networks, and how much 281 | it's just memorisation of examples. It could be interesting to explore much 282 | simpler architectures of reward predictor - for example, one which tries to 283 | establish a ranking of video clips directly from preferences (I'm not 284 | familiar with the literature, but e.g. [Efficient Ranking from Pairwise 285 | Comparisons](http://proceedings.mlr.press/v28/wauthier13.pdf)), then gives 286 | reward corresponding to the rank of the most similar video clip. 287 | * **Automatic reward shaping**. Watching the graph of rewards predicted by the 288 | reward predictor (run [`run_checkpoint.py`](run_checkpoint.py) with a reward 289 | predictor checkpoint), it looks like the predicted rewards might be slightly 290 | better-shaped than the original rewards, even when trained with synthetic 291 | preferences based on the original rewards. Specifically, in Pong, it looks 292 | like there might be a small positive reward whenever the agent hits the ball. 293 | Could a reward predictor trained from synthetic preferences be used to 294 | automatically shape rewards for easier training? 295 | 296 | ## Code credits 297 | 298 | A2C code in [`a2c`](a2c) is based on the implementation from [OpenAI's baselines](https://github.com/openai/baselines), commit [`f8663ea`](https://github.com/openai/baselines/commit/f8663ea). 299 | -------------------------------------------------------------------------------- /a2c/LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2017 OpenAI (http://openai.com) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /a2c/a2c/a2c.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | import queue 4 | import time 5 | 6 | import cloudpickle 7 | import easy_tf_log 8 | import numpy as np 9 | from numpy.testing import assert_equal 10 | import tensorflow as tf 11 | 12 | from a2c import logger 13 | from a2c.a2c.utils import (cat_entropy, discount_with_dones, 14 | find_trainable_variables, mse) 15 | from a2c.common import explained_variance, set_global_seeds 16 | from pref_db import Segment 17 | 18 | 19 | class Model(object): 20 | def __init__(self, 21 | policy, 22 | ob_space, 23 | ac_space, 24 | nenvs, 25 | nsteps, 26 | nstack, 27 | num_procs, 28 | lr_scheduler, 29 | ent_coef=0.01, 30 | vf_coef=0.5, 31 | max_grad_norm=0.5, 32 | alpha=0.99, 33 | epsilon=1e-5): 34 | config = tf.ConfigProto( 35 | allow_soft_placement=True, 36 | intra_op_parallelism_threads=num_procs, 37 | inter_op_parallelism_threads=num_procs) 38 | config.gpu_options.allow_growth = True 39 | sess = tf.Session(config=config) 40 | nbatch = nenvs * nsteps 41 | 42 | A = tf.placeholder(tf.int32, [nbatch]) 43 | ADV = tf.placeholder(tf.float32, [nbatch]) 44 | R = tf.placeholder(tf.float32, [nbatch]) 45 | LR = tf.placeholder(tf.float32, []) 46 | 47 | step_model = policy( 48 | sess, ob_space, ac_space, nenvs, 1, nstack, reuse=False) 49 | train_model = policy( 50 | sess, ob_space, ac_space, nenvs, nsteps, nstack, reuse=True) 51 | 52 | neglogpac = tf.nn.sparse_softmax_cross_entropy_with_logits( 53 | logits=train_model.pi, labels=A) 54 | pg_loss = tf.reduce_mean(ADV * neglogpac) 55 | vf_loss = tf.reduce_mean(mse(tf.squeeze(train_model.vf), R)) 56 | entropy = tf.reduce_mean(cat_entropy(train_model.pi)) 57 | loss = pg_loss - entropy * ent_coef + vf_loss * vf_coef 58 | 59 | params = find_trainable_variables("model") 60 | grads = tf.gradients(loss, params) 61 | if max_grad_norm is not None: 62 | grads, grad_norm = tf.clip_by_global_norm(grads, max_grad_norm) 63 | grads = list(zip(grads, params)) 64 | trainer = tf.train.RMSPropOptimizer( 65 | learning_rate=LR, decay=alpha, epsilon=epsilon) 66 | _train = trainer.apply_gradients(grads) 67 | 68 | def train(obs, states, rewards, masks, actions, values): 69 | advs = rewards - values 70 | n_steps = len(obs) 71 | for _ in range(n_steps): 72 | cur_lr = lr_scheduler.value() 73 | td_map = { 74 | train_model.X: obs, 75 | A: actions, 76 | ADV: advs, 77 | R: rewards, 78 | LR: cur_lr 79 | } 80 | if states != []: 81 | td_map[train_model.S] = states 82 | td_map[train_model.M] = masks 83 | policy_loss, value_loss, policy_entropy, _ = sess.run( 84 | [pg_loss, vf_loss, entropy, _train], td_map) 85 | return policy_loss, value_loss, policy_entropy, cur_lr 86 | 87 | self.train = train 88 | self.train_model = train_model 89 | self.step_model = step_model 90 | self.step = step_model.step 91 | self.value = step_model.value 92 | self.initial_state = step_model.initial_state 93 | self.sess = sess 94 | # Why var_list=params? 95 | # Otherwise we'll also save optimizer parameters, 96 | # which take up a /lot/ of space. 97 | # Why save_relative_paths=True? 98 | # So that the plain-text 'checkpoint' file written uses relative paths, 99 | # which seems to be needed in order to avoid confusing saver.restore() 100 | # when restoring from FloydHub runs. 101 | self.saver = tf.train.Saver( 102 | max_to_keep=1, var_list=params, save_relative_paths=True) 103 | tf.global_variables_initializer().run(session=sess) 104 | 105 | def load(self, ckpt_path): 106 | self.saver.restore(self.sess, ckpt_path) 107 | 108 | def save(self, ckpt_path, step_n): 109 | saved_path = self.saver.save(self.sess, ckpt_path, step_n) 110 | print("Saved policy checkpoint to '{}'".format(saved_path)) 111 | 112 | 113 | class Runner(object): 114 | def __init__(self, 115 | env, 116 | model, 117 | nsteps, 118 | nstack, 119 | gamma, 120 | gen_segments, 121 | seg_pipe, 122 | reward_predictor, 123 | episode_vid_queue): 124 | self.env = env 125 | self.model = model 126 | nh, nw, nc = env.observation_space.shape 127 | nenv = env.num_envs 128 | self.batch_ob_shape = (nenv * nsteps, nh, nw, nc * nstack) 129 | self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8) 130 | # The first stack of 4 frames: the first 3 frames are zeros, 131 | # with the last frame coming from env.reset(). 132 | obs = env.reset() 133 | self.update_obs(obs) 134 | self.gamma = gamma 135 | self.nsteps = nsteps 136 | self.states = model.initial_state 137 | self.dones = [False for _ in range(nenv)] 138 | 139 | self.gen_segments = gen_segments 140 | self.segment = Segment() 141 | self.seg_pipe = seg_pipe 142 | 143 | self.orig_reward = [0 for _ in range(nenv)] 144 | self.reward_predictor = reward_predictor 145 | 146 | self.episode_frames = [] 147 | self.episode_vid_queue = episode_vid_queue 148 | 149 | def update_obs(self, obs): 150 | # Do frame-stacking here instead of the FrameStack wrapper to reduce 151 | # IPC overhead 152 | self.obs = np.roll(self.obs, shift=-1, axis=3) 153 | self.obs[:, :, :, -1] = obs[:, :, :, 0] 154 | 155 | def update_segment_buffer(self, mb_obs, mb_rewards, mb_dones): 156 | # Segments are only generated from the first worker. 157 | # Empirically, this seems to work fine. 158 | e0_obs = mb_obs[0] 159 | e0_rew = mb_rewards[0] 160 | e0_dones = mb_dones[0] 161 | assert_equal(e0_obs.shape, (self.nsteps, 84, 84, 4)) 162 | assert_equal(e0_rew.shape, (self.nsteps, )) 163 | assert_equal(e0_dones.shape, (self.nsteps, )) 164 | 165 | for step in range(self.nsteps): 166 | self.segment.append(np.copy(e0_obs[step]), np.copy(e0_rew[step])) 167 | if len(self.segment) == 25 or e0_dones[step]: 168 | while len(self.segment) < 25: 169 | # Pad to 25 steps long so that all segments in the batch 170 | # have the same length. 171 | # Note that the reward predictor needs the full frame 172 | # stack, so we send all frames. 173 | self.segment.append(e0_obs[step], 0) 174 | self.segment.finalise() 175 | try: 176 | self.seg_pipe.put(self.segment, block=False) 177 | except queue.Full: 178 | # If the preference interface has a backlog of segments 179 | # to deal with, don't stop training the agents. Just drop 180 | # the segment and keep on going. 181 | pass 182 | self.segment = Segment() 183 | 184 | def update_episode_frame_buffer(self, mb_obs, mb_dones): 185 | e0_obs = mb_obs[0] 186 | e0_dones = mb_dones[0] 187 | for step in range(self.nsteps): 188 | # Here we only need to send the last frame (the most recent one) 189 | # from the 4-frame stack, because we're just showing output to 190 | # the user. 191 | self.episode_frames.append(e0_obs[step, :, :, -1]) 192 | if e0_dones[step]: 193 | self.episode_vid_queue.put(self.episode_frames) 194 | self.episode_frames = [] 195 | 196 | def run(self): 197 | nenvs = len(self.env.remotes) 198 | mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = \ 199 | [], [], [], [], [] 200 | mb_states = self.states 201 | 202 | # Run for nsteps steps in the environment 203 | for _ in range(self.nsteps): 204 | actions, values, states = self.model.step(self.obs, self.states, 205 | self.dones) 206 | mb_obs.append(np.copy(self.obs)) 207 | mb_actions.append(actions) 208 | mb_values.append(values) 209 | mb_dones.append(self.dones) 210 | # len({obs, rewards, dones}) == nenvs 211 | obs, rewards, dones, _ = self.env.step(actions) 212 | self.states = states 213 | self.dones = dones 214 | for n, done in enumerate(dones): 215 | if done: 216 | self.obs[n] = self.obs[n] * 0 217 | # SubprocVecEnv automatically resets when done 218 | self.update_obs(obs) 219 | mb_rewards.append(rewards) 220 | mb_dones.append(self.dones) 221 | # batch of steps to batch of rollouts 222 | # i.e. from nsteps, nenvs to nenvs, nsteps 223 | mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0) 224 | mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0) 225 | mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0) 226 | mb_values = np.asarray(mb_values, dtype=np.float32).swapaxes(1, 0) 227 | mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0) 228 | mb_masks = mb_dones[:, :-1] 229 | # The first entry was just the init state of 'dones' (all False), 230 | # before we'd actually run any steps, so drop it. 231 | mb_dones = mb_dones[:, 1:] 232 | 233 | # Log original rewards 234 | for env_n, (rs, dones) in enumerate(zip(mb_rewards, mb_dones)): 235 | assert_equal(rs.shape, (self.nsteps, )) 236 | assert_equal(dones.shape, (self.nsteps, )) 237 | for step_n in range(self.nsteps): 238 | self.orig_reward[env_n] += rs[step_n] 239 | if dones[step_n]: 240 | easy_tf_log.tflog( 241 | "orig_reward_{}".format(env_n), 242 | self.orig_reward[env_n]) 243 | self.orig_reward[env_n] = 0 244 | 245 | if self.env.env_id == 'MovingDotNoFrameskip-v0': 246 | # For MovingDot, reward depends on both current observation and 247 | # current action, so encode action in the observations. 248 | # (We only need to set this in the most recent frame, 249 | # because that's all that the reward predictor for MovingDot 250 | # uses.) 251 | mb_obs[:, :, 0, 0, -1] = mb_actions[:, :] 252 | 253 | # Generate segments 254 | # (For MovingDot, this has to happen _after_ we've encoded the action 255 | # in the observations.) 256 | if self.gen_segments: 257 | self.update_segment_buffer(mb_obs, mb_rewards, mb_dones) 258 | 259 | # Replace rewards with those from reward predictor 260 | # (Note that this also needs to be done _after_ we've encoded the 261 | # action.) 262 | logging.debug("Original rewards:\n%s", mb_rewards) 263 | if self.reward_predictor: 264 | assert_equal(mb_obs.shape, (nenvs, self.nsteps, 84, 84, 4)) 265 | mb_obs_allenvs = mb_obs.reshape(nenvs * self.nsteps, 84, 84, 4) 266 | 267 | rewards_allenvs = self.reward_predictor.reward(mb_obs_allenvs) 268 | assert_equal(rewards_allenvs.shape, (nenvs * self.nsteps, )) 269 | mb_rewards = rewards_allenvs.reshape(nenvs, self.nsteps) 270 | assert_equal(mb_rewards.shape, (nenvs, self.nsteps)) 271 | 272 | logging.debug("Predicted rewards:\n%s", mb_rewards) 273 | 274 | # Save frames for episode rendering 275 | if self.episode_vid_queue is not None: 276 | self.update_episode_frame_buffer(mb_obs, mb_dones) 277 | 278 | # Discount rewards 279 | mb_obs = mb_obs.reshape(self.batch_ob_shape) 280 | last_values = self.model.value(self.obs, self.states, 281 | self.dones).tolist() 282 | # discount/bootstrap off value fn 283 | for n, (rewards, dones, value) in enumerate( 284 | zip(mb_rewards, mb_dones, last_values)): 285 | rewards = rewards.tolist() 286 | dones = dones.tolist() 287 | if dones[-1] == 0: 288 | # Make sure that the first iteration of the loop inside 289 | # discount_with_dones picks up 'value' as the initial 290 | # value of r 291 | rewards = discount_with_dones(rewards + [value], 292 | dones + [0], 293 | self.gamma)[:-1] 294 | else: 295 | rewards = discount_with_dones(rewards, dones, self.gamma) 296 | mb_rewards[n] = rewards 297 | 298 | mb_rewards = mb_rewards.flatten() 299 | mb_actions = mb_actions.flatten() 300 | mb_values = mb_values.flatten() 301 | mb_masks = mb_masks.flatten() 302 | return mb_obs, mb_states, mb_rewards, mb_masks, mb_actions, mb_values 303 | 304 | 305 | def learn(policy, 306 | env, 307 | seed, 308 | start_policy_training_pipe, 309 | ckpt_save_dir, 310 | lr_scheduler, 311 | nsteps=5, 312 | nstack=4, 313 | total_timesteps=int(80e6), 314 | vf_coef=0.5, 315 | ent_coef=0.01, 316 | max_grad_norm=0.5, 317 | epsilon=1e-5, 318 | alpha=0.99, 319 | gamma=0.99, 320 | log_interval=100, 321 | ckpt_save_interval=1000, 322 | ckpt_load_dir=None, 323 | gen_segments=False, 324 | seg_pipe=None, 325 | reward_predictor=None, 326 | episode_vid_queue=None): 327 | 328 | tf.reset_default_graph() 329 | set_global_seeds(seed) 330 | 331 | nenvs = env.num_envs 332 | ob_space = env.observation_space 333 | ac_space = env.action_space 334 | num_procs = len(env.remotes) # HACK 335 | 336 | def make_model(): 337 | return Model( 338 | policy=policy, 339 | ob_space=ob_space, 340 | ac_space=ac_space, 341 | nenvs=nenvs, 342 | nsteps=nsteps, 343 | nstack=nstack, 344 | num_procs=num_procs, 345 | ent_coef=ent_coef, 346 | vf_coef=vf_coef, 347 | max_grad_norm=max_grad_norm, 348 | lr_scheduler=lr_scheduler, 349 | alpha=alpha, 350 | epsilon=epsilon) 351 | 352 | with open(osp.join(ckpt_save_dir, 'make_model.pkl'), 'wb') as fh: 353 | fh.write(cloudpickle.dumps(make_model)) 354 | 355 | print("Initialising policy...") 356 | if ckpt_load_dir is None: 357 | model = make_model() 358 | else: 359 | with open(osp.join(ckpt_load_dir, 'make_model.pkl'), 'rb') as fh: 360 | make_model = cloudpickle.loads(fh.read()) 361 | model = make_model() 362 | 363 | ckpt_load_path = tf.train.latest_checkpoint(ckpt_load_dir) 364 | model.load(ckpt_load_path) 365 | print("Loaded policy from checkpoint '{}'".format(ckpt_load_path)) 366 | 367 | ckpt_save_path = osp.join(ckpt_save_dir, 'policy.ckpt') 368 | 369 | runner = Runner(env=env, 370 | model=model, 371 | nsteps=nsteps, 372 | nstack=nstack, 373 | gamma=gamma, 374 | gen_segments=gen_segments, 375 | seg_pipe=seg_pipe, 376 | reward_predictor=reward_predictor, 377 | episode_vid_queue=episode_vid_queue) 378 | 379 | # nsteps: e.g. 5 380 | # nenvs: e.g. 16 381 | nbatch = nenvs * nsteps 382 | fps_tstart = time.time() 383 | fps_nsteps = 0 384 | 385 | print("Starting workers") 386 | 387 | # Before we're told to start training the policy itself, 388 | # just generate segments for the reward predictor to be trained with 389 | while True: 390 | runner.run() 391 | try: 392 | start_policy_training_pipe.get(block=False) 393 | except queue.Empty: 394 | continue 395 | else: 396 | break 397 | 398 | print("Starting policy training") 399 | 400 | for update in range(1, total_timesteps // nbatch + 1): 401 | # Run for nsteps 402 | obs, states, rewards, masks, actions, values = runner.run() 403 | 404 | policy_loss, value_loss, policy_entropy, cur_lr = model.train( 405 | obs, states, rewards, masks, actions, values) 406 | 407 | fps_nsteps += nbatch 408 | 409 | if update % log_interval == 0 and update != 0: 410 | fps = fps_nsteps / (time.time() - fps_tstart) 411 | fps_nsteps = 0 412 | fps_tstart = time.time() 413 | 414 | print("Trained policy for {} time steps".format(update * nbatch)) 415 | 416 | ev = explained_variance(values, rewards) 417 | logger.record_tabular("nupdates", update) 418 | logger.record_tabular("total_timesteps", update * nbatch) 419 | logger.record_tabular("fps", fps) 420 | logger.record_tabular("policy_entropy", float(policy_entropy)) 421 | logger.record_tabular("value_loss", float(value_loss)) 422 | logger.record_tabular("explained_variance", float(ev)) 423 | logger.record_tabular("learning_rate", cur_lr) 424 | logger.dump_tabular() 425 | 426 | if update != 0 and update % ckpt_save_interval == 0: 427 | model.save(ckpt_save_path, update) 428 | 429 | model.save(ckpt_save_path, update) 430 | -------------------------------------------------------------------------------- /a2c/a2c/policies.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from a2c.a2c.utils import conv, fc, conv_to_fc, batch_to_seq, seq_to_batch, lstm, lnlstm, sample 4 | 5 | class LnLstmPolicy(object): 6 | def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, nlstm=256, reuse=False): 7 | nbatch = nenv*nsteps 8 | nh, nw, nc = ob_space.shape 9 | ob_shape = (nbatch, nh, nw, nc*nstack) 10 | nact = ac_space.n 11 | X = tf.placeholder(tf.uint8, ob_shape) #obs 12 | M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) 13 | S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states 14 | with tf.variable_scope("model", reuse=reuse): 15 | h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) 16 | h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) 17 | h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) 18 | h3 = conv_to_fc(h3) 19 | h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) 20 | xs = batch_to_seq(h4, nenv, nsteps) 21 | ms = batch_to_seq(M, nenv, nsteps) 22 | h5, snew = lnlstm(xs, ms, S, 'lstm1', nh=nlstm) 23 | h5 = seq_to_batch(h5) 24 | pi = fc(h5, 'pi', nact, act=lambda x:x) 25 | vf = fc(h5, 'v', 1, act=lambda x:x) 26 | 27 | v0 = vf[:, 0] 28 | a0 = sample(pi) 29 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) 30 | 31 | def step(ob, state, mask): 32 | a, v, s = sess.run([a0, v0, snew], {X:ob, S:state, M:mask}) 33 | return a, v, s 34 | 35 | def value(ob, state, mask): 36 | return sess.run(v0, {X:ob, S:state, M:mask}) 37 | 38 | self.X = X 39 | self.M = M 40 | self.S = S 41 | self.pi = pi 42 | self.vf = vf 43 | self.step = step 44 | self.value = value 45 | 46 | class LstmPolicy(object): 47 | 48 | def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, nlstm=256, reuse=False): 49 | nbatch = nenv*nsteps 50 | nh, nw, nc = ob_space.shape 51 | ob_shape = (nbatch, nh, nw, nc*nstack) 52 | nact = ac_space.n 53 | X = tf.placeholder(tf.uint8, ob_shape) #obs 54 | M = tf.placeholder(tf.float32, [nbatch]) #mask (done t-1) 55 | S = tf.placeholder(tf.float32, [nenv, nlstm*2]) #states 56 | with tf.variable_scope("model", reuse=reuse): 57 | h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) 58 | h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) 59 | h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) 60 | h3 = conv_to_fc(h3) 61 | h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) 62 | xs = batch_to_seq(h4, nenv, nsteps) 63 | ms = batch_to_seq(M, nenv, nsteps) 64 | h5, snew = lstm(xs, ms, S, 'lstm1', nh=nlstm) 65 | h5 = seq_to_batch(h5) 66 | pi = fc(h5, 'pi', nact, act=lambda x:x) 67 | vf = fc(h5, 'v', 1, act=lambda x:x) 68 | 69 | v0 = vf[:, 0] 70 | a0 = sample(pi) 71 | self.initial_state = np.zeros((nenv, nlstm*2), dtype=np.float32) 72 | 73 | def step(ob, state, mask): 74 | a, v, s = sess.run([a0, v0, snew], {X:ob, S:state, M:mask}) 75 | return a, v, s 76 | 77 | def value(ob, state, mask): 78 | return sess.run(v0, {X:ob, S:state, M:mask}) 79 | 80 | self.X = X 81 | self.M = M 82 | self.S = S 83 | self.pi = pi 84 | self.vf = vf 85 | self.step = step 86 | self.value = value 87 | 88 | class CnnPolicy(object): 89 | 90 | def __init__(self, sess, ob_space, ac_space, nenv, nsteps, nstack, reuse=False): 91 | nbatch = nenv*nsteps 92 | nh, nw, nc = ob_space.shape 93 | ob_shape = (nbatch, nh, nw, nc*nstack) 94 | nact = ac_space.n 95 | X = tf.placeholder(tf.uint8, ob_shape) #obs 96 | with tf.variable_scope("model", reuse=reuse): 97 | h = conv(tf.cast(X, tf.float32)/255., 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2)) 98 | h2 = conv(h, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2)) 99 | h3 = conv(h2, 'c3', nf=64, rf=3, stride=1, init_scale=np.sqrt(2)) 100 | h3 = conv_to_fc(h3) 101 | h4 = fc(h3, 'fc1', nh=512, init_scale=np.sqrt(2)) 102 | pi = fc(h4, 'pi', nact, act=lambda x:x) 103 | vf = fc(h4, 'v', 1, act=lambda x:x) 104 | 105 | v0 = vf[:, 0] 106 | a0 = sample(pi) 107 | self.initial_state = [] #not stateful 108 | 109 | def step(ob, *_args, **_kwargs): 110 | a, v = sess.run([a0, v0], {X:ob}) 111 | return a, v, [] #dummy state 112 | 113 | def value(ob, *_args, **_kwargs): 114 | return sess.run(v0, {X:ob}) 115 | 116 | self.X = X 117 | self.pi = pi 118 | self.vf = vf 119 | self.step = step 120 | self.value = value 121 | 122 | 123 | class MlpPolicy(object): 124 | 125 | def __init__(self, 126 | sess, 127 | ob_space, 128 | ac_space, 129 | nenv, 130 | nsteps, 131 | nstack, 132 | reuse=False): 133 | nbatch = nenv*nsteps 134 | nh, nw, nc = ob_space.shape 135 | ob_shape = (nbatch, nh, nw, nc*nstack) 136 | nact = ac_space.n 137 | X = tf.placeholder(tf.uint8, ob_shape) # obs 138 | with tf.variable_scope("model", reuse=reuse): 139 | x = tf.cast(X, tf.float32)/255. 140 | 141 | # Only look at the most recent frame 142 | x = x[:, :, :, -1] 143 | 144 | w, h = x.get_shape()[1:] 145 | x = tf.reshape(x, [-1, int(w * h)]) 146 | x = fc(x, 'fc1', nh=2048, init_scale=np.sqrt(2)) 147 | x = fc(x, 'fc2', nh=1024, init_scale=np.sqrt(2)) 148 | x = fc(x, 'fc3', nh=512, init_scale=np.sqrt(2)) 149 | pi = fc(x, 'pi', nact, act=lambda x: x) 150 | vf = fc(x, 'v', 1, act=lambda x: x) 151 | 152 | v0 = vf[:, 0] 153 | a0 = sample(pi) 154 | self.initial_state = [] # not stateful 155 | 156 | def step(ob, *_args, **_kwargs): 157 | a, v = sess.run([a0, v0], {X: ob}) 158 | return a, v, [] # dummy state 159 | 160 | def value(ob, *_args, **_kwargs): 161 | return sess.run(v0, {X: ob}) 162 | 163 | self.X = X 164 | self.pi = pi 165 | self.vf = vf 166 | self.step = step 167 | self.value = value 168 | -------------------------------------------------------------------------------- /a2c/a2c/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | from collections import deque 5 | 6 | def sample(logits): 7 | noise = tf.random_uniform(tf.shape(logits)) 8 | return tf.argmax(logits - tf.log(-tf.log(noise)), 1) 9 | 10 | def cat_entropy(logits): 11 | a0 = logits - tf.reduce_max(logits, 1, keepdims=True) 12 | ea0 = tf.exp(a0) 13 | z0 = tf.reduce_sum(ea0, 1, keepdims=True) 14 | p0 = ea0 / z0 15 | return tf.reduce_sum(p0 * (tf.log(z0) - a0), 1) 16 | 17 | def cat_entropy_softmax(p0): 18 | return - tf.reduce_sum(p0 * tf.log(p0 + 1e-6), axis = 1) 19 | 20 | def mse(pred, target): 21 | return tf.square(pred-target)/2. 22 | 23 | def ortho_init(scale=1.0): 24 | def _ortho_init(shape, dtype, partition_info=None): 25 | #lasagne ortho init for tf 26 | shape = tuple(shape) 27 | if len(shape) == 2: 28 | flat_shape = shape 29 | elif len(shape) == 4: # assumes NHWC 30 | flat_shape = (np.prod(shape[:-1]), shape[-1]) 31 | else: 32 | raise NotImplementedError 33 | a = np.random.normal(0.0, 1.0, flat_shape) 34 | u, _, v = np.linalg.svd(a, full_matrices=False) 35 | q = u if u.shape == flat_shape else v # pick the one with the correct shape 36 | q = q.reshape(shape) 37 | return (scale * q[:shape[0], :shape[1]]).astype(np.float32) 38 | return _ortho_init 39 | 40 | def conv(x, scope, nf, rf, stride, pad='VALID', act=tf.nn.relu, init_scale=1.0): 41 | with tf.variable_scope(scope): 42 | nin = x.get_shape()[3].value 43 | w = tf.get_variable("w", [rf, rf, nin, nf], initializer=ortho_init(init_scale)) 44 | b = tf.get_variable("b", [nf], initializer=tf.constant_initializer(0.0)) 45 | z = tf.nn.conv2d(x, w, strides=[1, stride, stride, 1], padding=pad)+b 46 | h = act(z) 47 | return h 48 | 49 | def fc(x, scope, nh, act=tf.nn.relu, init_scale=1.0): 50 | with tf.variable_scope(scope): 51 | nin = x.get_shape()[1].value 52 | w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale)) 53 | b = tf.get_variable("b", [nh], initializer=tf.constant_initializer(0.0)) 54 | z = tf.matmul(x, w)+b 55 | h = act(z) 56 | return h 57 | 58 | def batch_to_seq(h, nbatch, nsteps, flat=False): 59 | if flat: 60 | h = tf.reshape(h, [nbatch, nsteps]) 61 | else: 62 | h = tf.reshape(h, [nbatch, nsteps, -1]) 63 | return [tf.squeeze(v, [1]) for v in tf.split(axis=1, num_or_size_splits=nsteps, value=h)] 64 | 65 | def seq_to_batch(h, flat = False): 66 | shape = h[0].get_shape().as_list() 67 | if not flat: 68 | assert(len(shape) > 1) 69 | nh = h[0].get_shape()[-1].value 70 | return tf.reshape(tf.concat(axis=1, values=h), [-1, nh]) 71 | else: 72 | return tf.reshape(tf.stack(values=h, axis=1), [-1]) 73 | 74 | def lstm(xs, ms, s, scope, nh, init_scale=1.0): 75 | nbatch, nin = [v.value for v in xs[0].get_shape()] 76 | nsteps = len(xs) 77 | with tf.variable_scope(scope): 78 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) 79 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) 80 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0)) 81 | 82 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s) 83 | for idx, (x, m) in enumerate(zip(xs, ms)): 84 | c = c*(1-m) 85 | h = h*(1-m) 86 | z = tf.matmul(x, wx) + tf.matmul(h, wh) + b 87 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) 88 | i = tf.nn.sigmoid(i) 89 | f = tf.nn.sigmoid(f) 90 | o = tf.nn.sigmoid(o) 91 | u = tf.tanh(u) 92 | c = f*c + i*u 93 | h = o*tf.tanh(c) 94 | xs[idx] = h 95 | s = tf.concat(axis=1, values=[c, h]) 96 | return xs, s 97 | 98 | def _ln(x, g, b, e=1e-5, axes=[1]): 99 | u, s = tf.nn.moments(x, axes=axes, keep_dims=True) 100 | x = (x-u)/tf.sqrt(s+e) 101 | x = x*g+b 102 | return x 103 | 104 | def lnlstm(xs, ms, s, scope, nh, init_scale=1.0): 105 | nbatch, nin = [v.value for v in xs[0].get_shape()] 106 | nsteps = len(xs) 107 | with tf.variable_scope(scope): 108 | wx = tf.get_variable("wx", [nin, nh*4], initializer=ortho_init(init_scale)) 109 | gx = tf.get_variable("gx", [nh*4], initializer=tf.constant_initializer(1.0)) 110 | bx = tf.get_variable("bx", [nh*4], initializer=tf.constant_initializer(0.0)) 111 | 112 | wh = tf.get_variable("wh", [nh, nh*4], initializer=ortho_init(init_scale)) 113 | gh = tf.get_variable("gh", [nh*4], initializer=tf.constant_initializer(1.0)) 114 | bh = tf.get_variable("bh", [nh*4], initializer=tf.constant_initializer(0.0)) 115 | 116 | b = tf.get_variable("b", [nh*4], initializer=tf.constant_initializer(0.0)) 117 | 118 | gc = tf.get_variable("gc", [nh], initializer=tf.constant_initializer(1.0)) 119 | bc = tf.get_variable("bc", [nh], initializer=tf.constant_initializer(0.0)) 120 | 121 | c, h = tf.split(axis=1, num_or_size_splits=2, value=s) 122 | for idx, (x, m) in enumerate(zip(xs, ms)): 123 | c = c*(1-m) 124 | h = h*(1-m) 125 | z = _ln(tf.matmul(x, wx), gx, bx) + _ln(tf.matmul(h, wh), gh, bh) + b 126 | i, f, o, u = tf.split(axis=1, num_or_size_splits=4, value=z) 127 | i = tf.nn.sigmoid(i) 128 | f = tf.nn.sigmoid(f) 129 | o = tf.nn.sigmoid(o) 130 | u = tf.tanh(u) 131 | c = f*c + i*u 132 | h = o*tf.tanh(_ln(c, gc, bc)) 133 | xs[idx] = h 134 | s = tf.concat(axis=1, values=[c, h]) 135 | return xs, s 136 | 137 | def conv_to_fc(x): 138 | nh = np.prod([v.value for v in x.get_shape()[1:]]) 139 | x = tf.reshape(x, [-1, nh]) 140 | return x 141 | 142 | def discount_with_dones(rewards, dones, gamma): 143 | discounted = [] 144 | r = 0 145 | for reward, done in zip(rewards[::-1], dones[::-1]): 146 | r = reward + gamma * r * (1. - done) # fixed off by one bug 147 | discounted.append(r) 148 | return discounted[::-1] 149 | 150 | def find_trainable_variables(key): 151 | with tf.variable_scope(key): 152 | return tf.trainable_variables() 153 | 154 | def make_path(f): 155 | return os.makedirs(f, exist_ok=True) 156 | 157 | def constant(p): 158 | return 1 159 | 160 | def linear(p): 161 | return 1-p 162 | 163 | schedules = { 164 | 'linear':linear, 165 | 'constant':constant 166 | } 167 | 168 | class Scheduler(object): 169 | 170 | def __init__(self, v, nvalues, schedule): 171 | self.n = 0. 172 | self.v = v 173 | self.nvalues = nvalues 174 | self.schedule = schedules[schedule] 175 | 176 | def value(self): 177 | current_value = self.v*self.schedule(self.n/self.nvalues) 178 | self.n += 1. 179 | return current_value 180 | 181 | def value_steps(self, steps): 182 | return self.v*self.schedule(steps/self.nvalues) 183 | 184 | 185 | class EpisodeStats: 186 | def __init__(self, nsteps, nenvs): 187 | self.episode_rewards = [] 188 | for i in range(nenvs): 189 | self.episode_rewards.append([]) 190 | self.lenbuffer = deque(maxlen=40) # rolling buffer for episode lengths 191 | self.rewbuffer = deque(maxlen=40) # rolling buffer for episode rewards 192 | self.nsteps = nsteps 193 | self.nenvs = nenvs 194 | 195 | def feed(self, rewards, masks): 196 | rewards = np.reshape(rewards, [self.nenvs, self.nsteps]) 197 | masks = np.reshape(masks, [self.nenvs, self.nsteps]) 198 | for i in range(0, self.nenvs): 199 | for j in range(0, self.nsteps): 200 | self.episode_rewards[i].append(rewards[i][j]) 201 | if masks[i][j]: 202 | l = len(self.episode_rewards[i]) 203 | s = sum(self.episode_rewards[i]) 204 | self.lenbuffer.append(l) 205 | self.rewbuffer.append(s) 206 | self.episode_rewards[i] = [] 207 | 208 | def mean_length(self): 209 | if self.lenbuffer: 210 | return np.mean(self.lenbuffer) 211 | else: 212 | return 0 # on the first params dump, no episodes are finished 213 | 214 | def mean_reward(self): 215 | if self.rewbuffer: 216 | return np.mean(self.rewbuffer) 217 | else: 218 | return 0 219 | 220 | 221 | # For ACER 222 | def get_by_index(x, idx): 223 | assert(len(x.get_shape()) == 2) 224 | assert(len(idx.get_shape()) == 1) 225 | idx_flattened = tf.range(0, x.shape[0]) * x.shape[1] + idx 226 | y = tf.gather(tf.reshape(x, [-1]), # flatten input 227 | idx_flattened) # use flattened indices 228 | return y 229 | 230 | def check_shape(ts,shapes): 231 | i = 0 232 | for (t,shape) in zip(ts,shapes): 233 | assert t.get_shape().as_list()==shape, "id " + str(i) + " shape " + str(t.get_shape()) + str(shape) 234 | i += 1 235 | 236 | def avg_norm(t): 237 | return tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(t), axis=-1))) 238 | 239 | def myadd(g1, g2, param): 240 | print([g1, g2, param.name]) 241 | assert (not (g1 is None and g2 is None)), param.name 242 | if g1 is None: 243 | return g2 244 | elif g2 is None: 245 | return g1 246 | else: 247 | return g1 + g2 248 | 249 | def my_explained_variance(qpred, q): 250 | _, vary = tf.nn.moments(q, axes=[0, 1]) 251 | _, varpred = tf.nn.moments(q - qpred, axes=[0, 1]) 252 | check_shape([vary, varpred], [[]] * 2) 253 | return 1.0 - (varpred / vary) 254 | -------------------------------------------------------------------------------- /a2c/common/__init__.py: -------------------------------------------------------------------------------- 1 | from a2c.common.math_util import * 2 | from a2c.common.misc_util import * 3 | -------------------------------------------------------------------------------- /a2c/common/atari_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import deque 3 | from PIL import Image 4 | import gym 5 | from gym import spaces 6 | 7 | 8 | class NoopResetEnv(gym.Wrapper): 9 | def __init__(self, env, noop_max=30): 10 | """Sample initial states by taking random number of no-ops on reset. 11 | No-op is assumed to be action 0. 12 | """ 13 | gym.Wrapper.__init__(self, env) 14 | self.noop_max = noop_max 15 | self.override_num_noops = None 16 | assert env.unwrapped.get_action_meanings()[0] == 'NOOP' 17 | 18 | def _reset(self): 19 | """ Do no-op action for a number of steps in [1, noop_max].""" 20 | self.env.reset() 21 | if self.override_num_noops is not None: 22 | noops = self.override_num_noops 23 | else: 24 | noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) #pylint: disable=E1101 25 | assert noops > 0 26 | obs = None 27 | for _ in range(noops): 28 | obs, _, done, _ = self.env.step(0) 29 | if done: 30 | obs = self.env.reset() 31 | return obs 32 | 33 | class FireResetEnv(gym.Wrapper): 34 | def __init__(self, env): 35 | """Take action on reset for environments that are fixed until firing.""" 36 | gym.Wrapper.__init__(self, env) 37 | assert env.unwrapped.get_action_meanings()[1] == 'FIRE' 38 | assert len(env.unwrapped.get_action_meanings()) >= 3 39 | 40 | def _reset(self): 41 | self.env.reset() 42 | obs, _, done, _ = self.env.step(1) 43 | if done: 44 | self.env.reset() 45 | obs, _, done, _ = self.env.step(2) 46 | if done: 47 | self.env.reset() 48 | return obs 49 | 50 | class EpisodicLifeEnv(gym.Wrapper): 51 | def __init__(self, env): 52 | """Make end-of-life == end-of-episode, but only reset on true game over. 53 | Done by DeepMind for the DQN and co. since it helps value estimation. 54 | """ 55 | gym.Wrapper.__init__(self, env) 56 | self.lives = 0 57 | self.was_real_done = True 58 | 59 | def _step(self, action): 60 | obs, reward, done, info = self.env.step(action) 61 | self.was_real_done = done 62 | # check current lives, make loss of life terminal, 63 | # then update lives to handle bonus lives 64 | lives = self.env.unwrapped.ale.lives() 65 | if lives < self.lives and lives > 0: 66 | # for Qbert somtimes we stay in lives == 0 condtion for a few frames 67 | # so its important to keep lives > 0, so that we only reset once 68 | # the environment advertises done. 69 | done = True 70 | self.lives = lives 71 | return obs, reward, done, info 72 | 73 | def _reset(self): 74 | """Reset only when lives are exhausted. 75 | This way all states are still reachable even though lives are episodic, 76 | and the learner need not know about any of this behind-the-scenes. 77 | """ 78 | if self.was_real_done: 79 | obs = self.env.reset() 80 | else: 81 | # no-op step to advance from terminal/lost life state 82 | obs, _, _, _ = self.env.step(0) 83 | self.lives = self.env.unwrapped.ale.lives() 84 | return obs 85 | 86 | class MaxAndSkipEnv(gym.Wrapper): 87 | def __init__(self, env, skip=4): 88 | """Return only every `skip`-th frame""" 89 | gym.Wrapper.__init__(self, env) 90 | # most recent raw observations (for max pooling across time steps) 91 | self._obs_buffer = deque(maxlen=2) 92 | self._skip = skip 93 | 94 | def _step(self, action): 95 | """Repeat action, sum reward, and max over last observations.""" 96 | total_reward = 0.0 97 | done = None 98 | for _ in range(self._skip): 99 | obs, reward, done, info = self.env.step(action) 100 | self._obs_buffer.append(obs) 101 | total_reward += reward 102 | if done: 103 | break 104 | max_frame = np.max(np.stack(self._obs_buffer), axis=0) 105 | 106 | return max_frame, total_reward, done, info 107 | 108 | def _reset(self): 109 | """Clear past frame buffer and init. to first obs. from inner env.""" 110 | self._obs_buffer.clear() 111 | obs = self.env.reset() 112 | self._obs_buffer.append(obs) 113 | return obs 114 | 115 | class ClipRewardEnv(gym.RewardWrapper): 116 | def _reward(self, reward): 117 | """Bin reward to {+1, 0, -1} by its sign.""" 118 | return np.sign(reward) 119 | 120 | class WarpFrame(gym.ObservationWrapper): 121 | def __init__(self, env): 122 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 123 | gym.ObservationWrapper.__init__(self, env) 124 | self.res = 84 125 | self.observation_space = spaces.Box(low=0, high=255, shape=(self.res, self.res, 1)) 126 | 127 | def _observation(self, obs): 128 | frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32')) 129 | frame = np.array(Image.fromarray(frame).resize((self.res, self.res), 130 | resample=Image.BILINEAR), dtype=np.uint8) 131 | return frame.reshape((self.res, self.res, 1)) 132 | 133 | class FrameStack(gym.Wrapper): 134 | def __init__(self, env, k): 135 | """Buffer observations and stack across channels (last axis).""" 136 | gym.Wrapper.__init__(self, env) 137 | self.k = k 138 | self.frames = deque([], maxlen=k) 139 | shp = env.observation_space.shape 140 | assert shp[2] == 1 # can only stack 1-channel frames 141 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k)) 142 | 143 | def _reset(self): 144 | """Clear buffer and re-fill by duplicating the first observation.""" 145 | ob = self.env.reset() 146 | for _ in range(self.k): self.frames.append(ob) 147 | return self._observation() 148 | 149 | def _step(self, action): 150 | ob, reward, done, info = self.env.step(action) 151 | self.frames.append(ob) 152 | return self._observation(), reward, done, info 153 | 154 | def _observation(self): 155 | assert len(self.frames) == self.k 156 | return np.concatenate(self.frames, axis=2) 157 | 158 | def wrap_deepmind(env, episode_life=True, clip_rewards=True): 159 | """Configure environment for DeepMind-style Atari. 160 | 161 | Note: this does not include frame stacking!""" 162 | assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip 163 | if episode_life: 164 | env = EpisodicLifeEnv(env) 165 | env = NoopResetEnv(env, noop_max=30) 166 | env = MaxAndSkipEnv(env, skip=4) 167 | if 'FIRE' in env.unwrapped.get_action_meanings(): 168 | env = FireResetEnv(env) 169 | env = WarpFrame(env) 170 | if clip_rewards: 171 | env = ClipRewardEnv(env) 172 | return env 173 | -------------------------------------------------------------------------------- /a2c/common/math_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.signal 3 | 4 | 5 | def discount(x, gamma): 6 | """ 7 | computes discounted sums along 0th dimension of x. 8 | 9 | inputs 10 | ------ 11 | x: ndarray 12 | gamma: float 13 | 14 | outputs 15 | ------- 16 | y: ndarray with same shape as x, satisfying 17 | 18 | y[t] = x[t] + gamma*x[t+1] + gamma^2*x[t+2] + ... + gamma^k x[t+k], 19 | where k = len(x) - t - 1 20 | 21 | """ 22 | assert x.ndim >= 1 23 | return scipy.signal.lfilter([1],[1,-gamma],x[::-1], axis=0)[::-1] 24 | 25 | def explained_variance(ypred,y): 26 | """ 27 | Computes fraction of variance that ypred explains about y. 28 | Returns 1 - Var[y-ypred] / Var[y] 29 | 30 | interpretation: 31 | ev=0 => might as well have predicted zero 32 | ev=1 => perfect prediction 33 | ev<0 => worse than just predicting zero 34 | 35 | """ 36 | assert y.ndim == 1 and ypred.ndim == 1 37 | vary = np.var(y) 38 | return np.nan if vary==0 else 1 - np.var(y-ypred)/vary 39 | 40 | def explained_variance_2d(ypred, y): 41 | assert y.ndim == 2 and ypred.ndim == 2 42 | vary = np.var(y, axis=0) 43 | out = 1 - np.var(y-ypred)/vary 44 | out[vary < 1e-10] = 0 45 | return out 46 | 47 | def ncc(ypred, y): 48 | return np.corrcoef(ypred, y)[1,0] 49 | 50 | def flatten_arrays(arrs): 51 | return np.concatenate([arr.flat for arr in arrs]) 52 | 53 | def unflatten_vector(vec, shapes): 54 | i=0 55 | arrs = [] 56 | for shape in shapes: 57 | size = np.prod(shape) 58 | arr = vec[i:i+size].reshape(shape) 59 | arrs.append(arr) 60 | i += size 61 | return arrs 62 | 63 | def discount_with_boundaries(X, New, gamma): 64 | """ 65 | X: 2d array of floats, time x features 66 | New: 2d array of bools, indicating when a new episode has started 67 | """ 68 | Y = np.zeros_like(X) 69 | T = X.shape[0] 70 | Y[T-1] = X[T-1] 71 | for t in range(T-2, -1, -1): 72 | Y[t] = X[t] + gamma * Y[t+1] * (1 - New[t+1]) 73 | return Y 74 | 75 | def test_discount_with_boundaries(): 76 | gamma=0.9 77 | x = np.array([1.0, 2.0, 3.0, 4.0], 'float32') 78 | starts = [1.0, 0.0, 0.0, 1.0] 79 | y = discount_with_boundaries(x, starts, gamma) 80 | assert np.allclose(y, [ 81 | 1 + gamma * 2 + gamma**2 * 3, 82 | 2 + gamma * 3, 83 | 3, 84 | 4 85 | ]) 86 | -------------------------------------------------------------------------------- /a2c/common/misc_util.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import os 4 | import pickle 5 | import random 6 | import tempfile 7 | import time 8 | import zipfile 9 | 10 | 11 | def zipsame(*seqs): 12 | L = len(seqs[0]) 13 | assert all(len(seq) == L for seq in seqs[1:]) 14 | return zip(*seqs) 15 | 16 | 17 | def unpack(seq, sizes): 18 | """ 19 | Unpack 'seq' into a sequence of lists, with lengths specified by 'sizes'. 20 | None = just one bare element, not a list 21 | 22 | Example: 23 | unpack([1,2,3,4,5,6], [3,None,2]) -> ([1,2,3], 4, [5,6]) 24 | """ 25 | seq = list(seq) 26 | it = iter(seq) 27 | assert sum(1 if s is None else s for s in sizes) == len(seq), "Trying to unpack %s into %s" % (seq, sizes) 28 | for size in sizes: 29 | if size is None: 30 | yield it.__next__() 31 | else: 32 | li = [] 33 | for _ in range(size): 34 | li.append(it.__next__()) 35 | yield li 36 | 37 | 38 | class EzPickle(object): 39 | """Objects that are pickled and unpickled via their constructor 40 | arguments. 41 | 42 | Example usage: 43 | 44 | class Dog(Animal, EzPickle): 45 | def __init__(self, furcolor, tailkind="bushy"): 46 | Animal.__init__() 47 | EzPickle.__init__(furcolor, tailkind) 48 | ... 49 | 50 | When this object is unpickled, a new Dog will be constructed by passing the provided 51 | furcolor and tailkind into the constructor. However, philosophers are still not sure 52 | whether it is still the same dog. 53 | 54 | This is generally needed only for environments which wrap C/C++ code, such as MuJoCo 55 | and Atari. 56 | """ 57 | 58 | def __init__(self, *args, **kwargs): 59 | self._ezpickle_args = args 60 | self._ezpickle_kwargs = kwargs 61 | 62 | def __getstate__(self): 63 | return {"_ezpickle_args": self._ezpickle_args, "_ezpickle_kwargs": self._ezpickle_kwargs} 64 | 65 | def __setstate__(self, d): 66 | out = type(self)(*d["_ezpickle_args"], **d["_ezpickle_kwargs"]) 67 | self.__dict__.update(out.__dict__) 68 | 69 | 70 | def set_global_seeds(i): 71 | try: 72 | import tensorflow as tf 73 | except ImportError: 74 | pass 75 | else: 76 | tf.set_random_seed(i) 77 | np.random.seed(i) 78 | random.seed(i) 79 | 80 | 81 | def pretty_eta(seconds_left): 82 | """Print the number of seconds in human readable format. 83 | 84 | Examples: 85 | 2 days 86 | 2 hours and 37 minutes 87 | less than a minute 88 | 89 | Paramters 90 | --------- 91 | seconds_left: int 92 | Number of seconds to be converted to the ETA 93 | Returns 94 | ------- 95 | eta: str 96 | String representing the pretty ETA. 97 | """ 98 | minutes_left = seconds_left // 60 99 | seconds_left %= 60 100 | hours_left = minutes_left // 60 101 | minutes_left %= 60 102 | days_left = hours_left // 24 103 | hours_left %= 24 104 | 105 | def helper(cnt, name): 106 | return "{} {}{}".format(str(cnt), name, ('s' if cnt > 1 else '')) 107 | 108 | if days_left > 0: 109 | msg = helper(days_left, 'day') 110 | if hours_left > 0: 111 | msg += ' and ' + helper(hours_left, 'hour') 112 | return msg 113 | if hours_left > 0: 114 | msg = helper(hours_left, 'hour') 115 | if minutes_left > 0: 116 | msg += ' and ' + helper(minutes_left, 'minute') 117 | return msg 118 | if minutes_left > 0: 119 | return helper(minutes_left, 'minute') 120 | return 'less than a minute' 121 | 122 | 123 | class RunningAvg(object): 124 | def __init__(self, gamma, init_value=None): 125 | """Keep a running estimate of a quantity. This is a bit like mean 126 | but more sensitive to recent changes. 127 | 128 | Parameters 129 | ---------- 130 | gamma: float 131 | Must be between 0 and 1, where 0 is the most sensitive to recent 132 | changes. 133 | init_value: float or None 134 | Initial value of the estimate. If None, it will be set on the first update. 135 | """ 136 | self._value = init_value 137 | self._gamma = gamma 138 | 139 | def update(self, new_val): 140 | """Update the estimate. 141 | 142 | Parameters 143 | ---------- 144 | new_val: float 145 | new observated value of estimated quantity. 146 | """ 147 | if self._value is None: 148 | self._value = new_val 149 | else: 150 | self._value = self._gamma * self._value + (1.0 - self._gamma) * new_val 151 | 152 | def __float__(self): 153 | """Get the current estimate""" 154 | return self._value 155 | 156 | 157 | class SimpleMonitor(gym.Wrapper): 158 | def __init__(self, env): 159 | """Adds two qunatities to info returned by every step: 160 | 161 | num_steps: int 162 | Number of steps takes so far 163 | rewards: [float] 164 | All the cumulative rewards for the episodes completed so far. 165 | """ 166 | super().__init__(env) 167 | # current episode state 168 | self._current_reward = None 169 | self._num_steps = None 170 | # temporary monitor state that we do not save 171 | self._time_offset = None 172 | self._total_steps = None 173 | # monitor state 174 | self._episode_rewards = [] 175 | self._episode_lengths = [] 176 | self._episode_end_times = [] 177 | 178 | def _reset(self): 179 | obs = self.env.reset() 180 | # recompute temporary state if needed 181 | if self._time_offset is None: 182 | self._time_offset = time.time() 183 | if len(self._episode_end_times) > 0: 184 | self._time_offset -= self._episode_end_times[-1] 185 | if self._total_steps is None: 186 | self._total_steps = sum(self._episode_lengths) 187 | # update monitor state 188 | if self._current_reward is not None: 189 | self._episode_rewards.append(self._current_reward) 190 | self._episode_lengths.append(self._num_steps) 191 | self._episode_end_times.append(time.time() - self._time_offset) 192 | # reset episode state 193 | self._current_reward = 0 194 | self._num_steps = 0 195 | 196 | return obs 197 | 198 | def _step(self, action): 199 | obs, rew, done, info = self.env.step(action) 200 | self._current_reward += rew 201 | self._num_steps += 1 202 | self._total_steps += 1 203 | info['steps'] = self._total_steps 204 | info['rewards'] = self._episode_rewards 205 | return (obs, rew, done, info) 206 | 207 | def get_state(self): 208 | return { 209 | 'env_id': self.env.unwrapped.spec.id, 210 | 'episode_data': { 211 | 'episode_rewards': self._episode_rewards, 212 | 'episode_lengths': self._episode_lengths, 213 | 'episode_end_times': self._episode_end_times, 214 | 'initial_reset_time': 0, 215 | } 216 | } 217 | 218 | def set_state(self, state): 219 | assert state['env_id'] == self.env.unwrapped.spec.id 220 | ed = state['episode_data'] 221 | self._episode_rewards = ed['episode_rewards'] 222 | self._episode_lengths = ed['episode_lengths'] 223 | self._episode_end_times = ed['episode_end_times'] 224 | 225 | 226 | def boolean_flag(parser, name, default=False, help=None): 227 | """Add a boolean flag to argparse parser. 228 | 229 | Parameters 230 | ---------- 231 | parser: argparse.Parser 232 | parser to add the flag to 233 | name: str 234 | -- will enable the flag, while --no- will disable it 235 | default: bool or None 236 | default value of the flag 237 | help: str 238 | help string for the flag 239 | """ 240 | dest = name.replace('-', '_') 241 | parser.add_argument("--" + name, action="store_true", default=default, dest=dest, help=help) 242 | parser.add_argument("--no-" + name, action="store_false", dest=dest) 243 | 244 | 245 | def get_wrapper_by_name(env, classname): 246 | """Given an a gym environment possibly wrapped multiple times, returns a wrapper 247 | of class named classname or raises ValueError if no such wrapper was applied 248 | 249 | Parameters 250 | ---------- 251 | env: gym.Env of gym.Wrapper 252 | gym environment 253 | classname: str 254 | name of the wrapper 255 | 256 | Returns 257 | ------- 258 | wrapper: gym.Wrapper 259 | wrapper named classname 260 | """ 261 | currentenv = env 262 | while True: 263 | if classname == currentenv.class_name(): 264 | return currentenv 265 | elif isinstance(currentenv, gym.Wrapper): 266 | currentenv = currentenv.env 267 | else: 268 | raise ValueError("Couldn't find wrapper named %s" % classname) 269 | 270 | 271 | def relatively_safe_pickle_dump(obj, path, compression=False): 272 | """This is just like regular pickle dump, except from the fact that failure cases are 273 | different: 274 | 275 | - It's never possible that we end up with a pickle in corrupted state. 276 | - If a there was a different file at the path, that file will remain unchanged in the 277 | even of failure (provided that filesystem rename is atomic). 278 | - it is sometimes possible that we end up with useless temp file which needs to be 279 | deleted manually (it will be removed automatically on the next function call) 280 | 281 | The indended use case is periodic checkpoints of experiment state, such that we never 282 | corrupt previous checkpoints if the current one fails. 283 | 284 | Parameters 285 | ---------- 286 | obj: object 287 | object to pickle 288 | path: str 289 | path to the output file 290 | compression: bool 291 | if true pickle will be compressed 292 | """ 293 | temp_storage = path + ".relatively_safe" 294 | if compression: 295 | # Using gzip here would be simpler, but the size is limited to 2GB 296 | with tempfile.NamedTemporaryFile() as uncompressed_file: 297 | pickle.dump(obj, uncompressed_file) 298 | with zipfile.ZipFile(temp_storage, "w", compression=zipfile.ZIP_DEFLATED) as myzip: 299 | myzip.write(uncompressed_file.name, "data") 300 | else: 301 | with open(temp_storage, "wb") as f: 302 | pickle.dump(obj, f) 303 | os.rename(temp_storage, path) 304 | 305 | 306 | def pickle_load(path, compression=False): 307 | """Unpickle a possible compressed pickle. 308 | 309 | Parameters 310 | ---------- 311 | path: str 312 | path to the output file 313 | compression: bool 314 | if true assumes that pickle was compressed when created and attempts decompression. 315 | 316 | Returns 317 | ------- 318 | obj: object 319 | the unpickled object 320 | """ 321 | 322 | if compression: 323 | with zipfile.ZipFile(path, "r", compression=zipfile.ZIP_DEFLATED) as myzip: 324 | with myzip.open("data") as f: 325 | return pickle.load(f) 326 | else: 327 | with open(path, "rb") as f: 328 | return pickle.load(f) 329 | -------------------------------------------------------------------------------- /a2c/common/schedules.py: -------------------------------------------------------------------------------- 1 | """This file is used for specifying various schedules that evolve over 2 | time throughout the execution of the algorithm, such as: 3 | - learning rate for the optimizer 4 | - exploration epsilon for the epsilon greedy exploration strategy 5 | - beta parameter for beta parameter in prioritized replay 6 | 7 | Each schedule has a function `value(t)` which returns the current value 8 | of the parameter given the timestep t of the optimization procedure. 9 | """ 10 | 11 | 12 | class Schedule(object): 13 | def value(self, t): 14 | """Value of the schedule at time t""" 15 | raise NotImplementedError() 16 | 17 | 18 | class ConstantSchedule(object): 19 | def __init__(self, value): 20 | """Value remains constant over time. 21 | 22 | Parameters 23 | ---------- 24 | value: float 25 | Constant value of the schedule 26 | """ 27 | self._v = value 28 | 29 | def value(self, t): 30 | """See Schedule.value""" 31 | return self._v 32 | 33 | 34 | def linear_interpolation(l, r, alpha): 35 | return l + alpha * (r - l) 36 | 37 | 38 | class PiecewiseSchedule(object): 39 | def __init__(self, endpoints, interpolation=linear_interpolation, outside_value=None): 40 | """Piecewise schedule. 41 | 42 | endpoints: [(int, int)] 43 | list of pairs `(time, value)` meanining that schedule should output 44 | `value` when `t==time`. All the values for time must be sorted in 45 | an increasing order. When t is between two times, e.g. `(time_a, value_a)` 46 | and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs 47 | `interpolation(value_a, value_b, alpha)` where alpha is a fraction of 48 | time passed between `time_a` and `time_b` for time `t`. 49 | interpolation: lambda float, float, float: float 50 | a function that takes value to the left and to the right of t according 51 | to the `endpoints`. Alpha is the fraction of distance from left endpoint to 52 | right endpoint that t has covered. See linear_interpolation for example. 53 | outside_value: float 54 | if the value is requested outside of all the intervals sepecified in 55 | `endpoints` this value is returned. If None then AssertionError is 56 | raised when outside value is requested. 57 | """ 58 | idxes = [e[0] for e in endpoints] 59 | assert idxes == sorted(idxes) 60 | self._interpolation = interpolation 61 | self._outside_value = outside_value 62 | self._endpoints = endpoints 63 | 64 | def value(self, t): 65 | """See Schedule.value""" 66 | for (l_t, l), (r_t, r) in zip(self._endpoints[:-1], self._endpoints[1:]): 67 | if l_t <= t and t < r_t: 68 | alpha = float(t - l_t) / (r_t - l_t) 69 | return self._interpolation(l, r, alpha) 70 | 71 | # t does not belong to any of the pieces, so doom. 72 | assert self._outside_value is not None 73 | return self._outside_value 74 | 75 | 76 | class LinearSchedule(object): 77 | def __init__(self, schedule_timesteps, final_p, initial_p=1.0): 78 | """Linear interpolation between initial_p and final_p over 79 | schedule_timesteps. After this many timesteps pass final_p is 80 | returned. 81 | 82 | Parameters 83 | ---------- 84 | schedule_timesteps: int 85 | Number of timesteps for which to linearly anneal initial_p 86 | to final_p 87 | initial_p: float 88 | initial output value 89 | final_p: float 90 | final output value 91 | """ 92 | self.schedule_timesteps = schedule_timesteps 93 | self.final_p = final_p 94 | self.initial_p = initial_p 95 | 96 | def value(self, t): 97 | """See Schedule.value""" 98 | fraction = min(float(t) / self.schedule_timesteps, 1.0) 99 | return self.initial_p + fraction * (self.final_p - self.initial_p) 100 | -------------------------------------------------------------------------------- /a2c/common/tests/test_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from a2c.common.schedules import ConstantSchedule, PiecewiseSchedule 4 | 5 | 6 | def test_piecewise_schedule(): 7 | ps = PiecewiseSchedule([(-5, 100), (5, 200), (10, 50), (100, 50), (200, -50)], outside_value=500) 8 | 9 | assert np.isclose(ps.value(-10), 500) 10 | assert np.isclose(ps.value(0), 150) 11 | assert np.isclose(ps.value(5), 200) 12 | assert np.isclose(ps.value(9), 80) 13 | assert np.isclose(ps.value(50), 50) 14 | assert np.isclose(ps.value(80), 50) 15 | assert np.isclose(ps.value(150), 0) 16 | assert np.isclose(ps.value(175), -25) 17 | assert np.isclose(ps.value(201), 500) 18 | assert np.isclose(ps.value(500), 500) 19 | 20 | assert np.isclose(ps.value(200 - 1e-10), -50) 21 | 22 | 23 | def test_constant_schedule(): 24 | cs = ConstantSchedule(5) 25 | for i in range(-100, 100): 26 | assert np.isclose(cs.value(i), 5) 27 | -------------------------------------------------------------------------------- /a2c/common/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | class VecEnv(object): 2 | """ 3 | Vectorized environment base class 4 | """ 5 | def step(self, vac): 6 | """ 7 | Apply sequence of actions to sequence of environments 8 | actions -> (observations, rewards, news) 9 | 10 | where 'news' is a boolean vector indicating whether each element is new. 11 | """ 12 | raise NotImplementedError 13 | def reset(self): 14 | """ 15 | Reset all environments 16 | """ 17 | raise NotImplementedError 18 | def close(self): 19 | pass 20 | -------------------------------------------------------------------------------- /a2c/common/vec_env/subproc_vec_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from multiprocessing import Process, Pipe 3 | from a2c.common.vec_env import VecEnv 4 | 5 | def worker(remote, env_fn_wrapper): 6 | env = env_fn_wrapper.x() 7 | while True: 8 | cmd, data = remote.recv() 9 | if cmd == 'step': 10 | ob, reward, done, info = env.step(data) 11 | if done: 12 | ob = env.reset() 13 | remote.send((ob, reward, done, info)) 14 | elif cmd == 'reset': 15 | ob = env.reset() 16 | remote.send(ob) 17 | elif cmd == 'close': 18 | remote.close() 19 | break 20 | elif cmd == 'get_spaces': 21 | remote.send((env.action_space, env.observation_space)) 22 | elif cmd == 'get_action_meanings': 23 | remote.send(env.unwrapped.get_action_meanings()) 24 | else: 25 | raise NotImplementedError 26 | 27 | class CloudpickleWrapper(object): 28 | """ 29 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 30 | """ 31 | def __init__(self, x): 32 | self.x = x 33 | def __getstate__(self): 34 | import cloudpickle 35 | return cloudpickle.dumps(self.x) 36 | def __setstate__(self, ob): 37 | import pickle 38 | self.x = pickle.loads(ob) 39 | 40 | class SubprocVecEnv(VecEnv): 41 | def __init__(self, env_id, env_fns): 42 | """ 43 | envs: list of gym environments to run in subprocesses 44 | """ 45 | nenvs = len(env_fns) 46 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 47 | self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn))) 48 | for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] 49 | for p in self.ps: 50 | p.start() 51 | 52 | self.remotes[0].send(('get_spaces', None)) 53 | self.action_space, self.observation_space = self.remotes[0].recv() 54 | 55 | self.remotes[0].send(('get_action_meanings', None)) 56 | self.action_meanings = self.remotes[0].recv() 57 | 58 | self.env_id = env_id 59 | 60 | 61 | def step(self, actions): 62 | for remote, action in zip(self.remotes, actions): 63 | remote.send(('step', action)) 64 | results = [remote.recv() for remote in self.remotes] 65 | obs, rews, dones, infos = zip(*results) 66 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 67 | 68 | def reset(self): 69 | for remote in self.remotes: 70 | remote.send(('reset', None)) 71 | return np.stack([remote.recv() for remote in self.remotes]) 72 | 73 | def close(self): 74 | for remote in self.remotes: 75 | remote.send(('close', None)) 76 | for p in self.ps: 77 | p.join() 78 | 79 | @property 80 | def num_envs(self): 81 | return len(self.remotes) 82 | -------------------------------------------------------------------------------- /a2c/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import shutil 4 | import os.path as osp 5 | import json 6 | import time 7 | import datetime 8 | import tempfile 9 | 10 | LOG_OUTPUT_FORMATS = ['stdout', 'log', 'json'] 11 | 12 | DEBUG = 10 13 | INFO = 20 14 | WARN = 30 15 | ERROR = 40 16 | 17 | DISABLED = 50 18 | 19 | class OutputFormat(object): 20 | def writekvs(self, kvs): 21 | """ 22 | Write key-value pairs 23 | """ 24 | raise NotImplementedError 25 | 26 | def writeseq(self, args): 27 | """ 28 | Write a sequence of other data (e.g. a logging message) 29 | """ 30 | pass 31 | 32 | def close(self): 33 | return 34 | 35 | 36 | class HumanOutputFormat(OutputFormat): 37 | def __init__(self, file): 38 | self.file = file 39 | 40 | def writekvs(self, kvs): 41 | # Create strings for printing 42 | key2str = {} 43 | for (key, val) in sorted(kvs.items()): 44 | if isinstance(val, float): 45 | valstr = '%-8.3g' % (val,) 46 | else: 47 | valstr = str(val) 48 | key2str[self._truncate(key)] = self._truncate(valstr) 49 | 50 | # Find max widths 51 | keywidth = max(map(len, key2str.keys())) 52 | valwidth = max(map(len, key2str.values())) 53 | 54 | # Write out the data 55 | dashes = '-' * (keywidth + valwidth + 7) 56 | lines = [dashes] 57 | for (key, val) in sorted(key2str.items()): 58 | lines.append('| %s%s | %s%s |' % ( 59 | key, 60 | ' ' * (keywidth - len(key)), 61 | val, 62 | ' ' * (valwidth - len(val)), 63 | )) 64 | lines.append(dashes) 65 | self.file.write('\n'.join(lines) + '\n') 66 | 67 | # Flush the output to the file 68 | self.file.flush() 69 | 70 | def _truncate(self, s): 71 | return s[:20] + '...' if len(s) > 23 else s 72 | 73 | def writeseq(self, args): 74 | for arg in args: 75 | self.file.write(arg) 76 | self.file.write('\n') 77 | self.file.flush() 78 | 79 | class JSONOutputFormat(OutputFormat): 80 | def __init__(self, file): 81 | self.file = file 82 | 83 | def writekvs(self, kvs): 84 | for k, v in sorted(kvs.items()): 85 | if hasattr(v, 'dtype'): 86 | v = v.tolist() 87 | kvs[k] = float(v) 88 | self.file.write(json.dumps(kvs) + '\n') 89 | self.file.flush() 90 | 91 | class TensorBoardOutputFormat(OutputFormat): 92 | """ 93 | Dumps key/value pairs into TensorBoard's numeric format. 94 | """ 95 | def __init__(self, dir): 96 | os.makedirs(dir, exist_ok=True) 97 | self.dir = dir 98 | self.step = 1 99 | prefix = 'events' 100 | path = osp.join(osp.abspath(dir), prefix) 101 | import tensorflow as tf 102 | from tensorflow.python import pywrap_tensorflow 103 | from tensorflow.core.util import event_pb2 104 | from tensorflow.python.util import compat 105 | self.tf = tf 106 | self.event_pb2 = event_pb2 107 | self.pywrap_tensorflow = pywrap_tensorflow 108 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 109 | 110 | def writekvs(self, kvs): 111 | def summary_val(k, v): 112 | kwargs = {'tag': k, 'simple_value': float(v)} 113 | return self.tf.Summary.Value(**kwargs) 114 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 115 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 116 | event.step = self.step # is there any reason why you'd want to specify the step? 117 | self.writer.WriteEvent(event) 118 | self.writer.Flush() 119 | self.step += 1 120 | 121 | def close(self): 122 | if self.writer: 123 | self.writer.Close() 124 | self.writer = None 125 | 126 | 127 | def make_output_format(format, ev_dir): 128 | os.makedirs(ev_dir, exist_ok=True) 129 | if format == 'stdout': 130 | return HumanOutputFormat(sys.stdout) 131 | elif format == 'log': 132 | log_file = open(osp.join(ev_dir, 'log.txt'), 'wt') 133 | return HumanOutputFormat(log_file) 134 | elif format == 'json': 135 | json_file = open(osp.join(ev_dir, 'progress.json'), 'wt') 136 | return JSONOutputFormat(json_file) 137 | elif format == 'tensorboard': 138 | return TensorBoardOutputFormat(osp.join(ev_dir, 'tb')) 139 | else: 140 | raise ValueError('Unknown format specified: %s' % (format,)) 141 | 142 | # ================================================================ 143 | # API 144 | # ================================================================ 145 | 146 | def logkv(key, val): 147 | """ 148 | Log a value of some diagnostic 149 | Call this once for each diagnostic quantity, each iteration 150 | """ 151 | Logger.CURRENT.logkv(key, val) 152 | 153 | def logkvs(d): 154 | """ 155 | Log a dictionary of key-value pairs 156 | """ 157 | for (k, v) in d.items(): 158 | logkv(k, v) 159 | 160 | def dumpkvs(): 161 | """ 162 | Write all of the diagnostics from the current iteration 163 | 164 | level: int. (see logger.py docs) If the global logger level is higher than 165 | the level argument here, don't print to stdout. 166 | """ 167 | Logger.CURRENT.dumpkvs() 168 | 169 | def getkvs(): 170 | return Logger.CURRENT.name2val 171 | 172 | 173 | def log(*args, level=INFO): 174 | """ 175 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 176 | """ 177 | Logger.CURRENT.log(*args, level=level) 178 | 179 | 180 | def debug(*args): 181 | log(*args, level=DEBUG) 182 | 183 | 184 | def info(*args): 185 | log(*args, level=INFO) 186 | 187 | 188 | def warn(*args): 189 | log(*args, level=WARN) 190 | 191 | 192 | def error(*args): 193 | log(*args, level=ERROR) 194 | 195 | 196 | def set_level(level): 197 | """ 198 | Set logging threshold on current logger. 199 | """ 200 | Logger.CURRENT.set_level(level) 201 | 202 | def get_dir(): 203 | """ 204 | Get directory that log files are being written to. 205 | will be None if there is no output directory (i.e., if you didn't call start) 206 | """ 207 | return Logger.CURRENT.get_dir() 208 | 209 | record_tabular = logkv 210 | dump_tabular = dumpkvs 211 | 212 | # ================================================================ 213 | # Backend 214 | # ================================================================ 215 | 216 | class Logger(object): 217 | DEFAULT = None # A logger with no output files. (See right below class definition) 218 | # So that you can still log to the terminal without setting up any output files 219 | CURRENT = None # Current logger being used by the free functions above 220 | 221 | def __init__(self, dir, output_formats): 222 | self.name2val = {} # values this iteration 223 | self.level = INFO 224 | self.dir = dir 225 | self.output_formats = output_formats 226 | 227 | # Logging API, forwarded 228 | # ---------------------------------------- 229 | def logkv(self, key, val): 230 | self.name2val[key] = val 231 | 232 | def dumpkvs(self): 233 | if self.level == DISABLED: return 234 | for fmt in self.output_formats: 235 | fmt.writekvs(self.name2val) 236 | self.name2val.clear() 237 | 238 | def log(self, *args, level=INFO): 239 | if self.level <= level: 240 | self._do_log(args) 241 | 242 | # Configuration 243 | # ---------------------------------------- 244 | def set_level(self, level): 245 | self.level = level 246 | 247 | def get_dir(self): 248 | return self.dir 249 | 250 | def close(self): 251 | for fmt in self.output_formats: 252 | fmt.close() 253 | 254 | # Misc 255 | # ---------------------------------------- 256 | def _do_log(self, args): 257 | for fmt in self.output_formats: 258 | fmt.writeseq(args) 259 | 260 | Logger.DEFAULT = Logger.CURRENT = Logger(dir=None, output_formats=[HumanOutputFormat(sys.stdout)]) 261 | 262 | def configure(dir=None, format_strs=None): 263 | assert Logger.CURRENT is Logger.DEFAULT,\ 264 | "Only call logger.configure() when it's in the default state. Try calling logger.reset() first." 265 | prevlogger = Logger.CURRENT 266 | if dir is None: 267 | dir = os.getenv('OPENAI_LOGDIR') 268 | if dir is None: 269 | dir = osp.join(tempfile.gettempdir(), 270 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f")) 271 | if format_strs is None: 272 | format_strs = LOG_OUTPUT_FORMATS 273 | output_formats = [make_output_format(f, dir) for f in format_strs] 274 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats) 275 | log('Logging to %s'%dir) 276 | 277 | if os.getenv('OPENAI_LOGDIR'): 278 | # if OPENAI_LOGDIR is set, configure the logger on import 279 | # this kind of nasty (unexpected to user), but I don't know how else to inject the logger 280 | # to a script that's getting run in a subprocess 281 | configure(dir=os.getenv('OPENAI_LOGDIR')) 282 | 283 | def reset(): 284 | Logger.CURRENT = Logger.DEFAULT 285 | log('Reset logger') 286 | 287 | # ================================================================ 288 | 289 | def _demo(): 290 | info("hi") 291 | debug("shouldn't appear") 292 | set_level(DEBUG) 293 | debug("should appear") 294 | dir = "/tmp/testlogging" 295 | if os.path.exists(dir): 296 | shutil.rmtree(dir) 297 | with session(dir=dir): 298 | logkv("a", 3) 299 | logkv("b", 2.5) 300 | dumpkvs() 301 | logkv("b", -2.5) 302 | logkv("a", 5.5) 303 | dumpkvs() 304 | info("^^^ should see a = 5.5") 305 | 306 | logkv("b", -2.5) 307 | dumpkvs() 308 | 309 | logkv("a", "longasslongasslongasslongasslongasslongassvalue") 310 | dumpkvs() 311 | 312 | 313 | if __name__ == "__main__": 314 | _demo() 315 | -------------------------------------------------------------------------------- /enduro_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | An environment wrapper for Enduro which blanks out the speedometer (so that the 3 | agent doesn't inadvertently learn reward-related information from it) and 4 | signals 'done' once weather begins to change (so that the observations don't 5 | change so much and therefore the reward predictor can learn more easily). 6 | """ 7 | 8 | from gym import Wrapper 9 | 10 | 11 | class EnduroWrapper(Wrapper): 12 | def __init__(self, env): 13 | super(EnduroWrapper, self).__init__(env) 14 | assert str(env) == '>>' 15 | self._steps = None 16 | 17 | def step(self, action): 18 | observation, reward, done, info = self.env.step(action) 19 | # Blank out all the speedometer stuff 20 | observation[160:] = 0 21 | self._steps += 1 22 | # Done once the weather starts to change 23 | if self._steps == 3000: 24 | done = True 25 | return observation, reward, done, info 26 | 27 | def reset(self): 28 | self._steps = 0 29 | return self.env.reset() 30 | -------------------------------------------------------------------------------- /floydhub_utils/create_floyd_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Create a FloydHub dataset consisting of files changed by the setup wrapper 4 | # (e.g. packages dependencies) for quicker launching of jobs 5 | 6 | set -o errexit 7 | 8 | touch before_file 9 | 10 | bash floyd_wrapper.sh 11 | 12 | echo "Copying changed files..." 13 | find / -type f -newer before_file | grep -v -e '^/proc' -e '^/sys' -e '^/output' -e '^/code' -e '^/floydlocaldata' -e '^/root' | xargs -i cp --parents {} /output 14 | echo "Done!" 15 | -------------------------------------------------------------------------------- /floydhub_utils/floyd_wrapper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install dependencies, then run whatever was specified in the arguments 4 | 5 | set -o errexit 6 | 7 | # By default, FloydHub copies files with a timestamp of 0 seconds since epoch, 8 | # which breaks pip sometimes 9 | find . | xargs touch 10 | 11 | pip install pipenv 12 | # --site-packages so that we pick up the system TensorFlow 13 | pipenv --site-packages install 14 | 15 | pipenv run $* 16 | -------------------------------------------------------------------------------- /floydhub_utils/floyd_wrapper_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copy files from base dataset (created by create_floyd_base.sh), 4 | # then run whatever was specified in the arguments 5 | 6 | set -o errexit 7 | 8 | echo "Copying base files..." 9 | cp -r /base_files/* / 10 | echo "done!" 11 | 12 | $* 13 | -------------------------------------------------------------------------------- /floydhub_utils/get_dir.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Download the specified directory from the specified job's output files. 5 | """ 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import subprocess 11 | import time 12 | from multiprocessing import Process 13 | 14 | 15 | def getfile(f): 16 | dirname = osp.dirname(f) 17 | os.makedirs(dirname, exist_ok=True) 18 | print("Downloading {}...".format(f)) 19 | cmd = "floyd data getfile {}/output {}".format(args.job_id, f) 20 | subprocess.check_output(cmd.split()) 21 | os.rename(osp.basename(f), osp.join(dirname, osp.basename(f))) 22 | 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("job_id") 26 | parser.add_argument("dir") 27 | args = parser.parse_args() 28 | 29 | print("Listing files...") 30 | cmd = "floyd data listfiles {}/output".format(args.job_id) 31 | allfiles = subprocess.check_output(cmd.split()).decode().split('\n') 32 | dirfiles = [f for f in allfiles if f.startswith(args.dir + '/')] 33 | dirfiles = [f for f in dirfiles if not f.endswith('/')] 34 | 35 | for f in dirfiles: 36 | Process(target=getfile, args=(f, )).start() 37 | time.sleep(0.5) 38 | -------------------------------------------------------------------------------- /floydhub_utils/get_events.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Download all TensorFlow event files from the specified jobs' output files. 5 | """ 6 | 7 | import argparse 8 | import os 9 | import os.path as osp 10 | import subprocess 11 | import time 12 | from multiprocessing import Process 13 | 14 | 15 | def get(job_id, out_dir): 16 | print("Listing files...") 17 | cmd = "floyd data listfiles {}/output".format(job_id) 18 | files = subprocess.check_output(cmd.split()).decode().split('\n') 19 | event_files = [f for f in files if 'events.out.tfevents' in f] 20 | download_dir = osp.join(out_dir, job_id) 21 | os.makedirs(download_dir, exist_ok=True) 22 | for event_file in event_files: 23 | print("Downloading {}...".format(event_file)) 24 | cmd = "floyd data getfile {}/output {}".format(job_id, event_file) 25 | subprocess.call(cmd.split()) 26 | 27 | path = os.path.dirname(event_file) 28 | fname = os.path.basename(event_file) 29 | full_dir = osp.join(download_dir, path) 30 | os.makedirs(full_dir, exist_ok=True) 31 | os.rename(fname, osp.join(full_dir, fname)) 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("dir") 37 | parser.add_argument("job_ids", nargs='*') 38 | args = parser.parse_args() 39 | 40 | for job_id in args.job_ids: 41 | Process(target=get, args=(job_id, args.dir)).start() 42 | time.sleep(0.5) 43 | 44 | 45 | if __name__ == '__main__': 46 | main() 47 | -------------------------------------------------------------------------------- /floydhub_utils/monitor_jobs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Monitor running FloydHub jobs and send a macOS notification whenever 5 | one finishes 6 | """ 7 | 8 | import re 9 | import subprocess 10 | import time 11 | import sys 12 | 13 | 14 | def display_notification(title, text): 15 | osa_cmd = 'display notification "{}" with title "{}"'.format(text, title) 16 | subprocess.call(['osascript', '-e', osa_cmd]) 17 | 18 | 19 | def main(): 20 | running_jobs = set() 21 | while True: 22 | try: 23 | output = subprocess.check_output(['floyd', 'status']) 24 | output = output.decode().strip() 25 | except subprocess.CalledProcessError: 26 | continue 27 | output = output.split('\n') 28 | output = output[2:] # Skip header lines 29 | 30 | for line in output: 31 | # Consider two or more spaces the field separator 32 | fields = re.sub(r" *", '\t', line).split('\t') 33 | job_name = fields[0] 34 | job_id = job_name.split('/')[-1] 35 | status = fields[2] 36 | 37 | if status == 'running' and job_id not in running_jobs: 38 | print("Found new running job {}".format(job_id)) 39 | running_jobs.add(job_id) 40 | elif ((status == 'shutdown' or status == 'success') 41 | and job_id in running_jobs): 42 | print("Job {} finished".format(job_id)) 43 | running_jobs.remove(job_id) 44 | display_notification("FloydHub job finished", 45 | "Job {} finished".format(job_id)) 46 | 47 | time.sleep(1) 48 | 49 | 50 | if __name__ == '__main__': 51 | try: 52 | main() 53 | except KeyboardInterrupt: 54 | sys.exit(0) 55 | -------------------------------------------------------------------------------- /images/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrahtz/learning-from-human-preferences/3fca07c4c3fd20bec307f4405684461437d9e215/images/diagram.png -------------------------------------------------------------------------------- /images/enduro.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrahtz/learning-from-human-preferences/3fca07c4c3fd20bec307f4405684461437d9e215/images/enduro.gif -------------------------------------------------------------------------------- /images/moving-dot-graphs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrahtz/learning-from-human-preferences/3fca07c4c3fd20bec307f4405684461437d9e215/images/moving-dot-graphs.png -------------------------------------------------------------------------------- /images/moving-dot.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrahtz/learning-from-human-preferences/3fca07c4c3fd20bec307f4405684461437d9e215/images/moving-dot.gif -------------------------------------------------------------------------------- /images/pong-graphs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrahtz/learning-from-human-preferences/3fca07c4c3fd20bec307f4405684461437d9e215/images/pong-graphs.png -------------------------------------------------------------------------------- /images/pong.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mrahtz/learning-from-human-preferences/3fca07c4c3fd20bec307f4405684461437d9e215/images/pong.gif -------------------------------------------------------------------------------- /mem_utils/plot_mems.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Plot process memory usage graphs recorded by utils.profile_memory. 5 | """ 6 | 7 | import argparse 8 | from glob import glob 9 | from os.path import join 10 | 11 | from pylab import * 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('dir') 15 | args = parser.parse_args() 16 | 17 | 18 | files = glob(join(args.dir, 'mem_*.log')) 19 | for i, log in enumerate(files): 20 | with open(log) as f: 21 | lines = f.read().rstrip().split('\n') 22 | mems = [float(l.split()[1]) for l in lines] 23 | times = [float(l.split()[2]) for l in lines] 24 | rtimes = [t - times[0] for t in times] 25 | subplot(len(files), 1, i + 1) 26 | title(log) 27 | plot(rtimes, mems) 28 | 29 | tight_layout() 30 | show() 31 | -------------------------------------------------------------------------------- /nn_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | """ 4 | Wrappers for TensorFlow's layers integrating batchnorm in the right place. 5 | """ 6 | 7 | 8 | def conv_layer(x, filters, kernel_size, strides, batchnorm, training, name, 9 | reuse, activation='relu'): 10 | x = tf.layers.conv2d( 11 | x, 12 | filters, 13 | kernel_size, 14 | strides, 15 | activation=None, 16 | name=name, 17 | reuse=reuse) 18 | 19 | if batchnorm: 20 | batchnorm_name = name + "_batchnorm" 21 | x = tf.layers.batch_normalization( 22 | x, training=training, reuse=reuse, name=batchnorm_name) 23 | 24 | if activation == 'relu': 25 | x = tf.nn.leaky_relu(x, alpha=0.01) 26 | else: 27 | raise Exception("Unknown activation for conv_layer", activation) 28 | 29 | return x 30 | 31 | 32 | def dense_layer(x, 33 | units, 34 | name, 35 | reuse, 36 | activation=None): 37 | x = tf.layers.dense(x, units, activation=None, name=name, reuse=reuse) 38 | 39 | if activation is None: 40 | pass 41 | elif activation == 'relu': 42 | x = tf.nn.leaky_relu(x, alpha=0.01) 43 | else: 44 | raise Exception("Unknown activation for dense_layer", activation) 45 | 46 | return x 47 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import subprocess 5 | import sys 6 | import time 7 | 8 | from a2c.a2c.utils import Scheduler 9 | 10 | 11 | def parse_args(): 12 | parser = argparse.ArgumentParser( 13 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 14 | add_general_args(parser) 15 | add_pref_interface_args(parser) 16 | add_reward_predictor_args(parser) 17 | add_a2c_args(parser) 18 | args = parser.parse_args() 19 | 20 | log_dir = get_log_dir(args) 21 | if args.mode == 'pretrain_reward_predictor' and args.load_prefs_dir is None: 22 | raise Exception("Error: please specify preference databases to train with (--load_prefs_dir)") 23 | general_args = { 24 | 'mode': args.mode, 25 | 'run_name': args.run_name, 26 | 'test_mode': args.test_mode, 27 | 'render_episodes': args.render_episodes, 28 | 'n_initial_prefs': args.n_initial_prefs, 29 | 'max_prefs': args.max_prefs, 30 | 'log_dir': log_dir, 31 | 'prefs_dir': args.load_prefs_dir, 32 | 'debug': args.debug 33 | } 34 | 35 | num_timesteps = int(args.million_timesteps * 1e6) 36 | if args.lr_zero_million_timesteps is None: 37 | schedule = 'constant' 38 | nvalues = 1 # ignored 39 | else: 40 | schedule = 'linear' 41 | nvalues = int(args.lr_zero_million_timesteps * 1e6) 42 | lr_scheduler = Scheduler(v=args.lr, nvalues=nvalues, schedule=schedule) 43 | a2c_args = { 44 | 'env_id': args.env, 45 | 'ent_coef': args.ent_coef, 46 | 'n_envs': args.n_envs, 47 | 'seed': args.seed, 48 | 'ckpt_load_dir': args.load_policy_ckpt_dir, 49 | 'ckpt_save_interval': args.policy_ckpt_interval, 50 | 'total_timesteps': num_timesteps, 51 | 'lr_scheduler': lr_scheduler 52 | } 53 | 54 | pref_interface_args = { 55 | 'synthetic_prefs': args.synthetic_prefs, 56 | 'max_segs': args.max_segs 57 | } 58 | 59 | reward_predictor_training_args = { 60 | 'n_initial_epochs': args.n_initial_epochs, 61 | 'dropout': args.dropout, 62 | 'batchnorm': args.batchnorm, 63 | 'load_ckpt_dir': args.load_reward_predictor_ckpt_dir, 64 | 'ckpt_interval': args.reward_predictor_ckpt_interval, 65 | 'lr': args.reward_predictor_learning_rate, 66 | 'val_interval': 50 67 | } 68 | 69 | with open(osp.join(log_dir, 'args.txt'), 'w') as args_file: 70 | args_file.write(' '.join(sys.argv)) 71 | args_file.write('\n') 72 | args_file.write(str(args)) 73 | 74 | return general_args, a2c_args, pref_interface_args, reward_predictor_training_args 75 | 76 | 77 | def get_log_dir(args): 78 | if args.log_dir is not None: 79 | log_dir = args.log_dir 80 | else: 81 | git_rev = get_git_rev() 82 | run_name = args.run_name + '_' + git_rev 83 | log_dir = osp.join('runs', run_name) 84 | if osp.exists(log_dir): 85 | raise Exception("Log directory '%s' already exists" % log_dir) 86 | os.makedirs(log_dir) 87 | return log_dir 88 | 89 | 90 | def get_git_rev(): 91 | if not osp.exists('.git'): 92 | git_rev = "unkrev" 93 | else: 94 | git_rev = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).decode().rstrip() 95 | return git_rev 96 | 97 | 98 | def add_general_args(parser): 99 | parser.add_argument('mode', choices=['gather_initial_prefs', 'pretrain_reward_predictor', 100 | 'train_policy_with_preferences', 'train_policy_with_original_rewards']) 101 | parser.add_argument('env') 102 | 103 | parser.add_argument('--test_mode', action='store_true') 104 | parser.add_argument('--debug', action='store_true') 105 | parser.add_argument('--render_episodes', action='store_true') 106 | parser.add_argument('--load_prefs_dir') 107 | parser.add_argument('--n_initial_prefs', type=int, default=500, 108 | help='How many preferences to collect from a random ' 109 | 'policy before starting reward predictor ' 110 | 'training') 111 | # Page 15: "We maintain a buffer of only the last 3,000 labels" 112 | parser.add_argument('--max_prefs', type=int, default=3000) 113 | 114 | group = parser.add_mutually_exclusive_group(); 115 | group.add_argument('--log_dir') 116 | seconds_since_epoch = str(int(time.time())) 117 | group.add_argument('--run_name', default=seconds_since_epoch) 118 | 119 | 120 | def add_a2c_args(parser): 121 | parser.add_argument('--log_interval', type=int, default=100) 122 | parser.add_argument('--ent_coef', type=float, default=0.01) 123 | parser.add_argument('--n_envs', type=int, default=1) 124 | parser.add_argument('--seed', help='RNG seed', type=int, default=0) 125 | 126 | parser.add_argument("--lr_zero_million_timesteps", 127 | type=float, default=None, 128 | help='If set, decay learning rate linearly, reaching ' 129 | ' zero at this many timesteps') 130 | parser.add_argument('--lr', type=float, default=7e-4) 131 | 132 | parser.add_argument('--load_policy_ckpt_dir', 133 | help='Load a policy checkpoint from this directory.') 134 | parser.add_argument('--policy_ckpt_interval', type=int, default=100, 135 | help="No. updates between policy checkpoints") 136 | 137 | parser.add_argument('--million_timesteps', 138 | type=float, default=10., 139 | help='How many million timesteps to train for. ' 140 | '(The number of frames trained for is this ' 141 | 'multiplied by 4 due to frameskip.)') 142 | 143 | 144 | def add_reward_predictor_args(parser): 145 | parser.add_argument('--reward_predictor_learning_rate', type=float, default=2e-4) 146 | parser.add_argument('--n_initial_epochs', type=int, default=200) 147 | parser.add_argument('--dropout', type=float, default=0.0) 148 | parser.add_argument('--batchnorm', action='store_true') 149 | parser.add_argument('--load_reward_predictor_ckpt_dir', 150 | help='Directory to load reward predictor checkpoint from ' 151 | '(loads latest checkpoint in the specified directory)') 152 | parser.add_argument('--reward_predictor_ckpt_interval', 153 | type=int, default=1, 154 | help='No. training epochs between reward ' 155 | 'predictor checkpoints') 156 | 157 | 158 | def add_pref_interface_args(parser): 159 | parser.add_argument('--synthetic_prefs', action='store_true') 160 | # Maximum number of segments to store from which to generate preferences. 161 | # This isn't a parameter specified in the paper; 162 | # I'm just guessing that 1,000 is a reasonable figure. 163 | # 1,000 corresponds to about 25,000 frames. 164 | 165 | # How much memory does this use? 166 | # 84 x 84 (pixels per frame) x 167 | # 4 (frames per stack) x 168 | # 25 (stacks per segment) x 169 | # 1000 170 | # = ~ 700 MB 171 | parser.add_argument('--max_segs', type=int, default=1000) 172 | 173 | 174 | # Fraction of preferences that should be used for reward predictor validation 175 | # accuracy tests 176 | PREFS_VAL_FRACTION = 0.2 177 | -------------------------------------------------------------------------------- /pref_db.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import copy 3 | import gzip 4 | import pickle 5 | import queue 6 | import time 7 | import zlib 8 | from threading import Lock, Thread 9 | 10 | import easy_tf_log 11 | import numpy as np 12 | 13 | 14 | class Segment: 15 | """ 16 | A short recording of agent's behaviour in the environment, 17 | consisting of a number of video frames and the rewards it received 18 | during those frames. 19 | """ 20 | 21 | def __init__(self): 22 | self.frames = [] 23 | self.rewards = [] 24 | self.hash = None 25 | 26 | def append(self, frame, reward): 27 | self.frames.append(frame) 28 | self.rewards.append(reward) 29 | 30 | def finalise(self, seg_id=None): 31 | if seg_id is not None: 32 | self.hash = seg_id 33 | else: 34 | # This looks expensive, but don't worry - 35 | # it only takes about 0.5 ms. 36 | self.hash = hash(np.array(self.frames).tostring()) 37 | 38 | def __len__(self): 39 | return len(self.frames) 40 | 41 | 42 | class CompressedDict(collections.MutableMapping): 43 | 44 | def __init__(self): 45 | self.store = dict() 46 | 47 | def __getitem__(self, key): 48 | return pickle.loads(zlib.decompress(self.store[key])) 49 | 50 | def __setitem__(self, key, value): 51 | self.store[key] = zlib.compress(pickle.dumps(value)) 52 | 53 | def __delitem__(self, key): 54 | del self.store[key] 55 | 56 | def __iter__(self): 57 | return iter(self.store) 58 | 59 | def __len__(self): 60 | return len(self.store) 61 | 62 | def __keytransform__(self, key): 63 | return key 64 | 65 | 66 | class PrefDB: 67 | """ 68 | A circular database of preferences about pairs of segments. 69 | 70 | For each preference, we store the preference itself 71 | (mu in the paper) and the two segments the preference refers to. 72 | Segments are stored with deduplication - so that if multiple 73 | preferences refer to the same segment, the segment is only stored once. 74 | """ 75 | 76 | def __init__(self, maxlen): 77 | self.segments = CompressedDict() 78 | self.seg_refs = {} 79 | self.prefs = [] 80 | self.maxlen = maxlen 81 | 82 | def append(self, s1, s2, pref): 83 | k1 = hash(np.array(s1).tostring()) 84 | k2 = hash(np.array(s2).tostring()) 85 | 86 | for k, s in zip([k1, k2], [s1, s2]): 87 | if k not in self.segments.keys(): 88 | self.segments[k] = s 89 | self.seg_refs[k] = 1 90 | else: 91 | self.seg_refs[k] += 1 92 | 93 | tup = (k1, k2, pref) 94 | self.prefs.append(tup) 95 | 96 | if len(self.prefs) > self.maxlen: 97 | self.del_first() 98 | 99 | def del_first(self): 100 | self.del_pref(0) 101 | 102 | def del_pref(self, n): 103 | if n >= len(self.prefs): 104 | raise IndexError("Preference {} doesn't exist".format(n)) 105 | k1, k2, _ = self.prefs[n] 106 | for k in [k1, k2]: 107 | if self.seg_refs[k] == 1: 108 | del self.segments[k] 109 | del self.seg_refs[k] 110 | else: 111 | self.seg_refs[k] -= 1 112 | del self.prefs[n] 113 | 114 | def __len__(self): 115 | return len(self.prefs) 116 | 117 | def save(self, path): 118 | with gzip.open(path, 'wb') as pkl_file: 119 | pickle.dump(self, pkl_file) 120 | 121 | @staticmethod 122 | def load(path): 123 | with gzip.open(path, 'rb') as pkl_file: 124 | pref_db = pickle.load(pkl_file) 125 | return pref_db 126 | 127 | 128 | class PrefBuffer: 129 | """ 130 | A helper class to manage asynchronous receiving of preferences on a 131 | background thread. 132 | """ 133 | def __init__(self, db_train, db_val): 134 | self.train_db = db_train 135 | self.val_db = db_val 136 | self.lock = Lock() 137 | self.stop_recv = False 138 | 139 | def start_recv_thread(self, pref_pipe): 140 | self.stop_recv = False 141 | Thread(target=self.recv_prefs, args=(pref_pipe, )).start() 142 | 143 | def stop_recv_thread(self): 144 | self.stop_recv = True 145 | 146 | def recv_prefs(self, pref_pipe): 147 | n_recvd = 0 148 | while not self.stop_recv: 149 | try: 150 | s1, s2, pref = pref_pipe.get(block=True, timeout=1) 151 | except queue.Empty: 152 | continue 153 | n_recvd += 1 154 | 155 | val_fraction = self.val_db.maxlen / (self.val_db.maxlen + 156 | self.train_db.maxlen) 157 | 158 | self.lock.acquire(blocking=True) 159 | if np.random.rand() < val_fraction: 160 | self.val_db.append(s1, s2, pref) 161 | easy_tf_log.tflog('val_db_len', len(self.val_db)) 162 | else: 163 | self.train_db.append(s1, s2, pref) 164 | easy_tf_log.tflog('train_db_len', len(self.train_db)) 165 | self.lock.release() 166 | 167 | easy_tf_log.tflog('n_prefs_recvd', n_recvd) 168 | 169 | def train_db_len(self): 170 | return len(self.train_db) 171 | 172 | def val_db_len(self): 173 | return len(self.val_db) 174 | 175 | def get_dbs(self): 176 | self.lock.acquire(blocking=True) 177 | train_copy = copy.deepcopy(self.train_db) 178 | val_copy = copy.deepcopy(self.val_db) 179 | self.lock.release() 180 | return train_copy, val_copy 181 | 182 | def wait_until_len(self, min_len): 183 | while True: 184 | self.lock.acquire() 185 | train_len = len(self.train_db) 186 | val_len = len(self.val_db) 187 | self.lock.release() 188 | if train_len >= min_len and val_len != 0: 189 | break 190 | print("Waiting for preferences; {} so far".format(train_len)) 191 | time.sleep(5.0) 192 | -------------------------------------------------------------------------------- /pref_db_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import unittest 4 | import numpy as np 5 | from pref_db import PrefDB 6 | 7 | 8 | class TestPrefDB(unittest.TestCase): 9 | 10 | def test_similar_segs(self): 11 | """ 12 | Test that the preference database really distinguishes 13 | between similar segments 14 | (i.e. check that its hash function is working as it's supposed to). 15 | """ 16 | p = PrefDB(maxlen=5) 17 | s1 = np.ones((25, 84, 84, 4)) 18 | s2 = np.ones((25, 84, 84, 4)) 19 | s2[12][24][24][2] = 0 20 | p.append(s1, s2, [1.0, 0.0]) 21 | self.assertEqual(len(p.segments), 2) 22 | 23 | def test_append_delete(self): 24 | """ 25 | Do a number of appends/deletes and check that the number of 26 | preferences and segments is as expected at all times. 27 | """ 28 | p = PrefDB(maxlen=10) 29 | 30 | s1 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4)) 31 | s2 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4)) 32 | p.append(s1, s2, [1.0, 0.0]) 33 | self.assertEqual(len(p.segments), 2) 34 | self.assertEqual(len(p.prefs), 1) 35 | 36 | p.append(s1, s2, [0.0, 1.0]) 37 | self.assertEqual(len(p.segments), 2) 38 | self.assertEqual(len(p.prefs), 2) 39 | 40 | s1 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4)) 41 | p.append(s1, s2, [1.0, 0.0]) 42 | self.assertEqual(len(p.segments), 3) 43 | self.assertEqual(len(p.prefs), 3) 44 | 45 | s2 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4)) 46 | p.append(s1, s2, [1.0, 0.0]) 47 | self.assertEqual(len(p.segments), 4) 48 | self.assertEqual(len(p.prefs), 4) 49 | 50 | s1 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4)) 51 | s2 = np.random.randint(low=-10, high=10, size=(25, 84, 84, 4)) 52 | p.append(s1, s2, [1.0, 0.0]) 53 | self.assertEqual(len(p.segments), 6) 54 | self.assertEqual(len(p.prefs), 5) 55 | 56 | prefs_pre = list(p.prefs) 57 | p.del_first() 58 | self.assertEqual(len(p.prefs), 4) 59 | self.assertEqual(p.prefs, prefs_pre[1:]) 60 | # These segments were also used by the second preference, 61 | # so the number of segments shouldn't have decreased 62 | self.assertEqual(len(p.segments), 6) 63 | 64 | p.del_first() 65 | self.assertEqual(len(p.prefs), 3) 66 | # One of the segments just deleted was only used by the first two 67 | # preferences, so the length should have shrunk by one 68 | self.assertEqual(len(p.segments), 5) 69 | 70 | p.del_first() 71 | self.assertEqual(len(p.prefs), 2) 72 | # Another one should bite the dust... 73 | self.assertEqual(len(p.segments), 4) 74 | 75 | p.del_first() 76 | self.assertEqual(len(p.prefs), 1) 77 | self.assertEqual(len(p.segments), 2) 78 | 79 | p.del_first() 80 | self.assertEqual(len(p.prefs), 0) 81 | self.assertEqual(len(p.segments), 0) 82 | 83 | def test_circular(self): 84 | p = PrefDB(maxlen=2) 85 | 86 | p.append(0, 1, 10) 87 | self.assertEqual(len(p), 1) 88 | p.append(2, 3, 11) 89 | self.assertEqual(len(p), 2) 90 | p.append(4, 5, 12) 91 | self.assertEqual(len(p), 2) 92 | 93 | self.assertEqual(len(p.segments), 4) 94 | self.assertIn(2, p.segments.values()) 95 | self.assertIn(3, p.segments.values()) 96 | self.assertIn(4, p.segments.values()) 97 | self.assertIn(5, p.segments.values()) 98 | 99 | self.assertEqual(p.prefs[0][2], 11) 100 | self.assertEqual(p.prefs[1][2], 12) 101 | 102 | 103 | if __name__ == '__main__': 104 | unittest.main() 105 | -------------------------------------------------------------------------------- /pref_interface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | A simple CLI-based interface for querying the user about segment preferences. 5 | """ 6 | 7 | import logging 8 | import queue 9 | import time 10 | from itertools import combinations 11 | from multiprocessing import Queue 12 | from random import shuffle 13 | 14 | import easy_tf_log 15 | import numpy as np 16 | 17 | from utils import VideoRenderer 18 | 19 | 20 | class PrefInterface: 21 | 22 | def __init__(self, synthetic_prefs, max_segs, log_dir): 23 | self.vid_q = Queue() 24 | if not synthetic_prefs: 25 | self.renderer = VideoRenderer(vid_queue=self.vid_q, 26 | mode=VideoRenderer.restart_on_get_mode, 27 | zoom=4) 28 | else: 29 | self.renderer = None 30 | self.synthetic_prefs = synthetic_prefs 31 | self.seg_idx = 0 32 | self.segments = [] 33 | self.tested_pairs = set() # For O(1) lookup 34 | self.max_segs = max_segs 35 | easy_tf_log.set_dir(log_dir) 36 | 37 | def stop_renderer(self): 38 | if self.renderer: 39 | self.renderer.stop() 40 | 41 | def run(self, seg_pipe, pref_pipe): 42 | while len(self.segments) < 2: 43 | print("Preference interface waiting for segments") 44 | time.sleep(5.0) 45 | self.recv_segments(seg_pipe) 46 | 47 | while True: 48 | seg_pair = None 49 | while seg_pair is None: 50 | try: 51 | seg_pair = self.sample_seg_pair() 52 | except IndexError: 53 | print("Preference interface ran out of untested segments;" 54 | "waiting...") 55 | # If we've tested all possible pairs of segments so far, 56 | # we'll have to wait for more segments 57 | time.sleep(1.0) 58 | self.recv_segments(seg_pipe) 59 | s1, s2 = seg_pair 60 | 61 | logging.debug("Querying preference for segments %s and %s", 62 | s1.hash, s2.hash) 63 | 64 | if not self.synthetic_prefs: 65 | pref = self.ask_user(s1, s2) 66 | else: 67 | if sum(s1.rewards) > sum(s2.rewards): 68 | pref = (1.0, 0.0) 69 | elif sum(s1.rewards) < sum(s2.rewards): 70 | pref = (0.0, 1.0) 71 | else: 72 | pref = (0.5, 0.5) 73 | 74 | if pref is not None: 75 | # We don't need the rewards from this point on, so just send 76 | # the frames 77 | pref_pipe.put((s1.frames, s2.frames, pref)) 78 | # If pref is None, the user answered "incomparable" for the segment 79 | # pair. The pair has been marked as tested; we just drop it. 80 | 81 | self.recv_segments(seg_pipe) 82 | 83 | def recv_segments(self, seg_pipe): 84 | """ 85 | Receive segments from `seg_pipe` into circular buffer `segments`. 86 | """ 87 | max_wait_seconds = 0.5 88 | start_time = time.time() 89 | n_recvd = 0 90 | while time.time() - start_time < max_wait_seconds: 91 | try: 92 | segment = seg_pipe.get(block=True, timeout=max_wait_seconds) 93 | except queue.Empty: 94 | return 95 | if len(self.segments) < self.max_segs: 96 | self.segments.append(segment) 97 | else: 98 | self.segments[self.seg_idx] = segment 99 | self.seg_idx = (self.seg_idx + 1) % self.max_segs 100 | n_recvd += 1 101 | easy_tf_log.tflog('segment_idx', self.seg_idx) 102 | easy_tf_log.tflog('n_segments_rcvd', n_recvd) 103 | easy_tf_log.tflog('n_segments', len(self.segments)) 104 | 105 | def sample_seg_pair(self): 106 | """ 107 | Sample a random pair of segments which hasn't yet been tested. 108 | """ 109 | segment_idxs = list(range(len(self.segments))) 110 | shuffle(segment_idxs) 111 | possible_pairs = combinations(segment_idxs, 2) 112 | for i1, i2 in possible_pairs: 113 | s1, s2 = self.segments[i1], self.segments[i2] 114 | if ((s1.hash, s2.hash) not in self.tested_pairs) and \ 115 | ((s2.hash, s1.hash) not in self.tested_pairs): 116 | self.tested_pairs.add((s1.hash, s2.hash)) 117 | self.tested_pairs.add((s2.hash, s1.hash)) 118 | return s1, s2 119 | raise IndexError("No segment pairs yet untested") 120 | 121 | def ask_user(self, s1, s2): 122 | vid = [] 123 | seg_len = len(s1) 124 | for t in range(seg_len): 125 | border = np.zeros((84, 10), dtype=np.uint8) 126 | # -1 => show only the most recent frame of the 4-frame stack 127 | frame = np.hstack((s1.frames[t][:, :, -1], 128 | border, 129 | s2.frames[t][:, :, -1])) 130 | vid.append(frame) 131 | n_pause_frames = 7 132 | for _ in range(n_pause_frames): 133 | vid.append(np.copy(vid[-1])) 134 | self.vid_q.put(vid) 135 | 136 | while True: 137 | print("Segments {} and {}: ".format(s1.hash, s2.hash)) 138 | choice = input() 139 | # L = "I prefer the left segment" 140 | # R = "I prefer the right segment" 141 | # E = "I don't have a clear preference between the two segments" 142 | # "" = "The segments are incomparable" 143 | if choice == "L" or choice == "R" or choice == "E" or choice == "": 144 | break 145 | else: 146 | print("Invalid choice '{}'".format(choice)) 147 | 148 | if choice == "L": 149 | pref = (1.0, 0.0) 150 | elif choice == "R": 151 | pref = (0.0, 1.0) 152 | elif choice == "E": 153 | pref = (0.5, 0.5) 154 | elif choice == "": 155 | pref = None 156 | 157 | self.vid_q.put([np.zeros(vid[0].shape, dtype=np.uint8)]) 158 | 159 | return pref 160 | -------------------------------------------------------------------------------- /pref_interface_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import unittest 3 | from itertools import combinations 4 | from multiprocessing import Queue 5 | 6 | import numpy as np 7 | import termcolor 8 | 9 | from pref_db import Segment 10 | from pref_interface import PrefInterface 11 | 12 | 13 | def send_segments(n_segments, seg_pipe): 14 | frame_stack = np.zeros((84, 84, 4)) 15 | for i in range(n_segments): 16 | segment = Segment() 17 | for _ in range(25): 18 | segment.append(frame=frame_stack, reward=0) 19 | segment.finalise(seg_id=i) 20 | seg_pipe.put(segment) 21 | 22 | 23 | class TestPrefInterface(unittest.TestCase): 24 | def setUp(self): 25 | self.p = PrefInterface(synthetic_prefs=True, max_segs=1000, 26 | log_dir='/tmp') 27 | termcolor.cprint(self._testMethodName, 'red') 28 | 29 | def test_sample_pair(self): 30 | seg_pipe = Queue() 31 | n_segments = 5 32 | send_segments(n_segments, seg_pipe) 33 | self.p.recv_segments(seg_pipe) 34 | 35 | # Check that we get exactly the right number of unique pairs back 36 | n_possible_pairs = len(list(combinations(range(n_segments), 2))) 37 | tested_pairs = set() 38 | for _ in range(n_possible_pairs): 39 | s1, s2 = self.p.sample_seg_pair() 40 | tested_pairs.add((s1.hash, s2.hash)) 41 | tested_pairs.add((s2.hash, s1.hash)) 42 | self.assertEqual(len(tested_pairs), 2 * n_possible_pairs) 43 | 44 | # Check that if we try to get just one more, we get an exception 45 | # indicating that there are no more unique pairs available 46 | with self.assertRaises(IndexError): 47 | self.p.sample_seg_pair() 48 | 49 | def test_recv_segments(self): 50 | """ 51 | Check that segments are stored correctly in the circular buffer. 52 | """ 53 | pi = PrefInterface(synthetic_prefs=True, max_segs=5, log_dir='/tmp') 54 | pipe = Queue() 55 | for i in range(5): 56 | pipe.put(i) 57 | pi.recv_segments(pipe) 58 | np.testing.assert_array_equal(pi.segments, [0, 1, 2, 3, 4]) 59 | for i in range(5, 8): 60 | pipe.put(i) 61 | pi.recv_segments(pipe) 62 | np.testing.assert_array_equal(pi.segments, [5, 6, 7, 3, 4]) 63 | for i in range(8, 11): 64 | pipe.put(i) 65 | pi.recv_segments(pipe) 66 | np.testing.assert_array_equal(pi.segments, [10, 6, 7, 8, 9]) 67 | 68 | 69 | if __name__ == '__main__': 70 | unittest.main() 71 | -------------------------------------------------------------------------------- /reward_predictor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path as osp 3 | import time 4 | 5 | import easy_tf_log 6 | import numpy as np 7 | from numpy.testing import assert_equal 8 | import tensorflow as tf 9 | 10 | from utils import RunningStat, batch_iter 11 | 12 | 13 | class RewardPredictorEnsemble: 14 | """ 15 | An ensemble of reward predictors and associated helper functions. 16 | """ 17 | 18 | def __init__(self, 19 | cluster_job_name, 20 | core_network, 21 | lr=1e-4, 22 | cluster_dict=None, 23 | batchnorm=False, 24 | dropout=0.0, 25 | n_preds=1, 26 | log_dir=None): 27 | self.n_preds = n_preds 28 | graph, self.sess = self.init_sess(cluster_dict, cluster_job_name) 29 | # Why not just use soft device placement? With soft placement, 30 | # if we have a bug which prevents an operation being placed on the GPU 31 | # (e.g. we're using uint8s for operations that the GPU can't do), 32 | # then TensorFlow will be silent and just place the operation on a CPU. 33 | # Instead, we want to say: if there's a GPU present, definitely try and 34 | # put things on the GPU. If it fails, tell us! 35 | if tf.test.gpu_device_name(): 36 | worker_device = "/job:{}/task:0/gpu:0".format(cluster_job_name) 37 | else: 38 | worker_device = "/job:{}/task:0".format(cluster_job_name) 39 | device_setter = tf.train.replica_device_setter( 40 | cluster=cluster_dict, 41 | ps_device="/job:ps/task:0", 42 | worker_device=worker_device) 43 | self.rps = [] 44 | with graph.as_default(): 45 | for pred_n in range(n_preds): 46 | with tf.device(device_setter): 47 | with tf.variable_scope("pred_{}".format(pred_n)): 48 | rp = RewardPredictorNetwork( 49 | core_network=core_network, 50 | dropout=dropout, 51 | batchnorm=batchnorm, 52 | lr=lr) 53 | self.rps.append(rp) 54 | self.init_op = tf.global_variables_initializer() 55 | # Why save_relative_paths=True? 56 | # So that the plain-text 'checkpoint' file written uses relative paths, 57 | # which seems to be needed in order to avoid confusing saver.restore() 58 | # when restoring from FloydHub runs. 59 | self.saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) 60 | self.summaries = self.add_summary_ops() 61 | 62 | self.checkpoint_file = osp.join(log_dir, 63 | 'reward_predictor_checkpoints', 64 | 'reward_predictor.ckpt') 65 | self.train_writer = tf.summary.FileWriter( 66 | osp.join(log_dir, 'reward_predictor', 'train'), flush_secs=5) 67 | self.test_writer = tf.summary.FileWriter( 68 | osp.join(log_dir, 'reward_predictor', 'test'), flush_secs=5) 69 | 70 | self.n_steps = 0 71 | self.r_norm = RunningStat(shape=n_preds) 72 | 73 | misc_logs_dir = osp.join(log_dir, 'reward_predictor', 'misc') 74 | easy_tf_log.set_dir(misc_logs_dir) 75 | 76 | @staticmethod 77 | def init_sess(cluster_dict, cluster_job_name): 78 | graph = tf.Graph() 79 | cluster = tf.train.ClusterSpec(cluster_dict) 80 | config = tf.ConfigProto(gpu_options={'allow_growth': True}) 81 | server = tf.train.Server(cluster, job_name=cluster_job_name, config=config) 82 | sess = tf.Session(server.target, graph) 83 | return graph, sess 84 | 85 | def add_summary_ops(self): 86 | summary_ops = [] 87 | 88 | for pred_n, rp in enumerate(self.rps): 89 | name = 'reward_predictor_accuracy_{}'.format(pred_n) 90 | op = tf.summary.scalar(name, rp.accuracy) 91 | summary_ops.append(op) 92 | name = 'reward_predictor_loss_{}'.format(pred_n) 93 | op = tf.summary.scalar(name, rp.loss) 94 | summary_ops.append(op) 95 | 96 | mean_accuracy = tf.reduce_mean([rp.accuracy for rp in self.rps]) 97 | op = tf.summary.scalar('reward_predictor_accuracy_mean', mean_accuracy) 98 | summary_ops.append(op) 99 | 100 | mean_loss = tf.reduce_mean([rp.loss for rp in self.rps]) 101 | op = tf.summary.scalar('reward_predictor_loss_mean', mean_loss) 102 | summary_ops.append(op) 103 | 104 | summaries = tf.summary.merge(summary_ops) 105 | 106 | return summaries 107 | 108 | def init_network(self, load_ckpt_dir=None): 109 | if load_ckpt_dir: 110 | ckpt_file = tf.train.latest_checkpoint(load_ckpt_dir) 111 | if ckpt_file is None: 112 | msg = "No reward predictor checkpoint found in '{}'".format( 113 | load_ckpt_dir) 114 | raise FileNotFoundError(msg) 115 | self.saver.restore(self.sess, ckpt_file) 116 | print("Loaded reward predictor checkpoint from '{}'".format(ckpt_file)) 117 | else: 118 | self.sess.run(self.init_op) 119 | 120 | def save(self): 121 | ckpt_name = self.saver.save(self.sess, 122 | self.checkpoint_file, 123 | self.n_steps) 124 | print("Saved reward predictor checkpoint to '{}'".format(ckpt_name)) 125 | 126 | def raw_rewards(self, obs): 127 | """ 128 | Return (unnormalized) reward for each frame of a single segment 129 | from each member of the ensemble. 130 | """ 131 | assert_equal(obs.shape[1:], (84, 84, 4)) 132 | n_steps = obs.shape[0] 133 | feed_dict = {} 134 | for rp in self.rps: 135 | feed_dict[rp.training] = False 136 | feed_dict[rp.s1] = [obs] 137 | # This will return nested lists of sizes n_preds x 1 x nsteps 138 | # (x 1 because of the batch size of 1) 139 | rs = self.sess.run([rp.r1 for rp in self.rps], feed_dict) 140 | rs = np.array(rs) 141 | # Get rid of the extra x 1 dimension 142 | rs = rs[:, 0, :] 143 | assert_equal(rs.shape, (self.n_preds, n_steps)) 144 | return rs 145 | 146 | def reward(self, obs): 147 | """ 148 | Return (normalized) reward for each frame of a single segment. 149 | 150 | (Normalization involves normalizing the rewards from each member of the 151 | ensemble separately, then averaging the resulting rewards across all 152 | ensemble members.) 153 | """ 154 | assert_equal(obs.shape[1:], (84, 84, 4)) 155 | n_steps = obs.shape[0] 156 | 157 | # Get unnormalized rewards 158 | 159 | ensemble_rs = self.raw_rewards(obs) 160 | logging.debug("Unnormalized rewards:\n%s", ensemble_rs) 161 | 162 | # Normalize rewards 163 | 164 | # Note that we implement this here instead of in the network itself 165 | # because: 166 | # * It's simpler not to do it in TensorFlow 167 | # * Preference prediction doesn't need normalized rewards. Only 168 | # rewards sent to the the RL algorithm need to be normalized. 169 | # So we can save on computation. 170 | 171 | # Page 4: 172 | # "We normalized the rewards produced by r^ to have zero mean and 173 | # constant standard deviation." 174 | # Page 15: (Atari) 175 | # "Since the reward predictor is ultimately used to compare two sums 176 | # over timesteps, its scale is arbitrary, and we normalize it to have 177 | # a standard deviation of 0.05" 178 | # Page 5: 179 | # "The estimate r^ is defined by independently normalizing each of 180 | # these predictors..." 181 | 182 | # We want to keep track of running mean/stddev for each member of the 183 | # ensemble separately, so we have to be a little careful here. 184 | assert_equal(ensemble_rs.shape, (self.n_preds, n_steps)) 185 | ensemble_rs = ensemble_rs.transpose() 186 | assert_equal(ensemble_rs.shape, (n_steps, self.n_preds)) 187 | for ensemble_rs_step in ensemble_rs: 188 | self.r_norm.push(ensemble_rs_step) 189 | ensemble_rs -= self.r_norm.mean 190 | ensemble_rs /= (self.r_norm.std + 1e-12) 191 | ensemble_rs *= 0.05 192 | ensemble_rs = ensemble_rs.transpose() 193 | assert_equal(ensemble_rs.shape, (self.n_preds, n_steps)) 194 | logging.debug("Reward mean/stddev:\n%s %s", 195 | self.r_norm.mean, 196 | self.r_norm.std) 197 | logging.debug("Normalized rewards:\n%s", ensemble_rs) 198 | 199 | # "...and then averaging the results." 200 | rs = np.mean(ensemble_rs, axis=0) 201 | assert_equal(rs.shape, (n_steps, )) 202 | logging.debug("After ensemble averaging:\n%s", rs) 203 | 204 | return rs 205 | 206 | def preferences(self, s1s, s2s): 207 | """ 208 | Predict probability of human preferring one segment over another 209 | for each segment in the supplied batch of segment pairs. 210 | """ 211 | feed_dict = {} 212 | for rp in self.rps: 213 | feed_dict[rp.s1] = s1s 214 | feed_dict[rp.s2] = s2s 215 | feed_dict[rp.training] = False 216 | preds = self.sess.run([rp.pred for rp in self.rps], feed_dict) 217 | return preds 218 | 219 | def train(self, prefs_train, prefs_val, val_interval): 220 | """ 221 | Train all ensemble members for one epoch. 222 | """ 223 | print("Training/testing with %d/%d preferences" % (len(prefs_train), 224 | len(prefs_val))) 225 | 226 | start_steps = self.n_steps 227 | start_time = time.time() 228 | 229 | for _, batch in enumerate(batch_iter(prefs_train.prefs, 230 | batch_size=32, 231 | shuffle=True)): 232 | self.train_step(batch, prefs_train) 233 | self.n_steps += 1 234 | 235 | if self.n_steps and self.n_steps % val_interval == 0: 236 | self.val_step(prefs_val) 237 | 238 | end_time = time.time() 239 | end_steps = self.n_steps 240 | rate = (end_steps - start_steps) / (end_time - start_time) 241 | easy_tf_log.tflog('reward_predictor_training_steps_per_second', 242 | rate) 243 | 244 | def train_step(self, batch, prefs_train): 245 | s1s = [prefs_train.segments[k1] for k1, k2, pref, in batch] 246 | s2s = [prefs_train.segments[k2] for k1, k2, pref, in batch] 247 | prefs = [pref for k1, k2, pref, in batch] 248 | feed_dict = {} 249 | for rp in self.rps: 250 | feed_dict[rp.s1] = s1s 251 | feed_dict[rp.s2] = s2s 252 | feed_dict[rp.pref] = prefs 253 | feed_dict[rp.training] = True 254 | ops = [self.summaries, [rp.train for rp in self.rps]] 255 | summaries, _ = self.sess.run(ops, feed_dict) 256 | self.train_writer.add_summary(summaries, self.n_steps) 257 | 258 | def val_step(self, prefs_val): 259 | val_batch_size = 32 260 | if len(prefs_val) <= val_batch_size: 261 | batch = prefs_val.prefs 262 | else: 263 | idxs = np.random.choice(len(prefs_val.prefs), 264 | val_batch_size, 265 | replace=False) 266 | batch = [prefs_val.prefs[i] for i in idxs] 267 | s1s = [prefs_val.segments[k1] for k1, k2, pref, in batch] 268 | s2s = [prefs_val.segments[k2] for k1, k2, pref, in batch] 269 | prefs = [pref for k1, k2, pref, in batch] 270 | feed_dict = {} 271 | for rp in self.rps: 272 | feed_dict[rp.s1] = s1s 273 | feed_dict[rp.s2] = s2s 274 | feed_dict[rp.pref] = prefs 275 | feed_dict[rp.training] = False 276 | summaries = self.sess.run(self.summaries, feed_dict) 277 | self.test_writer.add_summary(summaries, self.n_steps) 278 | 279 | 280 | class RewardPredictorNetwork: 281 | """ 282 | Predict the reward that a human would assign to each frame of 283 | the input trajectory, trained using the human's preferences between 284 | pairs of trajectories. 285 | 286 | Network inputs: 287 | - s1/s2 Trajectory pairs 288 | - pref Preferences between each pair of trajectories 289 | Network outputs: 290 | - r1/r2 Reward predicted for each frame 291 | - rs1/rs2 Reward summed over all frames for each trajectory 292 | - pred Predicted preference 293 | """ 294 | 295 | def __init__(self, core_network, dropout, batchnorm, lr): 296 | training = tf.placeholder(tf.bool) 297 | # Each element of the batch is one trajectory segment. 298 | # (Dimensions are n segments x n frames per segment x ...) 299 | s1 = tf.placeholder(tf.float32, shape=(None, None, 84, 84, 4)) 300 | s2 = tf.placeholder(tf.float32, shape=(None, None, 84, 84, 4)) 301 | # For each trajectory segment, there is one human judgement. 302 | pref = tf.placeholder(tf.float32, shape=(None, 2)) 303 | 304 | # Concatenate trajectory segments so that the first dimension is just 305 | # frames 306 | # (necessary because of conv layer's requirements on input shape) 307 | s1_unrolled = tf.reshape(s1, [-1, 84, 84, 4]) 308 | s2_unrolled = tf.reshape(s2, [-1, 84, 84, 4]) 309 | 310 | # Predict rewards for each frame in the unrolled batch 311 | _r1 = core_network( 312 | s=s1_unrolled, 313 | dropout=dropout, 314 | batchnorm=batchnorm, 315 | reuse=False, 316 | training=training) 317 | _r2 = core_network( 318 | s=s2_unrolled, 319 | dropout=dropout, 320 | batchnorm=batchnorm, 321 | reuse=True, 322 | training=training) 323 | 324 | # Shape should be 'unrolled batch size' 325 | # where 'unrolled batch size' is 'batch size' x 'n frames per segment' 326 | c1 = tf.assert_rank(_r1, 1) 327 | c2 = tf.assert_rank(_r2, 1) 328 | with tf.control_dependencies([c1, c2]): 329 | # Re-roll to 'batch size' x 'n frames per segment' 330 | __r1 = tf.reshape(_r1, tf.shape(s1)[0:2]) 331 | __r2 = tf.reshape(_r2, tf.shape(s2)[0:2]) 332 | # Shape should be 'batch size' x 'n frames per segment' 333 | c1 = tf.assert_rank(__r1, 2) 334 | c2 = tf.assert_rank(__r2, 2) 335 | with tf.control_dependencies([c1, c2]): 336 | r1 = __r1 337 | r2 = __r2 338 | 339 | # Sum rewards over all frames in each segment 340 | _rs1 = tf.reduce_sum(r1, axis=1) 341 | _rs2 = tf.reduce_sum(r2, axis=1) 342 | # Shape should be 'batch size' 343 | c1 = tf.assert_rank(_rs1, 1) 344 | c2 = tf.assert_rank(_rs2, 1) 345 | with tf.control_dependencies([c1, c2]): 346 | rs1 = _rs1 347 | rs2 = _rs2 348 | 349 | # Predict preferences for each segment 350 | _rs = tf.stack([rs1, rs2], axis=1) 351 | # Shape should be 'batch size' x 2 352 | c1 = tf.assert_rank(_rs, 2) 353 | with tf.control_dependencies([c1]): 354 | rs = _rs 355 | _pred = tf.nn.softmax(rs) 356 | # Shape should be 'batch_size' x 2 357 | c1 = tf.assert_rank(_pred, 2) 358 | with tf.control_dependencies([c1]): 359 | pred = _pred 360 | 361 | preds_correct = tf.equal(tf.argmax(pref, 1), tf.argmax(pred, 1)) 362 | accuracy = tf.reduce_mean(tf.cast(preds_correct, tf.float32)) 363 | 364 | _loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=pref, 365 | logits=rs) 366 | # Shape should be 'batch size' 367 | c1 = tf.assert_rank(_loss, 1) 368 | with tf.control_dependencies([c1]): 369 | loss = tf.reduce_sum(_loss) 370 | 371 | # Make sure that batch normalization ops are updated 372 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 373 | 374 | with tf.control_dependencies(update_ops): 375 | train = tf.train.AdamOptimizer(learning_rate=lr).minimize(loss) 376 | 377 | # Inputs 378 | self.training = training 379 | self.s1 = s1 380 | self.s2 = s2 381 | self.pref = pref 382 | 383 | # Outputs 384 | self.r1 = r1 385 | self.r2 = r2 386 | self.rs1 = rs1 387 | self.rs2 = rs2 388 | self.pred = pred 389 | 390 | self.accuracy = accuracy 391 | self.loss = loss 392 | self.train = train 393 | -------------------------------------------------------------------------------- /reward_predictor_core_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core network which predicts rewards from frames, 3 | for gym-moving-dot and Atari games. 4 | """ 5 | 6 | import tensorflow as tf 7 | 8 | from nn_layers import dense_layer, conv_layer 9 | 10 | 11 | def get_dot_position(s): 12 | """ 13 | Estimate the position of the dot in the gym-moving-dot environment. 14 | """ 15 | # s is (?, 84, 84, 4) 16 | s = s[..., -1] # select last frame; now (?, 84, 84) 17 | 18 | x = tf.reduce_sum(s, axis=1) # now (?, 84) 19 | x = tf.argmax(x, axis=1) 20 | 21 | y = tf.reduce_sum(s, axis=2) 22 | y = tf.argmax(y, axis=1) 23 | 24 | return x, y 25 | 26 | 27 | def net_moving_dot_features(s, batchnorm, dropout, training, reuse): 28 | # Action taken at each time step is encoded in the observations by a2c.py. 29 | a = s[:, 0, 0, -1] 30 | a = tf.cast(a, tf.float32) / 4.0 31 | 32 | xc, yc = get_dot_position(s) 33 | xc = tf.cast(xc, tf.float32) / 83.0 34 | yc = tf.cast(yc, tf.float32) / 83.0 35 | 36 | features = [a, xc, yc] 37 | x = tf.stack(features, axis=1) 38 | 39 | x = dense_layer(x, 64, "d1", reuse, activation='relu') 40 | x = dense_layer(x, 64, "d2", reuse, activation='relu') 41 | x = dense_layer(x, 64, "d3", reuse, activation='relu') 42 | x = dense_layer(x, 1, "d4", reuse, activation=None) 43 | x = x[:, 0] 44 | 45 | return x 46 | 47 | 48 | def net_cnn(s, batchnorm, dropout, training, reuse): 49 | x = s / 255.0 50 | # Page 15: (Atari) 51 | # "[The] input is fed through 4 convolutional layers of size 7x7, 5x5, 3x3, 52 | # and 3x3 with strides 3, 2, 1, 1, each having 16 filters, with leaky ReLU 53 | # nonlinearities (α = 0.01). This is followed by a fully connected layer of 54 | # size 64 and then a scalar output. All convolutional layers use batch norm 55 | # and dropout with α = 0.5 to prevent predictor overfitting" 56 | x = conv_layer(x, 16, 7, 3, batchnorm, training, "c1", reuse, 'relu') 57 | x = tf.layers.dropout(x, dropout, training=training) 58 | x = conv_layer(x, 16, 5, 2, batchnorm, training, "c2", reuse, 'relu') 59 | x = tf.layers.dropout(x, dropout, training=training) 60 | x = conv_layer(x, 16, 3, 1, batchnorm, training, "c3", reuse, 'relu') 61 | x = tf.layers.dropout(x, dropout, training=training) 62 | x = conv_layer(x, 16, 3, 1, batchnorm, training, "c4", reuse, 'relu') 63 | 64 | w, h, c = x.get_shape()[1:] 65 | x = tf.reshape(x, [-1, int(w * h * c)]) 66 | 67 | x = dense_layer(x, 64, "d1", reuse, activation='relu') 68 | x = dense_layer(x, 1, "d2", reuse, activation=None) 69 | x = x[:, 0] 70 | 71 | return x 72 | -------------------------------------------------------------------------------- /reward_predictor_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import unittest 3 | 4 | import tensorflow as tf 5 | import termcolor 6 | import numpy as np 7 | from numpy import exp, log 8 | from numpy.testing import (assert_allclose, assert_approx_equal, 9 | assert_array_equal, assert_raises) 10 | 11 | from reward_predictor import RewardPredictorNetwork 12 | from reward_predictor_core_network import net_cnn 13 | 14 | 15 | class TestRewardPredictor(unittest.TestCase): 16 | 17 | def setUp(self): 18 | self.create_reward_predictor(dropout=0.5, batchnorm=True) 19 | termcolor.cprint(self._testMethodName, 'red') 20 | 21 | def create_reward_predictor(self, dropout, batchnorm): 22 | tf.reset_default_graph() 23 | self.sess = tf.Session() 24 | self.rpn = RewardPredictorNetwork(batchnorm=batchnorm, dropout=dropout, 25 | lr=1e-3, 26 | core_network=net_cnn) 27 | self.sess.run(tf.global_variables_initializer()) 28 | 29 | def test_weight_sharing(self): 30 | """ 31 | Check that both legs of the network give the same reward output 32 | for the same segment input. 33 | """ 34 | 35 | # We deliberately /don't/ use the same dropout for each leg of the 36 | # network. (If we do use the same dropout, without batchnorm, 37 | # Pong doesn't train successfully. If we use different dropout, Pong 38 | # does train successfully. I haven't tried training Pong with 39 | # batchnorm.) So we disable dropout for this test. 40 | self.create_reward_predictor(dropout=0.0, batchnorm=True) 41 | 42 | s = 255 * np.random.rand(100, 84, 84, 4) 43 | feed_dict_nontraining = { 44 | self.rpn.s1: [s], 45 | self.rpn.s2: [s], 46 | self.rpn.training: True 47 | } 48 | feed_dict_training = { 49 | self.rpn.s1: [s], 50 | self.rpn.s2: [s], 51 | self.rpn.training: False 52 | } 53 | for feed_dict in [feed_dict_nontraining, feed_dict_training]: 54 | for _ in range(3): # to check different dropouts 55 | [rs1], [rs2] = self.sess.run([self.rpn.rs1, self.rpn.rs2], feed_dict) 56 | # Check rs1 != 0.0 57 | assert_raises(AssertionError, assert_array_equal, rs1, 0.0) 58 | assert_allclose(rs1, rs2) 59 | 60 | def test_batchnorm_sharing(self): 61 | """ 62 | Check that batchnorm statistics are the same between the two legs of 63 | the network. 64 | """ 65 | n_frames = 20 66 | s1 = 255 * np.random.normal(loc=1.0, size=(n_frames, 84, 84, 4)) 67 | s2 = 255 * np.random.normal(loc=-1.0, size=(n_frames, 84, 84, 4)) 68 | feed_dict = { 69 | self.rpn.s1: [s1], 70 | self.rpn.s2: [s2], 71 | self.rpn.pref: [[0.0, 1.0]], 72 | self.rpn.training: True} 73 | self.sess.run(self.rpn.train, feed_dict) 74 | 75 | feed_dict = {self.rpn.s1: [s1], self.rpn.s2: [s1], self.rpn.training: False} 76 | [rs1], [rs2] = self.sess.run([self.rpn.rs1, self.rpn.rs2], feed_dict) 77 | # Check rs1 != 0.0 78 | assert_raises(AssertionError, assert_array_equal, rs1, 0.0) 79 | assert_allclose(rs1, rs2) 80 | 81 | def test_loss(self): 82 | """ 83 | Check that the loss is calculated correctly. 84 | """ 85 | # hack to ensure numerical stability 86 | rs1 = rs2 = 100 87 | n_frames = 20 88 | while rs1 > 50 or rs2 > 50: 89 | s1 = 255 * np.random.normal(loc=1.0, size=(n_frames, 84, 84, 4)) 90 | s2 = 255 * np.random.normal(loc=-1.0, size=(n_frames, 84, 84, 4)) 91 | feed_dict = { 92 | self.rpn.s1: [s1], 93 | self.rpn.s2: [s2], 94 | self.rpn.training: True 95 | } 96 | [rs1], [rs2] = self.sess.run([self.rpn.rs1, self.rpn.rs2], 97 | feed_dict) 98 | 99 | prefs = [[0.0, 1.0], [1.0, 0.0], [0.5, 0.5]] 100 | for pref in prefs: 101 | feed_dict[self.rpn.pref] = [pref] 102 | [rs1], [rs2], loss = self.sess.run( 103 | [self.rpn.rs1, self.rpn.rs2, self.rpn.loss], feed_dict) 104 | 105 | p_s1_s2 = exp(rs1) / (exp(rs1) + exp(rs2)) 106 | p_s2_s1 = exp(rs2) / (exp(rs1) + exp(rs2)) 107 | 108 | expected = -(pref[0] * log(p_s1_s2) + pref[1] * log(p_s2_s1)) 109 | assert_approx_equal(loss, expected, significant=3) 110 | 111 | def test_batches(self): 112 | """ 113 | Present a batch of two trajectories and check that we get the same 114 | results as if we'd presented the trajectories individually. 115 | """ 116 | n_segs = 2 117 | n_frames = 20 118 | prefs = [[0., 1.], [1., 0.]] 119 | s1s = [] 120 | s2s = [] 121 | for _ in range(n_segs): 122 | s1 = 255 * np.random.normal(loc=1.0, size=(n_frames, 84, 84, 4)) 123 | s2 = 255 * np.random.normal(loc=-1.0, size=(n_frames, 84, 84, 4)) 124 | s1s.append(s1) 125 | s2s.append(s2) 126 | 127 | # Step 1: present all trajectories as one big batch 128 | feed_dict = { 129 | self.rpn.s1: s1s, 130 | self.rpn.s2: s2s, 131 | self.rpn.pref: prefs, 132 | self.rpn.training: False 133 | } 134 | rs1_batch, rs2_batch, pred_batch, loss_batch = self.sess.run( 135 | [self.rpn.rs1, self.rpn.rs2, self.rpn.pred, self.rpn.loss], 136 | feed_dict) 137 | 138 | # Step 2: present trajectories individually 139 | rs1_nobatch = [] 140 | rs2_nobatch = [] 141 | pred_nobatch = [] 142 | loss_nobatch = 0 143 | for i in range(n_segs): 144 | feed_dict = { 145 | self.rpn.s1: [s1s[i]], 146 | self.rpn.s2: [s2s[i]], 147 | self.rpn.pref: [prefs[i]], 148 | self.rpn.training: False 149 | } 150 | [rs1], [rs2], [pred], loss = self.sess.run( 151 | [self.rpn.rs1, self.rpn.rs2, self.rpn.pred, self.rpn.loss], 152 | feed_dict) 153 | rs1_nobatch.append(rs1) 154 | rs2_nobatch.append(rs2) 155 | pred_nobatch.append(pred) 156 | loss_nobatch += loss 157 | 158 | # Compare 159 | assert_allclose(rs1_batch, rs1_nobatch, atol=1e-5) 160 | assert_allclose(rs2_batch, rs2_nobatch, atol=1e-5) 161 | assert_allclose(pred_batch, pred_nobatch, atol=1e-5) 162 | assert_approx_equal(loss_batch, loss_nobatch, significant=4) 163 | 164 | def test_training(self): 165 | """ 166 | Present two trajectories with different preferences and see whether 167 | training really does work (whether the reward predicted by the network 168 | matches the preferences after a few loops of running the training 169 | operation). 170 | 171 | Note: because of variations in training, this test does not always pass. 172 | """ 173 | n_frames = 20 174 | s1 = 255 * np.random.normal(loc=1.0, size=(n_frames, 84, 84, 4)) 175 | s2 = 255 * np.random.normal(loc=-1.0, size=(n_frames, 84, 84, 4)) 176 | 177 | feed_dict = { 178 | self.rpn.s1: [s1], 179 | self.rpn.s2: [s2] 180 | } 181 | 182 | prefs = [[0.0, 1.0], [1.0, 0.0], [0.5, 0.5]] 183 | for pref in prefs: 184 | print("Preference", pref) 185 | feed_dict[self.rpn.pref] = [pref] 186 | # Important to reset batch normalization statistics 187 | self.sess.run(tf.global_variables_initializer()) 188 | for _ in range(150): 189 | feed_dict[self.rpn.training] = True 190 | self.sess.run(self.rpn.train, feed_dict) 191 | # Uncomment these for more thorough manual testing. 192 | # (For the first case, rs1 should become higher 193 | # than rs2, and the distance between them should increase; 194 | # for the second case, rs2 should become higher; 195 | # for the third case, they should become approximately the 196 | # same.) 197 | """ 198 | feed_dict[self.rpn.training] = False 199 | ops = [self.rpn.rs1, self.rpn.rs2, self.rpn.loss] 200 | [rs1], [rs2], loss = self.sess.run(ops, feed_dict) 201 | print(" ".join(3 * ["{:>8.3f}"]).format(rs1, rs2, loss)) 202 | print() 203 | """ 204 | 205 | feed_dict[self.rpn.training] = False 206 | [rs1], [rs2] = self.sess.run([self.rpn.rs1, self.rpn.rs2], feed_dict) 207 | 208 | if pref[0] > pref[1]: 209 | self.assertGreater(rs1 - rs2, 10) 210 | elif pref[1] > pref[0]: 211 | self.assertGreater(rs2 - rs1, 10) 212 | elif pref[0] == pref[1]: 213 | self.assertLess(abs(rs2 - rs1), 2) 214 | 215 | def test_training_batches(self): 216 | """ 217 | Check that after training with a batch of 4 segments, each with their own preferences, 218 | the predicted preference for each of the segments is as expected. 219 | """ 220 | n_frames = 20 221 | s1s = 255 * np.random.normal(loc=1.0, size=(4, n_frames, 84, 84, 4)) 222 | s2s = 255 * np.random.normal(loc=-1.0, size=(4, n_frames, 84, 84, 4)) 223 | prefs = [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]] 224 | feed_dict = { 225 | self.rpn.s1: s1s, 226 | self.rpn.s2: s2s, 227 | self.rpn.pref: prefs, 228 | self.rpn.training: True 229 | } 230 | 231 | for i in range(100): 232 | if i % 10 == 0: 233 | print("Training {}/100".format(i)) 234 | self.sess.run(self.rpn.train, feed_dict) 235 | 236 | feed_dict[self.rpn.training] = False 237 | preds = self.sess.run(self.rpn.pred, feed_dict) 238 | assert_allclose(preds[0], [1., 0.], atol=1e-1) 239 | assert_allclose(preds[1], [1., 0.], atol=1e-1) 240 | assert_allclose(preds[2], [0., 1.], atol=1e-1) 241 | assert_allclose(preds[3], [0., 1.], atol=1e-1) 242 | 243 | def test_accuracy(self): 244 | """ 245 | Test accuracy op. 246 | """ 247 | n_frames = 20 248 | batch_n = 16 249 | s1s = 255 * np.random.normal(loc=1.0, size=(batch_n, n_frames, 84, 84, 4)) 250 | s2s = 255 * np.random.normal(loc=-1.0, size=(batch_n, n_frames, 84, 84, 4)) 251 | possible_prefs = [[1.0, 0.0], [0.0, 1.0]] 252 | possible_prefs = np.array(possible_prefs) 253 | prefs = possible_prefs[np.random.choice([0, 1], size=batch_n)] 254 | 255 | feed_dict = { 256 | self.rpn.s1: s1s, 257 | self.rpn.s2: s2s, 258 | self.rpn.pref: prefs, 259 | self.rpn.training: True 260 | } 261 | 262 | # Steer away from chance performance 263 | for _ in range(5): 264 | self.sess.run(self.rpn.train, feed_dict) 265 | 266 | feed_dict[self.rpn.training] = False 267 | preds = self.sess.run(self.rpn.pred, feed_dict) 268 | n_correct = 0 269 | for pref, pred in zip(prefs, preds): 270 | if pref[0] == 1.0 and pred[0] > pred[1] or \ 271 | pref[1] == 1.0 and pred[1] > pred[0]: 272 | n_correct += 1 273 | accuracy_expected = n_correct / batch_n 274 | 275 | accuracy_actual = self.sess.run(self.rpn.accuracy, feed_dict) 276 | 277 | assert_approx_equal(accuracy_actual, accuracy_expected) 278 | 279 | 280 | if __name__ == '__main__': 281 | unittest.main() 282 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import logging 4 | import os 5 | from os import path as osp 6 | import sys 7 | import time 8 | from multiprocessing import Process, Queue 9 | 10 | import cloudpickle 11 | import easy_tf_log 12 | from a2c import logger 13 | from a2c.a2c.a2c import learn 14 | from a2c.a2c.policies import CnnPolicy, MlpPolicy 15 | from a2c.common import set_global_seeds 16 | from a2c.common.vec_env.subproc_vec_env import SubprocVecEnv 17 | from params import parse_args, PREFS_VAL_FRACTION 18 | from pref_db import PrefDB, PrefBuffer 19 | from pref_interface import PrefInterface 20 | from reward_predictor import RewardPredictorEnsemble 21 | from reward_predictor_core_network import net_cnn, net_moving_dot_features 22 | from utils import VideoRenderer, get_port_range, make_env 23 | 24 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # filter out INFO messages 25 | 26 | 27 | def main(): 28 | general_params, a2c_params, \ 29 | pref_interface_params, rew_pred_training_params = parse_args() 30 | 31 | if general_params['debug']: 32 | logging.getLogger().setLevel(logging.DEBUG) 33 | 34 | run(general_params, 35 | a2c_params, 36 | pref_interface_params, 37 | rew_pred_training_params) 38 | 39 | 40 | def run(general_params, 41 | a2c_params, 42 | pref_interface_params, 43 | rew_pred_training_params): 44 | seg_pipe = Queue(maxsize=1) 45 | pref_pipe = Queue(maxsize=1) 46 | start_policy_training_flag = Queue(maxsize=1) 47 | 48 | if general_params['render_episodes']: 49 | episode_vid_queue, episode_renderer = start_episode_renderer() 50 | else: 51 | episode_vid_queue = episode_renderer = None 52 | 53 | if a2c_params['env_id'] in ['MovingDot-v0', 'MovingDotNoFrameskip-v0']: 54 | reward_predictor_network = net_moving_dot_features 55 | elif a2c_params['env_id'] in ['PongNoFrameskip-v4', 'EnduroNoFrameskip-v4']: 56 | reward_predictor_network = net_cnn 57 | else: 58 | raise Exception("Unsure about reward predictor network for {}".format( 59 | a2c_params['env_id'])) 60 | 61 | def make_reward_predictor(name, cluster_dict): 62 | return RewardPredictorEnsemble( 63 | cluster_job_name=name, 64 | cluster_dict=cluster_dict, 65 | log_dir=general_params['log_dir'], 66 | batchnorm=rew_pred_training_params['batchnorm'], 67 | dropout=rew_pred_training_params['dropout'], 68 | lr=rew_pred_training_params['lr'], 69 | core_network=reward_predictor_network) 70 | 71 | save_make_reward_predictor(general_params['log_dir'], 72 | make_reward_predictor) 73 | 74 | if general_params['mode'] == 'gather_initial_prefs': 75 | env, a2c_proc = start_policy_training( 76 | cluster_dict=None, 77 | make_reward_predictor=None, 78 | gen_segments=True, 79 | start_policy_training_pipe=start_policy_training_flag, 80 | seg_pipe=seg_pipe, 81 | episode_vid_queue=episode_vid_queue, 82 | log_dir=general_params['log_dir'], 83 | a2c_params=a2c_params) 84 | pi, pi_proc = start_pref_interface( 85 | seg_pipe=seg_pipe, 86 | pref_pipe=pref_pipe, 87 | log_dir=general_params['log_dir'], 88 | **pref_interface_params) 89 | 90 | n_train = general_params['max_prefs'] * (1 - PREFS_VAL_FRACTION) 91 | n_val = general_params['max_prefs'] * PREFS_VAL_FRACTION 92 | pref_db_train = PrefDB(maxlen=n_train) 93 | pref_db_val = PrefDB(maxlen=n_val) 94 | pref_buffer = PrefBuffer(db_train=pref_db_train, db_val=pref_db_val) 95 | pref_buffer.start_recv_thread(pref_pipe) 96 | pref_buffer.wait_until_len(general_params['n_initial_prefs']) 97 | pref_db_train, pref_db_val = pref_buffer.get_dbs() 98 | 99 | save_prefs(general_params['log_dir'], pref_db_train, pref_db_val) 100 | 101 | pi_proc.terminate() 102 | pi.stop_renderer() 103 | a2c_proc.terminate() 104 | pref_buffer.stop_recv_thread() 105 | 106 | env.close() 107 | elif general_params['mode'] == 'pretrain_reward_predictor': 108 | cluster_dict = create_cluster_dict(['ps', 'train']) 109 | ps_proc = start_parameter_server(cluster_dict, make_reward_predictor) 110 | rpt_proc = start_reward_predictor_training( 111 | cluster_dict=cluster_dict, 112 | make_reward_predictor=make_reward_predictor, 113 | just_pretrain=True, 114 | pref_pipe=pref_pipe, 115 | start_policy_training_pipe=start_policy_training_flag, 116 | max_prefs=general_params['max_prefs'], 117 | prefs_dir=general_params['prefs_dir'], 118 | load_ckpt_dir=None, 119 | n_initial_prefs=general_params['n_initial_prefs'], 120 | n_initial_epochs=rew_pred_training_params['n_initial_epochs'], 121 | val_interval=rew_pred_training_params['val_interval'], 122 | ckpt_interval=rew_pred_training_params['ckpt_interval'], 123 | log_dir=general_params['log_dir']) 124 | rpt_proc.join() 125 | ps_proc.terminate() 126 | elif general_params['mode'] == 'train_policy_with_original_rewards': 127 | env, a2c_proc = start_policy_training( 128 | cluster_dict=None, 129 | make_reward_predictor=None, 130 | gen_segments=False, 131 | start_policy_training_pipe=start_policy_training_flag, 132 | seg_pipe=seg_pipe, 133 | episode_vid_queue=episode_vid_queue, 134 | log_dir=general_params['log_dir'], 135 | a2c_params=a2c_params) 136 | start_policy_training_flag.put(True) 137 | a2c_proc.join() 138 | env.close() 139 | elif general_params['mode'] == 'train_policy_with_preferences': 140 | cluster_dict = create_cluster_dict(['ps', 'a2c', 'train']) 141 | ps_proc = start_parameter_server(cluster_dict, make_reward_predictor) 142 | env, a2c_proc = start_policy_training( 143 | cluster_dict=cluster_dict, 144 | make_reward_predictor=make_reward_predictor, 145 | gen_segments=True, 146 | start_policy_training_pipe=start_policy_training_flag, 147 | seg_pipe=seg_pipe, 148 | episode_vid_queue=episode_vid_queue, 149 | log_dir=general_params['log_dir'], 150 | a2c_params=a2c_params) 151 | pi, pi_proc = start_pref_interface( 152 | seg_pipe=seg_pipe, 153 | pref_pipe=pref_pipe, 154 | log_dir=general_params['log_dir'], 155 | **pref_interface_params) 156 | rpt_proc = start_reward_predictor_training( 157 | cluster_dict=cluster_dict, 158 | make_reward_predictor=make_reward_predictor, 159 | just_pretrain=False, 160 | pref_pipe=pref_pipe, 161 | start_policy_training_pipe=start_policy_training_flag, 162 | max_prefs=general_params['max_prefs'], 163 | prefs_dir=general_params['prefs_dir'], 164 | load_ckpt_dir=rew_pred_training_params['load_ckpt_dir'], 165 | n_initial_prefs=general_params['n_initial_prefs'], 166 | n_initial_epochs=rew_pred_training_params['n_initial_epochs'], 167 | val_interval=rew_pred_training_params['val_interval'], 168 | ckpt_interval=rew_pred_training_params['ckpt_interval'], 169 | log_dir=general_params['log_dir']) 170 | # We wait for A2C to complete the specified number of policy training 171 | # steps 172 | a2c_proc.join() 173 | rpt_proc.terminate() 174 | pi_proc.terminate() 175 | pi.stop_renderer() 176 | ps_proc.terminate() 177 | env.close() 178 | else: 179 | raise Exception("Unknown mode: {}".format(general_params['mode'])) 180 | 181 | if episode_renderer: 182 | episode_renderer.stop() 183 | 184 | 185 | def save_prefs(log_dir, pref_db_train, pref_db_val): 186 | train_path = osp.join(log_dir, 'train.pkl.gz') 187 | pref_db_train.save(train_path) 188 | print("Saved training preferences to '{}'".format(train_path)) 189 | val_path = osp.join(log_dir, 'val.pkl.gz') 190 | pref_db_val.save(val_path) 191 | print("Saved validation preferences to '{}'".format(val_path)) 192 | 193 | 194 | def save_make_reward_predictor(log_dir, make_reward_predictor): 195 | save_dir = osp.join(log_dir, 'reward_predictor_checkpoints') 196 | os.makedirs(save_dir, exist_ok=True) 197 | with open(osp.join(save_dir, 'make_reward_predictor.pkl'), 'wb') as fh: 198 | fh.write(cloudpickle.dumps(make_reward_predictor)) 199 | 200 | 201 | def create_cluster_dict(jobs): 202 | n_ports = len(jobs) + 1 203 | ports = get_port_range(start_port=2200, 204 | n_ports=n_ports, 205 | random_stagger=True) 206 | cluster_dict = {} 207 | for part, port in zip(jobs, ports): 208 | cluster_dict[part] = ['localhost:{}'.format(port)] 209 | return cluster_dict 210 | 211 | 212 | def configure_a2c_logger(log_dir): 213 | a2c_dir = osp.join(log_dir, 'a2c') 214 | os.makedirs(a2c_dir) 215 | tb = logger.TensorBoardOutputFormat(a2c_dir) 216 | logger.Logger.CURRENT = logger.Logger(dir=a2c_dir, output_formats=[tb]) 217 | 218 | 219 | def make_envs(env_id, n_envs, seed): 220 | def wrap_make_env(env_id, rank): 221 | def _thunk(): 222 | return make_env(env_id, seed + rank) 223 | return _thunk 224 | set_global_seeds(seed) 225 | env = SubprocVecEnv(env_id, [wrap_make_env(env_id, i) 226 | for i in range(n_envs)]) 227 | return env 228 | 229 | 230 | def start_parameter_server(cluster_dict, make_reward_predictor): 231 | def f(): 232 | make_reward_predictor('ps', cluster_dict) 233 | while True: 234 | time.sleep(1.0) 235 | 236 | proc = Process(target=f, daemon=True) 237 | proc.start() 238 | return proc 239 | 240 | 241 | def start_policy_training(cluster_dict, make_reward_predictor, gen_segments, 242 | start_policy_training_pipe, seg_pipe, 243 | episode_vid_queue, log_dir, a2c_params): 244 | env_id = a2c_params['env_id'] 245 | if env_id in ['MovingDotNoFrameskip-v0', 'MovingDot-v0']: 246 | policy_fn = MlpPolicy 247 | elif env_id in ['PongNoFrameskip-v4', 'EnduroNoFrameskip-v4']: 248 | policy_fn = CnnPolicy 249 | else: 250 | msg = "Unsure about policy network for {}".format(a2c_params['env_id']) 251 | raise Exception(msg) 252 | 253 | configure_a2c_logger(log_dir) 254 | 255 | # Done here because daemonic processes can't have children 256 | env = make_envs(a2c_params['env_id'], 257 | a2c_params['n_envs'], 258 | a2c_params['seed']) 259 | del a2c_params['env_id'], a2c_params['n_envs'] 260 | 261 | ckpt_dir = osp.join(log_dir, 'policy_checkpoints') 262 | os.makedirs(ckpt_dir) 263 | 264 | def f(): 265 | if make_reward_predictor: 266 | reward_predictor = make_reward_predictor('a2c', cluster_dict) 267 | else: 268 | reward_predictor = None 269 | misc_logs_dir = osp.join(log_dir, 'a2c_misc') 270 | easy_tf_log.set_dir(misc_logs_dir) 271 | learn( 272 | policy=policy_fn, 273 | env=env, 274 | seg_pipe=seg_pipe, 275 | start_policy_training_pipe=start_policy_training_pipe, 276 | episode_vid_queue=episode_vid_queue, 277 | reward_predictor=reward_predictor, 278 | ckpt_save_dir=ckpt_dir, 279 | gen_segments=gen_segments, 280 | **a2c_params) 281 | 282 | proc = Process(target=f, daemon=True) 283 | proc.start() 284 | return env, proc 285 | 286 | 287 | def start_pref_interface(seg_pipe, pref_pipe, max_segs, synthetic_prefs, 288 | log_dir): 289 | def f(): 290 | # The preference interface needs to get input from stdin. stdin is 291 | # automatically closed at the beginning of child processes in Python, 292 | # so this is a bit of a hack, but it seems to be fine. 293 | sys.stdin = os.fdopen(0) 294 | pi.run(seg_pipe=seg_pipe, pref_pipe=pref_pipe) 295 | 296 | # Needs to be done in the main process because does GUI setup work 297 | prefs_log_dir = osp.join(log_dir, 'pref_interface') 298 | pi = PrefInterface(synthetic_prefs=synthetic_prefs, 299 | max_segs=max_segs, 300 | log_dir=prefs_log_dir) 301 | proc = Process(target=f, daemon=True) 302 | proc.start() 303 | return pi, proc 304 | 305 | 306 | def start_reward_predictor_training(cluster_dict, 307 | make_reward_predictor, 308 | just_pretrain, 309 | pref_pipe, 310 | start_policy_training_pipe, 311 | max_prefs, 312 | n_initial_prefs, 313 | n_initial_epochs, 314 | prefs_dir, 315 | load_ckpt_dir, 316 | val_interval, 317 | ckpt_interval, 318 | log_dir): 319 | def f(): 320 | rew_pred = make_reward_predictor('train', cluster_dict) 321 | rew_pred.init_network(load_ckpt_dir) 322 | 323 | if prefs_dir is not None: 324 | train_path = osp.join(prefs_dir, 'train.pkl.gz') 325 | pref_db_train = PrefDB.load(train_path) 326 | print("Loaded training preferences from '{}'".format(train_path)) 327 | n_prefs, n_segs = len(pref_db_train), len(pref_db_train.segments) 328 | print("({} preferences, {} segments)".format(n_prefs, n_segs)) 329 | 330 | val_path = osp.join(prefs_dir, 'val.pkl.gz') 331 | pref_db_val = PrefDB.load(val_path) 332 | print("Loaded validation preferences from '{}'".format(val_path)) 333 | n_prefs, n_segs = len(pref_db_val), len(pref_db_val.segments) 334 | print("({} preferences, {} segments)".format(n_prefs, n_segs)) 335 | else: 336 | n_train = max_prefs * (1 - PREFS_VAL_FRACTION) 337 | n_val = max_prefs * PREFS_VAL_FRACTION 338 | pref_db_train = PrefDB(maxlen=n_train) 339 | pref_db_val = PrefDB(maxlen=n_val) 340 | 341 | pref_buffer = PrefBuffer(db_train=pref_db_train, 342 | db_val=pref_db_val) 343 | pref_buffer.start_recv_thread(pref_pipe) 344 | if prefs_dir is None: 345 | pref_buffer.wait_until_len(n_initial_prefs) 346 | 347 | save_prefs(log_dir, pref_db_train, pref_db_val) 348 | 349 | if not load_ckpt_dir: 350 | print("Pretraining reward predictor for {} epochs".format( 351 | n_initial_epochs)) 352 | pref_db_train, pref_db_val = pref_buffer.get_dbs() 353 | for i in range(n_initial_epochs): 354 | # Note that we deliberately don't update the preferences 355 | # databases during pretraining to keep the number of 356 | # fairly preferences small so that pretraining doesn't take too 357 | # long. 358 | print("Reward predictor training epoch {}".format(i)) 359 | rew_pred.train(pref_db_train, pref_db_val, val_interval) 360 | if i and i % ckpt_interval == 0: 361 | rew_pred.save() 362 | print("Reward predictor pretraining done") 363 | rew_pred.save() 364 | 365 | if just_pretrain: 366 | return 367 | 368 | start_policy_training_pipe.put(True) 369 | 370 | i = 0 371 | while True: 372 | pref_db_train, pref_db_val = pref_buffer.get_dbs() 373 | save_prefs(log_dir, pref_db_train, pref_db_val) 374 | rew_pred.train(pref_db_train, pref_db_val, val_interval) 375 | if i and i % ckpt_interval == 0: 376 | rew_pred.save() 377 | 378 | proc = Process(target=f, daemon=True) 379 | proc.start() 380 | return proc 381 | 382 | 383 | def start_episode_renderer(): 384 | episode_vid_queue = Queue() 385 | renderer = VideoRenderer( 386 | episode_vid_queue, 387 | playback_speed=2, 388 | zoom=2, 389 | mode=VideoRenderer.play_through_mode) 390 | return episode_vid_queue, renderer 391 | 392 | 393 | if __name__ == '__main__': 394 | main() 395 | -------------------------------------------------------------------------------- /run_checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Run a trained checkpoint to see what the agent is actually doing in the 5 | environment. 6 | """ 7 | 8 | import argparse 9 | import os.path as osp 10 | import time 11 | from collections import deque 12 | 13 | import cloudpickle 14 | import numpy as np 15 | import tensorflow as tf 16 | 17 | import matplotlib.pyplot as plt 18 | from matplotlib.ticker import FormatStrFormatter 19 | from utils import make_env 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | 25 | env = make_env(args.env) 26 | model = get_model(args.policy_ckpt_dir) 27 | if args.reward_predictor_ckpt_dir: 28 | reward_predictor = get_reward_predictor(args.reward_predictor_ckpt_dir) 29 | else: 30 | reward_predictor = None 31 | 32 | run_agent(env, model, reward_predictor, args.frame_interval_ms) 33 | 34 | 35 | def run_agent(env, model, reward_predictor, frame_interval_ms): 36 | nenvs = 1 37 | nstack = int(model.step_model.X.shape[-1]) 38 | nh, nw, nc = env.observation_space.shape 39 | obs = np.zeros((nenvs, nh, nw, nc * nstack), dtype=np.uint8) 40 | model_nenvs = int(model.step_model.X.shape[0]) 41 | states = model.initial_state 42 | if reward_predictor: 43 | value_graph = ValueGraph() 44 | while True: 45 | raw_obs = env.reset() 46 | update_obs(obs, raw_obs, nc) 47 | episode_reward = 0 48 | done = False 49 | while not done: 50 | model_obs = np.vstack([obs] * model_nenvs) 51 | actions, _, states = model.step(model_obs, states, [done]) 52 | action = actions[0] 53 | raw_obs, reward, done, _ = env.step(action) 54 | obs = update_obs(obs, raw_obs, nc) 55 | episode_reward += reward 56 | env.render() 57 | if reward_predictor is not None: 58 | predicted_reward = reward_predictor.reward(obs) 59 | # reward_predictor.reward returns reward for each frame in the 60 | # supplied batch. We only supplied one frame, so get the reward 61 | # for that frame. 62 | value_graph.append(predicted_reward[0]) 63 | time.sleep(frame_interval_ms * 1e-3) 64 | print("Episode reward:", episode_reward) 65 | 66 | 67 | def update_obs(obs, raw_obs, nc): 68 | obs = np.roll(obs, shift=-nc, axis=3) 69 | obs[:, :, :, -nc:] = raw_obs 70 | return obs 71 | 72 | 73 | def get_reward_predictor(ckpt_dir): 74 | with open(osp.join(ckpt_dir, 'make_reward_predictor.pkl'), 'rb') as fh: 75 | make_reward_predictor = cloudpickle.loads(fh.read()) 76 | cluster_dict = {'a2c': ['localhost:2200']} 77 | print("Initialising reward predictor...") 78 | reward_predictor = make_reward_predictor(name='a2c', cluster_dict=cluster_dict) 79 | reward_predictor.init_network(ckpt_dir) 80 | return reward_predictor 81 | 82 | 83 | def get_model(ckpt_dir): 84 | model_file = osp.join(ckpt_dir, 'make_model.pkl') 85 | with open(model_file, 'rb') as fh: 86 | make_model = cloudpickle.loads(fh.read()) 87 | print("Initialising policy...") 88 | model = make_model() 89 | ckpt_file = tf.train.latest_checkpoint(ckpt_dir) 90 | print("Loading checkpoint...") 91 | model.load(ckpt_file) 92 | return model 93 | 94 | 95 | def parse_args(): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("env") 98 | parser.add_argument("policy_ckpt_dir") 99 | parser.add_argument("--reward_predictor_ckpt_dir") 100 | parser.add_argument("--frame_interval_ms", type=float, default=0.) 101 | args = parser.parse_args() 102 | return args 103 | 104 | 105 | class ValueGraph: 106 | def __init__(self): 107 | n_values = 100 108 | self.data = deque(maxlen=n_values) 109 | 110 | self.fig, self.ax = plt.subplots() 111 | self.ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f')) 112 | self.fig.set_size_inches(4, 2) 113 | self.ax.set_xlim([0, n_values - 1]) 114 | self.ax.grid(axis='y') # Draw a line at 0 reward 115 | self.y_min = float('inf') 116 | self.y_max = -float('inf') 117 | self.line, = self.ax.plot([], []) 118 | 119 | self.fig.show() 120 | self.fig.canvas.draw() 121 | 122 | def append(self, value): 123 | self.data.append(value) 124 | 125 | self.y_min = min(self.y_min, min(self.data)) 126 | self.y_max = max(self.y_max, max(self.data)) 127 | self.ax.set_ylim([self.y_min, self.y_max]) 128 | self.ax.set_yticks([self.y_min, 0, self.y_max]) 129 | plt.tight_layout() 130 | 131 | ydata = list(self.data) 132 | xdata = list(range(len(self.data))) 133 | self.line.set_data(xdata, ydata) 134 | 135 | self.ax.draw_artist(self.line) 136 | self.fig.canvas.draw() 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /run_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | Simple tests to make sure each of the main commands basically run fine without 5 | any errors. 6 | """ 7 | 8 | import subprocess 9 | import tempfile 10 | import unittest 11 | from os.path import exists, join 12 | 13 | import termcolor 14 | 15 | 16 | def create_initial_prefs(out_dir, synthetic_prefs): 17 | cmd = ("python3 run.py gather_initial_prefs " 18 | "PongNoFrameskip-v4 " 19 | "--n_initial_prefs 1 " 20 | "--log_dir {}".format(out_dir)) 21 | if synthetic_prefs: 22 | cmd += " --synthetic_prefs" 23 | subprocess.call(cmd.split(' ')) 24 | 25 | 26 | class TestRun(unittest.TestCase): 27 | 28 | def setUp(self): 29 | termcolor.cprint(self._testMethodName, 'red') 30 | 31 | def test_end_to_end(self): 32 | with tempfile.TemporaryDirectory() as temp_dir: 33 | cmd = ("python3 run.py train_policy_with_preferences " 34 | "PongNoFrameskip-v4 " 35 | "--synthetic_prefs " 36 | "--million_timesteps 0.0001 " 37 | "--n_initial_prefs 1 " 38 | "--n_initial_epochs 1 " 39 | "--log_dir {0}".format(temp_dir)) 40 | subprocess.call(cmd.split(' ')) 41 | self.assertTrue(exists(join(temp_dir, 42 | 'policy_checkpoints', 43 | 'policy.ckpt-20.index'))) 44 | self.assertTrue(exists(join(temp_dir, 45 | 'reward_predictor_checkpoints', 46 | 'make_reward_predictor.pkl'))) 47 | 48 | def test_gather_prefs(self): 49 | for synthetic_prefs in [True, False]: 50 | if synthetic_prefs: 51 | termcolor.cprint('Synthetic preferences', 'green') 52 | else: 53 | termcolor.cprint('Human preferences', 'green') 54 | # Automatically deletes the directory afterwards :) 55 | with tempfile.TemporaryDirectory() as temp_dir: 56 | create_initial_prefs(temp_dir, synthetic_prefs) 57 | self.assertTrue(exists(join(temp_dir, 'train.pkl.gz'))) 58 | self.assertTrue(exists(join(temp_dir, 'val.pkl.gz'))) 59 | 60 | def test_pretrain_reward_predictor(self): 61 | with tempfile.TemporaryDirectory() as temp_dir: 62 | create_initial_prefs(temp_dir, synthetic_prefs=True) 63 | cmd = ("python3 run.py pretrain_reward_predictor " 64 | "PongNoFrameskip-v4 " 65 | "--n_initial_epochs 1 " 66 | "--load_prefs_dir {0} " 67 | "--log_dir {0}".format(temp_dir)) 68 | subprocess.call(cmd.split(' ')) 69 | self.assertTrue(exists(join(temp_dir, 70 | 'reward_predictor_checkpoints', 71 | 'reward_predictor.ckpt-1.index'))) 72 | 73 | 74 | if __name__ == '__main__': 75 | unittest.main() 76 | -------------------------------------------------------------------------------- /show_prefs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Display examples of the specified preference database 4 | (with the less-preferred segment on the left, 5 | and the more-preferred segment on the right) 6 | (skipping over equally-preferred segments) 7 | """ 8 | 9 | import argparse 10 | import gzip 11 | import pickle 12 | from multiprocessing import Queue 13 | 14 | import numpy as np 15 | 16 | from utils import VideoRenderer 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("prefs", help=".pkl.gz file") 22 | args = parser.parse_args() 23 | 24 | with gzip.open(args.prefs, 'rb') as pkl_file: 25 | print("Loading preferences from '{}'...".format(args.prefs), end="") 26 | prefs = pickle.load(pkl_file) 27 | print("done!") 28 | 29 | print("{} preferences found".format(len(prefs))) 30 | print("Preferred segment on the right") 31 | 32 | q = Queue() 33 | VideoRenderer(q, zoom=2, mode=VideoRenderer.restart_on_get_mode) 34 | 35 | for k1, k2, pref in prefs.prefs: 36 | if pref == (0.5, 0.5): 37 | continue 38 | 39 | if pref == (0.0, 1.0): 40 | s1 = np.array(prefs.segments[k1]) 41 | s2 = np.array(prefs.segments[k2]) 42 | elif pref == (1.0, 0.0): 43 | s1 = np.array(prefs.segments[k2]) 44 | s2 = np.array(prefs.segments[k1]) 45 | else: 46 | raise Exception("Unexpected preference", pref) 47 | 48 | vid = [] 49 | border = np.ones((84, 10), dtype=np.uint8) * 128 50 | for t in range(len(s1)): 51 | # -1 => select the last frame in the 4-frame stack 52 | f1 = s1[t, :, :, -1] 53 | f2 = s2[t, :, :, -1] 54 | frame = np.hstack((f1, border, f2)) 55 | vid.append(frame) 56 | n_pause_frames = 10 57 | for _ in range(n_pause_frames): 58 | vid.append(np.copy(vid[-1])) 59 | q.put(vid) 60 | input() 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import random 3 | import socket 4 | import time 5 | from multiprocessing import Process 6 | 7 | import gym 8 | import numpy as np 9 | import pyglet 10 | 11 | from a2c.common.atari_wrappers import wrap_deepmind 12 | from scipy.ndimage import zoom 13 | 14 | 15 | # https://github.com/joschu/modular_rl/blob/master/modular_rl/running_stat.py 16 | # http://www.johndcook.com/blog/standard_deviation/ 17 | class RunningStat(object): 18 | def __init__(self, shape=()): 19 | self._n = 0 20 | self._M = np.zeros(shape) 21 | self._S = np.zeros(shape) 22 | 23 | def push(self, x): 24 | x = np.asarray(x) 25 | assert x.shape == self._M.shape 26 | self._n += 1 27 | if self._n == 1: 28 | self._M[...] = x 29 | else: 30 | oldM = self._M.copy() 31 | self._M[...] = oldM + (x - oldM)/self._n 32 | self._S[...] = self._S + (x - oldM)*(x - self._M) 33 | 34 | @property 35 | def n(self): 36 | return self._n 37 | 38 | @property 39 | def mean(self): 40 | return self._M 41 | 42 | @property 43 | def var(self): 44 | if self._n >= 2: 45 | return self._S/(self._n - 1) 46 | else: 47 | return np.square(self._M) 48 | 49 | @property 50 | def std(self): 51 | return np.sqrt(self.var) 52 | 53 | @property 54 | def shape(self): 55 | return self._M.shape 56 | 57 | 58 | # Based on SimpleImageViewer in OpenAI gym 59 | class Im(object): 60 | def __init__(self, display=None): 61 | self.window = None 62 | self.isopen = False 63 | self.display = display 64 | 65 | def imshow(self, arr): 66 | if self.window is None: 67 | height, width = arr.shape 68 | self.window = pyglet.window.Window( 69 | width=width, height=height, display=self.display) 70 | self.width = width 71 | self.height = height 72 | self.isopen = True 73 | 74 | assert arr.shape == (self.height, self.width), \ 75 | "You passed in an image with the wrong number shape" 76 | 77 | image = pyglet.image.ImageData(self.width, self.height, 78 | 'L', arr.tobytes(), pitch=-self.width) 79 | self.window.clear() 80 | self.window.switch_to() 81 | self.window.dispatch_events() 82 | image.blit(0, 0) 83 | self.window.flip() 84 | 85 | def close(self): 86 | if self.isopen: 87 | self.window.close() 88 | self.isopen = False 89 | 90 | def __del__(self): 91 | self.close() 92 | 93 | 94 | class VideoRenderer: 95 | play_through_mode = 0 96 | restart_on_get_mode = 1 97 | 98 | def __init__(self, vid_queue, mode, zoom=1, playback_speed=1): 99 | assert mode == VideoRenderer.restart_on_get_mode or mode == VideoRenderer.play_through_mode 100 | self.mode = mode 101 | self.vid_queue = vid_queue 102 | self.zoom_factor = zoom 103 | self.playback_speed = playback_speed 104 | self.proc = Process(target=self.render) 105 | self.proc.start() 106 | 107 | def stop(self): 108 | self.proc.terminate() 109 | 110 | def render(self): 111 | v = Im() 112 | frames = self.vid_queue.get(block=True) 113 | t = 0 114 | while True: 115 | # Add a grey dot on the last line showing position 116 | width = frames[t].shape[1] 117 | fraction_played = t / len(frames) 118 | x = int(fraction_played * width) 119 | frames[t][-1][x] = 128 120 | 121 | zoomed_frame = zoom(frames[t], self.zoom_factor) 122 | v.imshow(zoomed_frame) 123 | 124 | if self.mode == VideoRenderer.play_through_mode: 125 | # Wait until having finished playing the current 126 | # set of frames. Then, stop, and get the most 127 | # recent set of frames. 128 | t += self.playback_speed 129 | if t >= len(frames): 130 | frames = self.get_queue_most_recent() 131 | t = 0 132 | else: 133 | time.sleep(1/60) 134 | elif self.mode == VideoRenderer.restart_on_get_mode: 135 | # Always try and get a new set of frames to show. 136 | # If there is a new set of frames on the queue, 137 | # restart playback with those frames immediately. 138 | # Otherwise, just keep looping with the current frames. 139 | try: 140 | frames = self.vid_queue.get(block=False) 141 | t = 0 142 | except queue.Empty: 143 | t = (t + self.playback_speed) % len(frames) 144 | time.sleep(1/60) 145 | 146 | def get_queue_most_recent(self): 147 | # Make sure we at least get something 148 | item = self.vid_queue.get(block=True) 149 | while True: 150 | try: 151 | item = self.vid_queue.get(block=True, timeout=0.1) 152 | except queue.Empty: 153 | break 154 | return item 155 | 156 | 157 | def get_port_range(start_port, n_ports, random_stagger=False): 158 | # If multiple runs try and call this function at the same time, 159 | # the function could return the same port range. 160 | # To guard against this, automatically offset the port range. 161 | if random_stagger: 162 | start_port += random.randint(0, 20) * n_ports 163 | 164 | free_range_found = False 165 | while not free_range_found: 166 | ports = [] 167 | for port_n in range(n_ports): 168 | port = start_port + port_n 169 | try: 170 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 171 | s.bind(("127.0.0.1", port)) 172 | ports.append(port) 173 | except socket.error as e: 174 | if e.errno == 98 or e.errno == 48: 175 | print("Warning: port {} already in use".format(port)) 176 | break 177 | else: 178 | raise e 179 | finally: 180 | s.close() 181 | if len(ports) < n_ports: 182 | # The last port we tried was in use 183 | # Try again, starting from the next port 184 | start_port = port + 1 185 | else: 186 | free_range_found = True 187 | 188 | return ports 189 | 190 | 191 | def profile_memory(log_path, pid): 192 | import memory_profiler 193 | def profile(): 194 | with open(log_path, 'w') as f: 195 | # timeout=99999 is necessary because for external processes, 196 | # memory_usage otherwise defaults to only returning a single sample 197 | # Note that even with interval=1, because memory_profiler only 198 | # flushes every 50 lines, we still have to wait 50 seconds before 199 | # updates. 200 | memory_profiler.memory_usage(pid, stream=f, 201 | timeout=99999, interval=1) 202 | p = Process(target=profile, daemon=True) 203 | p.start() 204 | return p 205 | 206 | 207 | def batch_iter(data, batch_size, shuffle=False): 208 | idxs = list(range(len(data))) 209 | if shuffle: 210 | np.random.shuffle(idxs) # in-place 211 | 212 | start_idx = 0 213 | end_idx = 0 214 | while end_idx < len(data): 215 | end_idx = start_idx + batch_size 216 | if end_idx > len(data): 217 | end_idx = len(data) 218 | 219 | batch_idxs = idxs[start_idx:end_idx] 220 | batch = [] 221 | for idx in batch_idxs: 222 | batch.append(data[idx]) 223 | 224 | yield batch 225 | start_idx += batch_size 226 | 227 | 228 | def make_env(env_id, seed=0): 229 | if env_id in ['MovingDot-v0', 'MovingDotNoFrameskip-v0']: 230 | import gym_moving_dot 231 | env = gym.make(env_id) 232 | env.seed(seed) 233 | if env_id == 'EnduroNoFrameskip-v4': 234 | from enduro_wrapper import EnduroWrapper 235 | env = EnduroWrapper(env) 236 | return wrap_deepmind(env) 237 | -------------------------------------------------------------------------------- /utils_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import socket 4 | import unittest 5 | from math import ceil 6 | 7 | import numpy as np 8 | 9 | from utils import RunningStat, batch_iter, get_port_range 10 | 11 | 12 | class TestUtils(unittest.TestCase): 13 | 14 | # https://github.com/joschu/modular_rl/blob/master/modular_rl/running_stat.py 15 | def test_running_stat(self): 16 | for shp in ((), (3, ), (3, 4)): 17 | li = [] 18 | rs = RunningStat(shp) 19 | for i in range(5): 20 | val = np.random.randn(*shp) 21 | rs.push(val) 22 | li.append(val) 23 | m = np.mean(li, axis=0) 24 | assert np.allclose(rs.mean, m) 25 | if i == 0: 26 | continue 27 | # ddof=1 => calculate unbiased sample variance 28 | v = np.var(li, ddof=1, axis=0) 29 | assert np.allclose(rs.var, v) 30 | 31 | def test_get_port_range(self): 32 | # Test 1: if we ask for 3 ports starting from port 60000 33 | # (which nothing should be listening on), we should get back 34 | # 60000, 60001 and 60002 35 | ports = get_port_range(60000, 3) 36 | self.assertEqual(ports, [60000, 60001, 60002]) 37 | 38 | # Test 2: if we set something listening on port 60000 39 | # then ask for the same ports as in test 1, 40 | # the function should skip over 60000 and give us the next 41 | # three ports 42 | s1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 43 | s1.bind(("127.0.0.1", 60000)) 44 | ports = get_port_range(60000, 3) 45 | self.assertEqual(ports, [60001, 60002, 60003]) 46 | 47 | # Test 3: if we set something listening on port 60002, 48 | # the function should realise it can't allocate a continuous 49 | # range starting from 60000 and should give us a range starting 50 | # from 60003 51 | s2 = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 52 | s2.bind(("127.0.0.1", 60002)) 53 | ports = get_port_range(60000, 3) 54 | self.assertEqual(ports, [60003, 60004, 60005]) 55 | 56 | s2.close() 57 | s1.close() 58 | 59 | def test_batch_iter_1(self): 60 | """ 61 | Check that batch_iter gives us exactly the right data back. 62 | """ 63 | l1 = list(range(16)) 64 | l2 = list(range(15)) 65 | l3 = list(range(13)) 66 | for l in [l1, l2, l3]: 67 | for shuffle in [True, False]: 68 | expected_data = l 69 | actual_data = set() 70 | expected_n_batches = ceil(len(l) / 4) 71 | actual_n_batches = 0 72 | for batch_n, x in enumerate(batch_iter(l, 73 | batch_size=4, 74 | shuffle=shuffle)): 75 | if batch_n == expected_n_batches - 1 and len(l) % 4 != 0: 76 | self.assertEqual(len(x), len(l) % 4) 77 | else: 78 | self.assertEqual(len(x), 4) 79 | self.assertEqual(len(actual_data.intersection(set(x))), 0) 80 | actual_data = actual_data.union(set(x)) 81 | actual_n_batches += 1 82 | self.assertEqual(actual_n_batches, expected_n_batches) 83 | np.testing.assert_array_equal(list(actual_data), expected_data) 84 | 85 | def test_batch_iter_2(self): 86 | """ 87 | Check that shuffle=True returns the same data but in a different order. 88 | """ 89 | expected_data = list(range(16)) 90 | actual_data = [] 91 | for x in batch_iter(expected_data, batch_size=4, shuffle=True): 92 | actual_data.extend(x) 93 | self.assertEqual(len(actual_data), len(expected_data)) 94 | self.assertEqual(set(actual_data), set(expected_data)) 95 | with self.assertRaises(AssertionError): 96 | np.testing.assert_array_equal(actual_data, expected_data) 97 | 98 | def test_batch_iter_3(self): 99 | """ 100 | Check that successive calls shuffle in a different order. 101 | """ 102 | data = list(range(16)) 103 | out1 = [] 104 | for x in batch_iter(data, batch_size=4, shuffle=True): 105 | out1.extend(x) 106 | out2 = [] 107 | for x in batch_iter(data, batch_size=4, shuffle=True): 108 | out2.extend(x) 109 | self.assertEqual(set(out1), set(out2)) 110 | with self.assertRaises(AssertionError): 111 | np.testing.assert_array_equal(out1, out2) 112 | 113 | 114 | if __name__ == '__main__': 115 | unittest.main() 116 | --------------------------------------------------------------------------------