├── .gitignore ├── LICENSE ├── Pipfile ├── Pipfile.lock ├── README.md ├── launch.py ├── lm_human_preferences ├── datasets │ ├── books.py │ ├── cnndm.py │ └── tldr.py ├── label_types.py ├── language │ ├── datasets.py │ ├── encodings.py │ ├── model.py │ ├── sample.py │ ├── test_model.py │ ├── test_sample.py │ └── trained_models.py ├── lm_tasks.py ├── policy.py ├── rewards.py ├── test_train_policy.py ├── test_train_reward.py ├── train_policy.py ├── train_reward.py └── utils │ ├── combos.py │ ├── core.py │ ├── gcs.py │ ├── hyperparams.py │ ├── launch.py │ ├── test_core_utils.py │ └── test_hyperparams.py ├── sample.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache 3 | *.egg-info/ 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Pipfile: -------------------------------------------------------------------------------- 1 | [[source]] 2 | name = "pypi" 3 | url = "https://pypi.org/simple" 4 | verify_ssl = true 5 | 6 | [dev-packages] 7 | 8 | [packages] 9 | cloudpickle = "==1.2.1" 10 | dataclasses = "==0.6.0" 11 | fire = "==0.1.3" 12 | ftfy = "==5.4.1" 13 | google-api-python-client = "==1.7.8" 14 | google-cloud-storage = "==1.13.0" 15 | mpi4py = "==3.0.2" 16 | mypy = "==0.580" 17 | numpy = "==1.16.2" 18 | pytest-instafail = "==0.3.0" 19 | pytest-timeout = "==1.2.0" 20 | pytest = "==3.5.0" 21 | pytz = "==2019.1" 22 | regex = "==2017.4.5" 23 | requests = "==2.18.0" 24 | tqdm = "==4.31.1" 25 | typeguard = ">=2.2.2" 26 | lm-human-preferences = {editable = true,path = "."} 27 | 28 | [requires] 29 | python_version = "3.7" 30 | -------------------------------------------------------------------------------- /Pipfile.lock: -------------------------------------------------------------------------------- 1 | { 2 | "_meta": { 3 | "hash": { 4 | "sha256": "aca3fc5344bba2aa6f9d399ce2323f3f0b72dd912e7f105e4139526101e2607a" 5 | }, 6 | "pipfile-spec": 6, 7 | "requires": { 8 | "python_version": "3.7" 9 | }, 10 | "sources": [ 11 | { 12 | "name": "pypi", 13 | "url": "https://pypi.org/simple", 14 | "verify_ssl": true 15 | } 16 | ] 17 | }, 18 | "default": { 19 | "attrs": { 20 | "hashes": [ 21 | "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79", 22 | "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399" 23 | ], 24 | "version": "==19.1.0" 25 | }, 26 | "cachetools": { 27 | "hashes": [ 28 | "sha256:428266a1c0d36dc5aca63a2d7c5942e88c2c898d72139fca0e97fdd2380517ae", 29 | "sha256:8ea2d3ce97850f31e4a08b0e2b5e6c34997d7216a9d2c98e0f3978630d4da69a" 30 | ], 31 | "version": "==3.1.1" 32 | }, 33 | "certifi": { 34 | "hashes": [ 35 | "sha256:e4f3620cfea4f83eedc95b24abd9cd56f3c4b146dd0177e83a21b4eb49e21e50", 36 | "sha256:fd7c7c74727ddcf00e9acd26bba8da604ffec95bf1c2144e67aff7a8b50e6cef" 37 | ], 38 | "version": "==2019.9.11" 39 | }, 40 | "chardet": { 41 | "hashes": [ 42 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae", 43 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691" 44 | ], 45 | "version": "==3.0.4" 46 | }, 47 | "cloudpickle": { 48 | "hashes": [ 49 | "sha256:603244e0f552b72a267d47a7d9b347b27a3430f58a0536037a290e7e0e212ecf", 50 | "sha256:b8ba7e322f2394b9bbbdc1c976e6442c2c02acc784cb9e553cee9186166a6890" 51 | ], 52 | "index": "pypi", 53 | "version": "==1.2.1" 54 | }, 55 | "dataclasses": { 56 | "hashes": [ 57 | "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f", 58 | "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84" 59 | ], 60 | "index": "pypi", 61 | "version": "==0.6.0" 62 | }, 63 | "fire": { 64 | "hashes": [ 65 | "sha256:c299d16064ff81cbb649b65988300d4a28b71ecfb789d1fb74d99ea98ae4d2eb" 66 | ], 67 | "index": "pypi", 68 | "version": "==0.1.3" 69 | }, 70 | "ftfy": { 71 | "hashes": [ 72 | "sha256:619e68f9844cadd03e0d835e9b6790b2399357100c57fddae14d93a8de81e114" 73 | ], 74 | "index": "pypi", 75 | "version": "==5.4.1" 76 | }, 77 | "google-api-core": { 78 | "hashes": [ 79 | "sha256:2c23fbc81c76b941ffb71301bb975ed66a610e9b03f918feacd1ed59cf43a6ec", 80 | "sha256:b2b91107bcc3b981633c89602b46451f6474973089febab3ee51c49cb7ae6a1f" 81 | ], 82 | "version": "==1.14.2" 83 | }, 84 | "google-api-python-client": { 85 | "hashes": [ 86 | "sha256:06907006ed5ce831018f03af3852d739c0b2489cdacfda6971bcc2075c762858", 87 | "sha256:937eabdc3940977f712fa648a096a5142766b6d0a0f58bc603e2ac0687397ef0" 88 | ], 89 | "index": "pypi", 90 | "version": "==1.7.8" 91 | }, 92 | "google-auth": { 93 | "hashes": [ 94 | "sha256:0f7c6a64927d34c1a474da92cfc59e552a5d3b940d3266606c6a28b72888b9e4", 95 | "sha256:20705f6803fd2c4d1cc2dcb0df09d4dfcb9a7d51fd59e94a3a28231fd93119ed" 96 | ], 97 | "version": "==1.6.3" 98 | }, 99 | "google-auth-httplib2": { 100 | "hashes": [ 101 | "sha256:098fade613c25b4527b2c08fa42d11f3c2037dda8995d86de0745228e965d445", 102 | "sha256:f1c437842155680cf9918df9bc51c1182fda41feef88c34004bd1978c8157e08" 103 | ], 104 | "version": "==0.0.3" 105 | }, 106 | "google-cloud-core": { 107 | "hashes": [ 108 | "sha256:0090df83dbc5cb2405fa90844366d13176d1c0b48181c1807ab15f53be403f73", 109 | "sha256:89e8140a288acec20c5e56159461d3afa4073570c9758c05d4e6cb7f2f8cc440" 110 | ], 111 | "version": "==0.28.1" 112 | }, 113 | "google-cloud-storage": { 114 | "hashes": [ 115 | "sha256:936c859c47f8e94fd0005e98235a10d5e75828d2c6c3a8caacae18344a572a0a", 116 | "sha256:fc32b9be41a45016ba2387e3ad23e70ccba399d626ef596409316f7cee477956" 117 | ], 118 | "index": "pypi", 119 | "version": "==1.13.0" 120 | }, 121 | "google-resumable-media": { 122 | "hashes": [ 123 | "sha256:5fd2e641f477e50be925a55bcfdf0b0cb97c2b92aacd7b15c1d339f70d55c1c7", 124 | "sha256:cdeb8fbb3551a665db921023603af2f0d6ac59ad8b48259cb510b8799505775f" 125 | ], 126 | "version": "==0.4.1" 127 | }, 128 | "googleapis-common-protos": { 129 | "hashes": [ 130 | "sha256:e61b8ed5e36b976b487c6e7b15f31bb10c7a0ca7bd5c0e837f4afab64b53a0c6" 131 | ], 132 | "version": "==1.6.0" 133 | }, 134 | "httplib2": { 135 | "hashes": [ 136 | "sha256:6901c8c0ffcf721f9ce270ad86da37bc2b4d32b8802d4a9cec38274898a64044", 137 | "sha256:cf6f9d5876d796539ec922a2c9b9a7cad9bfd90f04badcdc3bcfa537168052c3" 138 | ], 139 | "version": "==0.13.1" 140 | }, 141 | "idna": { 142 | "hashes": [ 143 | "sha256:3cb5ce08046c4e3a560fc02f138d0ac63e00f8ce5901a56b32ec8b7994082aab", 144 | "sha256:cc19709fd6d0cbfed39ea875d29ba6d4e22c0cebc510a76d6302a28385e8bb70" 145 | ], 146 | "version": "==2.5" 147 | }, 148 | "lm-human-preferences": { 149 | "editable": true, 150 | "path": "." 151 | }, 152 | "more-itertools": { 153 | "hashes": [ 154 | "sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832", 155 | "sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4" 156 | ], 157 | "version": "==7.2.0" 158 | }, 159 | "mpi4py": { 160 | "hashes": [ 161 | "sha256:014076ffa558bc8d1d82c820c94848ae5f9fe1aab3c9e0a18d80e0c339a4bbe4", 162 | "sha256:020dbf8c8d2b95b6098c6a66352907afed1c449d811fd085247d5ee244890bb1", 163 | "sha256:06514c4205e1de84d04c780ab6aa8751121203dd246a45b120817c4444bed341", 164 | "sha256:0bcd7acb12c7e830267f9d3df13da0576ccf1603fb1c9f940e600ceefbe69200", 165 | "sha256:1c83daae9a99908109200b29c9cfd93e7c0dc9cad50bef15f0ea85642c288746", 166 | "sha256:39807cca8195b0c1e43dc9a3e1d80ef4b7cdc66a9f19a184ce7c28d8b42b7f4a", 167 | "sha256:45b5674d0d630c31bbb94abd9563202ecd83e72a2c54ee719b9813d3a5938767", 168 | "sha256:4f2f6f5cdece7a95b53bfc884ff9201e270ca386f8c53b54ff2bec799e5b8e0c", 169 | "sha256:5c1b377022a43e515812f6064d7b1ec01fd61027592aa16e5ad5e14f27f8db3a", 170 | "sha256:baa8a41f5bddbf581f521fc68db1a297fe24a0256c36bf7dd22fcb3e2cc93ea1", 171 | "sha256:c105ac976e1605a6883db06a37b0dfac497b210de6d8569dc6d23af33597f145", 172 | "sha256:e452b96ff879700dcbcef19d145190d56621419e4fbc73e43998b2e692dc6eeb", 173 | "sha256:f8d629d1e3e3b7b89cb99d0e3bc5505e76cc42089829807950d5c56606ed48e0" 174 | ], 175 | "index": "pypi", 176 | "version": "==3.0.2" 177 | }, 178 | "mypy": { 179 | "hashes": [ 180 | "sha256:3bd95a1369810f7693366911d85be9f0a0bd994f6cb7162b7a994e5ded90e3d9", 181 | "sha256:7247f9948d7cdaae9408a4ee1662a01853c24e668117b4419acf025b05fbe3ce" 182 | ], 183 | "index": "pypi", 184 | "version": "==0.580" 185 | }, 186 | "numpy": { 187 | "hashes": [ 188 | "sha256:1980f8d84548d74921685f68096911585fee393975f53797614b34d4f409b6da", 189 | "sha256:22752cd809272671b273bb86df0f505f505a12368a3a5fc0aa811c7ece4dfd5c", 190 | "sha256:23cc40313036cffd5d1873ef3ce2e949bdee0646c5d6f375bf7ee4f368db2511", 191 | "sha256:2b0b118ff547fecabc247a2668f48f48b3b1f7d63676ebc5be7352a5fd9e85a5", 192 | "sha256:3a0bd1edf64f6a911427b608a894111f9fcdb25284f724016f34a84c9a3a6ea9", 193 | "sha256:3f25f6c7b0d000017e5ac55977a3999b0b1a74491eacb3c1aa716f0e01f6dcd1", 194 | "sha256:4061c79ac2230594a7419151028e808239450e676c39e58302ad296232e3c2e8", 195 | "sha256:560ceaa24f971ab37dede7ba030fc5d8fa173305d94365f814d9523ffd5d5916", 196 | "sha256:62be044cd58da2a947b7e7b2252a10b42920df9520fc3d39f5c4c70d5460b8ba", 197 | "sha256:6c692e3879dde0b67a9dc78f9bfb6f61c666b4562fd8619632d7043fb5b691b0", 198 | "sha256:6f65e37b5a331df950ef6ff03bd4136b3c0bbcf44d4b8e99135d68a537711b5a", 199 | "sha256:7a78cc4ddb253a55971115f8320a7ce28fd23a065fc33166d601f51760eecfa9", 200 | "sha256:80a41edf64a3626e729a62df7dd278474fc1726836552b67a8c6396fd7e86760", 201 | "sha256:893f4d75255f25a7b8516feb5766c6b63c54780323b9bd4bc51cdd7efc943c73", 202 | "sha256:972ea92f9c1b54cc1c1a3d8508e326c0114aaf0f34996772a30f3f52b73b942f", 203 | "sha256:9f1d4865436f794accdabadc57a8395bd3faa755449b4f65b88b7df65ae05f89", 204 | "sha256:9f4cd7832b35e736b739be03b55875706c8c3e5fe334a06210f1a61e5c2c8ca5", 205 | "sha256:adab43bf657488300d3aeeb8030d7f024fcc86e3a9b8848741ea2ea903e56610", 206 | "sha256:bd2834d496ba9b1bdda3a6cf3de4dc0d4a0e7be306335940402ec95132ad063d", 207 | "sha256:d20c0360940f30003a23c0adae2fe50a0a04f3e48dc05c298493b51fd6280197", 208 | "sha256:d3b3ed87061d2314ff3659bb73896e622252da52558f2380f12c421fbdee3d89", 209 | "sha256:dc235bf29a406dfda5790d01b998a1c01d7d37f449128c0b1b7d1c89a84fae8b", 210 | "sha256:fb3c83554f39f48f3fa3123b9c24aecf681b1c289f9334f8215c1d3c8e2f6e5b" 211 | ], 212 | "index": "pypi", 213 | "version": "==1.16.2" 214 | }, 215 | "pluggy": { 216 | "hashes": [ 217 | "sha256:7f8ae7f5bdf75671a718d2daf0a64b7885f74510bcd98b1a0bb420eb9a9d0cff", 218 | "sha256:d345c8fe681115900d6da8d048ba67c25df42973bda370783cd58826442dcd7c", 219 | "sha256:e160a7fcf25762bb60efc7e171d4497ff1d8d2d75a3d0df7a21b76821ecbf5c5" 220 | ], 221 | "version": "==0.6.0" 222 | }, 223 | "protobuf": { 224 | "hashes": [ 225 | "sha256:00a1b0b352dc7c809749526d1688a64b62ea400c5b05416f93cfb1b11a036295", 226 | "sha256:01acbca2d2c8c3f7f235f1842440adbe01bbc379fa1cbdd80753801432b3fae9", 227 | "sha256:0a795bca65987b62d6b8a2d934aa317fd1a4d06a6dd4df36312f5b0ade44a8d9", 228 | "sha256:0ec035114213b6d6e7713987a759d762dd94e9f82284515b3b7331f34bfaec7f", 229 | "sha256:31b18e1434b4907cb0113e7a372cd4d92c047ce7ba0fa7ea66a404d6388ed2c1", 230 | "sha256:32a3abf79b0bef073c70656e86d5bd68a28a1fbb138429912c4fc07b9d426b07", 231 | "sha256:55f85b7808766e5e3f526818f5e2aeb5ba2edcc45bcccede46a3ccc19b569cb0", 232 | "sha256:64ab9bc971989cbdd648c102a96253fdf0202b0c38f15bd34759a8707bdd5f64", 233 | "sha256:64cf847e843a465b6c1ba90fb6c7f7844d54dbe9eb731e86a60981d03f5b2e6e", 234 | "sha256:917c8662b585470e8fd42f052661fc66d59fccaae450a60044307dcbf82a3335", 235 | "sha256:afed9003d7f2be2c3df20f64220c30faec441073731511728a2cb4cab4cd46a6", 236 | "sha256:bf8e05d638b585d1752c5a84247134a0350d3a8b73d3632489a014a9f6f1e758", 237 | "sha256:d831b047bd69becaf64019a47179eb22118a50dd008340655266a906c69c6417", 238 | "sha256:de2760583ed28749ff885789c1cbc6c9c06d6de92fc825740ab99deb2f25ea4d", 239 | "sha256:eabc4cf1bc19689af8022ba52fd668564a8d96e0d08f3b4732d26a64255216a4", 240 | "sha256:fcff6086c86fb1628d94ea455c7b9de898afc50378042927a59df8065a79a549" 241 | ], 242 | "version": "==3.9.1" 243 | }, 244 | "py": { 245 | "hashes": [ 246 | "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa", 247 | "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53" 248 | ], 249 | "version": "==1.8.0" 250 | }, 251 | "pyasn1": { 252 | "hashes": [ 253 | "sha256:62cdade8b5530f0b185e09855dd422bc05c0bbff6b72ff61381c09dac7befd8c", 254 | "sha256:a9495356ca1d66ed197a0f72b41eb1823cf7ea8b5bd07191673e8147aecf8604" 255 | ], 256 | "version": "==0.4.7" 257 | }, 258 | "pyasn1-modules": { 259 | "hashes": [ 260 | "sha256:43c17a83c155229839cc5c6b868e8d0c6041dba149789b6d6e28801c64821722", 261 | "sha256:e30199a9d221f1b26c885ff3d87fd08694dbbe18ed0e8e405a2a7126d30ce4c0" 262 | ], 263 | "version": "==0.2.6" 264 | }, 265 | "pytest": { 266 | "hashes": [ 267 | "sha256:6266f87ab64692112e5477eba395cfedda53b1933ccd29478e671e73b420c19c", 268 | "sha256:fae491d1874f199537fd5872b5e1f0e74a009b979df9d53d1553fd03da1703e1" 269 | ], 270 | "index": "pypi", 271 | "version": "==3.5.0" 272 | }, 273 | "pytest-instafail": { 274 | "hashes": [ 275 | "sha256:b4d5fc3ca81e530a8d0e15a7771dc14b06fc9a0930c4b3909a7f4527040572c3" 276 | ], 277 | "index": "pypi", 278 | "version": "==0.3.0" 279 | }, 280 | "pytest-timeout": { 281 | "hashes": [ 282 | "sha256:c29e3168f10897728059bd6b8ca20b28733d7fe6b8f6c09bb9d89f6146f27cb8", 283 | "sha256:c65a80c87074c17b6dfbe91cd856f260f84fbdad5df9bd79b1cfc26fe5c163f1" 284 | ], 285 | "index": "pypi", 286 | "version": "==1.2.0" 287 | }, 288 | "pytz": { 289 | "hashes": [ 290 | "sha256:303879e36b721603cc54604edcac9d20401bdbe31e1e4fdee5b9f98d5d31dfda", 291 | "sha256:d747dd3d23d77ef44c6a3526e274af6efeb0a6f1afd5a69ba4d5be4098c8e141" 292 | ], 293 | "index": "pypi", 294 | "version": "==2019.1" 295 | }, 296 | "regex": { 297 | "hashes": [ 298 | "sha256:19c4b0f68dd97b7116e590f47d60d97ab9e76966acc321b1d20dd87c2b64dff2", 299 | "sha256:1af6b820bec5ca82af87447af5a6dcc23b3ddc96b0184fd71666be0c24fb2a4f", 300 | "sha256:232dbc28a2562d92d713c3c1eb2b9276f3ebcbdb6d3e96ff68d0417a71926784", 301 | "sha256:3d26ce7e605a501509b68c343fc9d9e09f76c2e9e261df8183027bdc750c97ce", 302 | "sha256:52b590a41b9677314d02d9055edc33992db758b3d5167aa1365229a6a0c26a6d", 303 | "sha256:565f9aac9cd43b2351f7fcbc0d6056f8aebf4f6d049a17982085019ab9acdf28", 304 | "sha256:656984899644d3fe2e40533724f513a21127f77162a15dd5244af3c965152c63", 305 | "sha256:689c9d17c3ba02f52e8481a5c584c8c11ba27d6cc5f939efdd838ae0d0d1af41", 306 | "sha256:8a9d9db8ef1621ae51ea12acb5e503204b4586e05c6cfd418aecb9466a71bd87", 307 | "sha256:ad2beea450d551b11b47512ce920127d7c8645e528cc56dc9502c5973e8732f3", 308 | "sha256:b39867f577bc59b2fec9209facc513c761978e4ac63f4b73b9750a2c1501729e", 309 | "sha256:b6a7725a069be8f9dd09e1e500e5b57556b301942e21c8c712627f73ec048286", 310 | "sha256:b9e9b97696e75e826adac1920b13e7bac3a6a2128c085783abd208d73a278d70", 311 | "sha256:bf4896ed1ca2017153fc6b341bc8a0da8ca5480f85eebd7bfe58bbafceb4e728", 312 | "sha256:c3c2fe1e0d90f4c93be5b588480f05defd44f64c65767a657de69c4db4429a39", 313 | "sha256:d811874ed669165fe1059a54f860db5c6ab5f48100bf4945d915fd2f877b2531", 314 | "sha256:db616380b04e29e5709bc3ec0674e827dfed3d18e7d686c09537ab01506127c9", 315 | "sha256:efa66273b49dbd7a9f6a4d02d1a7d5bf353d568a89f7cd8927812daa9f83bb84", 316 | "sha256:f8feab5b517cdc65a61a50549e7dcfa0f61ab872a0034da1f6b8d61775178b6a" 317 | ], 318 | "index": "pypi", 319 | "version": "==2017.4.5" 320 | }, 321 | "requests": { 322 | "hashes": [ 323 | "sha256:5e88d64aa56ac0fda54e77fb9762ebc65879e171b746d5479a33c4082519d6c6", 324 | "sha256:cd0189f962787284bff715fddaad478eb4d9c15aa167bd64e52ea0f661e7ea5c" 325 | ], 326 | "index": "pypi", 327 | "version": "==2.18.0" 328 | }, 329 | "rsa": { 330 | "hashes": [ 331 | "sha256:14ba45700ff1ec9eeb206a2ce76b32814958a98e372006c8fb76ba820211be66", 332 | "sha256:1a836406405730121ae9823e19c6e806c62bbad73f890574fff50efa4122c487" 333 | ], 334 | "version": "==4.0" 335 | }, 336 | "six": { 337 | "hashes": [ 338 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c", 339 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73" 340 | ], 341 | "version": "==1.12.0" 342 | }, 343 | "tqdm": { 344 | "hashes": [ 345 | "sha256:d385c95361699e5cf7622485d9b9eae2d4864b21cd5a2374a9c381ffed701021", 346 | "sha256:e22977e3ebe961f72362f6ddfb9197cc531c9737aaf5f607ef09740c849ecd05" 347 | ], 348 | "index": "pypi", 349 | "version": "==4.31.1" 350 | }, 351 | "typed-ast": { 352 | "hashes": [ 353 | "sha256:0cf0c406af2a6472a02254fe1ced40cb81a7c1215b7ceba88a3bb9c3a864f851", 354 | "sha256:1b784cd3c6778cd7b99afb41ddcaa1eb5b35a399210db7fcf24ed082670e0070", 355 | "sha256:2d7a322c1df6cccff2381c0475c1ebf82d3e9a331e48ed4ea89bbc72a8dedca6", 356 | "sha256:4304399ff89452871348f6fb7a7112454cd508fbe3eb49b5ed711cce9b99fe9e", 357 | "sha256:4658aebc30c0af80e63b579e917c04b592bdf10ef40da381b2fd179075b5d1b6", 358 | "sha256:471a7f12e55ad22f7a4bb2c3e62e39e3ab78008b24c61c48c9042e63b7359bb9", 359 | "sha256:57cb23412dac214383c6b6f0f7b0aec2d0c001a936af20f0b53542bbe4ba08a7", 360 | "sha256:5eb14e6b3aa5ff5d7e964b978a718227b5576b3965f1dd71dd055f71054233a5", 361 | "sha256:8219b6147af4d609096b6db2c797281e19fd3f7232ef35932bc74a812ff417a0", 362 | "sha256:8a7e9635cf0aaca04b2a4d4b3501c0dbc5c49a140b2e55b00e218d41ed2a69c8", 363 | "sha256:935157ada4aa115d61c59e759e43c5862b04d19ffe6fe5c9d735716587535cb7", 364 | "sha256:9525f4cbe3eb7b9e19a87c765ca9bbc1147ce18f75059e15138eb7fc59ce02e3", 365 | "sha256:99c140583eef6b50f3de4af44718a4fc63108671b29c468b5ff83ed383facf6d", 366 | "sha256:9e358ce6d4c43a90c15b99b76261adc852998680628c780f26fd64bc21adb9fa", 367 | "sha256:aaf63a024b54d2788cff3400de79009ee8a23594b581d4f33d90b7c67f8c05bd", 368 | "sha256:c3313b3fa1b6b722866eda370c14fd8f4962b6bcd1f6d43f42d6818a8b29d998", 369 | "sha256:c9342947e5f3480473d836754d69965a12ac2237d99ae85d1e3fdd1c1722669f", 370 | "sha256:cb1c7e5b3195103f5a784db7969fc55463cfae9b354e3b97cc219d32293d5e65", 371 | "sha256:d2d2cce74165cae2663167c921e331fb0eecfff2e93254dfdb16beb99716e519", 372 | "sha256:d6fc3b9fbf67d556223aa5493501022e1d585b9a1892fa87ba1257627763c461", 373 | "sha256:fa4eafaa57074958f065c2a6222d8f11162739f8c9db125472a1f04794a0b91d" 374 | ], 375 | "version": "==1.1.2" 376 | }, 377 | "typeguard": { 378 | "hashes": [ 379 | "sha256:5b90905662970cb47029cd5800b17b81608162ea2fcab7e5fd19bcc04a7d0b42", 380 | "sha256:5ecab47551c42a8090dcb914c550287a09caf599b4d47958445494f2822165aa" 381 | ], 382 | "index": "pypi", 383 | "version": "==2.5.0" 384 | }, 385 | "uritemplate": { 386 | "hashes": [ 387 | "sha256:01c69f4fe8ed503b2951bef85d996a9d22434d2431584b5b107b2981ff416fbd", 388 | "sha256:1b9c467a940ce9fb9f50df819e8ddd14696f89b9a8cc87ac77952ba416e0a8fd", 389 | "sha256:c02643cebe23fc8adb5e6becffe201185bf06c40bda5c0b4028a93f1527d011d" 390 | ], 391 | "version": "==3.0.0" 392 | }, 393 | "urllib3": { 394 | "hashes": [ 395 | "sha256:8ed6d5c1ff9d6ba84677310060d6a3a78ca3072ce0684cb3c645023009c114b1", 396 | "sha256:b14486978518ca0901a76ba973d7821047409d7f726f22156b24e83fd71382a5" 397 | ], 398 | "version": "==1.21.1" 399 | }, 400 | "wcwidth": { 401 | "hashes": [ 402 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e", 403 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c" 404 | ], 405 | "version": "==0.1.7" 406 | } 407 | }, 408 | "develop": {} 409 | } 410 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | **Status:** Archive (code is provided as-is, no updates expected) 2 | 3 | **Status:** All references to `gs://lm-human-preferences/` were updated to `https://openaipublic.blob.core.windows.net/lm-human-preferences`, as we migrated from GCP to Azure. The code provided as is may no longer work. Pull requests welcome 4 | 5 | # lm-human-preferences 6 | 7 | This repository contains code for the paper [Fine-Tuning Language Models from Human Preferences](https://arxiv.org/abs/1909.08593). See also our [blog post](https://openai.com/blog/fine-tuning-gpt-2/). 8 | 9 | We provide code for: 10 | - Training reward models from human labels 11 | - Fine-tuning language models using those reward models 12 | 13 | It does not contain code for generating labels. However, we have released human labels collected for our experiments, at `gs://lm-human-preferences/labels`. 14 | For those interested, the question and label schemas are simple and documented in [`label_types.py`](./lm_human_preferences/label_types.py). 15 | 16 | The code has only been tested using the smallest GPT-2 model (124M parameters). 17 | 18 | ## Instructions 19 | 20 | This code has only been tested using Python 3.7.3. Training has been tested on GCE machines with 8 V100s, running Ubuntu 16.04, but development also works on Mac OS X. 21 | 22 | ### Installation 23 | 24 | - Install [pipenv](https://github.com/pypa/pipenv#installation). 25 | 26 | - Install [tensorflow](https://www.tensorflow.org/install/gpu): Install CUDA 10.0 and cuDNN 7.6.2, then `pipenv install tensorflow-gpu==1.13.1`. The code may technically run with tensorflow on CPU but will be very slow. 27 | 28 | - Install [`gsutil`](https://cloud.google.com/storage/docs/gsutil_install) 29 | 30 | - Clone this repo. Then: 31 | ``` 32 | pipenv install 33 | ``` 34 | 35 | - (Recommended) Install [`horovod`](https://github.com/horovod/horovod#install) to speed up the code, or otherwise substitute some fast implementation in the `mpi_allreduce_sum` function of [`core.py`](./lm_human_preferences/utils/core.py). Make sure to use pipenv for the install, e.g. `pipenv install horovod==0.18.1`. 36 | 37 | ### Running 38 | 39 | The following examples assume we are aiming to train a model to continue text in a physically descriptive way. 40 | You can read [`launch.py`](./launch.py) to see how the `descriptiveness` experiments and others are defined. 41 | 42 | Note that we provide pre-trained models, so you can skip directly to RL fine-tuning or even to sampling from a trained policy, if desired. 43 | 44 | #### Training a reward model 45 | 46 | To train a reward model, use a command such as 47 | ``` 48 | experiment=descriptiveness 49 | reward_experiment_name=testdesc-$(date +%y%m%d%H%M) 50 | pipenv run ./launch.py train_reward $experiment $reward_experiment_name 51 | ``` 52 | 53 | This will save outputs (and tensorboard event files) to the directory `/tmp/save/train_reward/$reward_experiment_name`. The directory can be changed via the `--save_dir` flag. 54 | 55 | #### Finetuning a language model 56 | 57 | Once you have trained a reward model, you can finetune against it. 58 | 59 | First, set 60 | ``` 61 | trained_reward_model=/tmp/save/train_reward/$reward_experiment_name 62 | ``` 63 | or if using our pretrained model, 64 | ``` 65 | trained_reward_model=gs://lm-human-preferences/runs/descriptiveness/reward_model 66 | ``` 67 | 68 | Then, 69 | ``` 70 | experiment=descriptiveness 71 | policy_experiment_name=testdesc-$(date +%y%m%d%H%M) 72 | pipenv run ./launch.py train_policy $experiment $policy_experiment_name --rewards.trained_model $trained_reward_model --rewards.train_new_model 'off' 73 | ``` 74 | 75 | This will save outputs (and tensorboard event files) to the directory `/tmp/save/train_policy/$policy_experiment_name`. The directory can be changed via the `--save_dir` flag. 76 | 77 | #### Both steps at once 78 | 79 | You can run a single command to train a reward model and then finetune against it 80 | ``` 81 | experiment=descriptiveness 82 | experiment_name=testdesc-$(date +%y%m%d%H%M) 83 | pipenv run ./launch.py train_policy $experiment $experiment_name 84 | ``` 85 | 86 | In this case, outputs are in the directory `/tmp/save/train_policy/$policy_experiment_name`, and the reward model is saved to a subdirectory `reward_model`. The directory can be changed via the `--save_dir` flag. 87 | 88 | #### Sampling from a trained policy 89 | 90 | Specify the policy to load: 91 | ``` 92 | save_dir=/tmp/save/train_policy/$policy_experiment_name 93 | ``` 94 | or if using our pretrained model, 95 | ``` 96 | save_dir=gs://lm-human-preferences/runs/descriptiveness 97 | ``` 98 | 99 | Then run: 100 | ``` 101 | pipenv run ./sample.py sample --save_dir $save_dir --savescope policy 102 | ``` 103 | 104 | Note that this script can run on less than 8 GPUs. You can pass the flag `--mpi 1`, for exapmle, if you only have one GPU. 105 | 106 | ## LICENSE 107 | 108 | [MIT](./LICENSE) 109 | 110 | ## Citation 111 | 112 | Please cite the paper with the following bibtex entry: 113 | ``` 114 | @article{ziegler2019finetuning, 115 | title={Fine-Tuning Language Models from Human Preferences}, 116 | author={Ziegler, Daniel M. and Stiennon, Nisan and Wu, Jeffrey and Brown, Tom B. and Radford, Alec and Amodei, Dario and Christiano, Paul and Irving, Geoffrey}, 117 | journal={arXiv preprint arXiv:1909.08593}, 118 | url={https://arxiv.org/abs/1909.08593}, 119 | year={2019} 120 | } 121 | ``` 122 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from lm_human_preferences.utils import launch 4 | from lm_human_preferences.utils.combos import bind, combos, each, label, options_shortdesc, bind_nested 5 | from lm_human_preferences import train_policy, train_reward 6 | 7 | 8 | books_task = combos( 9 | bind('query_length', 64), 10 | bind('query_dataset', 'books'), 11 | bind('response_length', 24), 12 | bind('start_text', '.'), # Start the context at the beginning of a sentence 13 | bind('end_text', '.'), # End the context at the end of a sentence. 14 | bind('truncate_token', 13), # Encoding of '.' -- end completions at the end of a sentence. 15 | bind('truncate_after', 16), # Make sure completions are at least 16 tokens long. 16 | 17 | bind('policy.temperature', 0.7), 18 | bind('policy.initial_model', '124M'), 19 | ) 20 | 21 | summarize_cnndm_task = combos( 22 | bind('query_prefix', 'Article:\n\n'), 23 | bind('query_suffix', '\n\nTL;DR:'), 24 | bind('end_text', '\n'), 25 | bind('query_dataset', 'cnndm'), 26 | bind('query_length', 500), 27 | bind('response_length', 75), 28 | bind('start_text', None), 29 | bind('truncate_after', 55), 30 | bind('truncate_token', 198), # '\n' 31 | 32 | bind('policy.temperature', 0.5), 33 | bind('policy.initial_model', '124M'), 34 | ) 35 | 36 | summarize_tldr_task = combos( 37 | bind('query_suffix', '\n\nTL;DR:'), 38 | bind('query_dataset', 'tldr'), 39 | bind('query_length', 500), 40 | bind('response_length', 75), 41 | bind('start_text', None), 42 | bind('truncate_after', 55), 43 | bind('truncate_token', 198), # '\n' 44 | 45 | bind('policy.temperature', 0.7), 46 | bind('policy.initial_model', '124M'), 47 | ) 48 | 49 | def get_train_reward_experiments(): 50 | _shared = combos( 51 | bind('labels.type', 'best_of_4'), 52 | bind('normalize_after', True), 53 | bind('normalize_before', True), 54 | bind('normalize_samples', 256), 55 | ) 56 | 57 | 58 | _books_task = combos( 59 | bind_nested('task', books_task), 60 | _shared, 61 | bind('batch_size', 32), 62 | bind('lr', 5e-5), 63 | bind('rollout_batch_size', 512), 64 | ) 65 | 66 | sentiment = combos( 67 | _books_task, 68 | 69 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/sentiment/offline_5k.json'), 70 | bind('labels.num_train', 4_992), 71 | bind('run.seed', 1) 72 | ) 73 | 74 | 75 | descriptiveness = combos( 76 | _books_task, 77 | 78 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/descriptiveness/offline_5k.json'), 79 | bind('labels.num_train', 4_992), 80 | bind('run.seed', 1) 81 | ) 82 | 83 | cnndm = combos( 84 | bind_nested('task', summarize_cnndm_task), 85 | _shared, 86 | 87 | # bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/cnndm/offline_60k.json'), 88 | # bind('labels.num_train', 60_000), 89 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/cnndm/online_45k.json'), 90 | bind('labels.num_train', 46_000), 91 | 92 | bind('batch_size', 2 * 8), 93 | bind('lr', 2.5e-5), 94 | bind('rollout_batch_size', 128), 95 | bind('run.seed', 1) 96 | ) 97 | 98 | tldr = combos( 99 | bind_nested('task', summarize_tldr_task), 100 | _shared, 101 | 102 | # bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/tldr/offline_60k.json'), 103 | # bind('labels.num_train', 60_000), 104 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/tldr/online_45k.json'), 105 | bind('labels.num_train', 46_000), 106 | 107 | bind('batch_size', 2 * 8), 108 | bind('lr', 2.5e-5), 109 | bind('rollout_batch_size', 128), 110 | bind('run.seed', 1) 111 | ) 112 | 113 | return locals() 114 | 115 | 116 | def get_experiments(): 117 | train_reward_experiments = get_train_reward_experiments() 118 | 119 | _books_task = combos( 120 | bind_nested('task', books_task), 121 | 122 | bind('ppo.lr', 1e-5), 123 | bind('ppo.total_episodes', 1_000_000), 124 | bind('ppo.batch_size', 512), 125 | ) 126 | 127 | sentiment = combos( 128 | _books_task, 129 | bind('rewards.kl_coef', 0.15), 130 | bind('rewards.adaptive_kl', 'on'), 131 | bind('rewards.adaptive_kl.target', 6.0), 132 | 133 | bind('rewards.train_new_model', 'on'), 134 | bind_nested('rewards.train_new_model', train_reward_experiments['sentiment']), 135 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'), 136 | 137 | bind('run.seed', 1) 138 | ) 139 | 140 | descriptiveness = combos( 141 | _books_task, 142 | bind('rewards.kl_coef', 0.15), 143 | bind('rewards.adaptive_kl', 'on'), 144 | bind('rewards.adaptive_kl.target', 6.0), 145 | 146 | bind('rewards.train_new_model', 'on'), 147 | bind_nested('rewards.train_new_model', train_reward_experiments['descriptiveness']), 148 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'), 149 | 150 | bind('run.seed', 1) 151 | ) 152 | 153 | cnndm = combos( 154 | bind_nested('task', summarize_cnndm_task), 155 | 156 | bind('rewards.train_new_model', 'on'), 157 | bind_nested('rewards.train_new_model', train_reward_experiments['cnndm']), 158 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'), 159 | 160 | bind('ppo.total_episodes', 1_000_000), 161 | bind('ppo.lr', 2e-6), 162 | bind('rewards.kl_coef', 0.01), 163 | # bind('rewards.adaptive_kl', 'on'), 164 | # bind('rewards.adaptive_kl.target', 18.0), 165 | bind('ppo.batch_size', 32), 166 | bind('rewards.whiten', False), 167 | 168 | bind('run.seed', 1) 169 | ) 170 | 171 | tldr = combos( 172 | bind_nested('task', summarize_tldr_task), 173 | 174 | bind('rewards.train_new_model', 'on'), 175 | bind_nested('rewards.train_new_model', train_reward_experiments['tldr']), 176 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'), 177 | 178 | bind('ppo.total_episodes', 1_000_000), 179 | bind('ppo.lr', 2e-6), 180 | bind('rewards.kl_coef', 0.03), # 0.01 too low 181 | # bind('rewards.adaptive_kl', 'on'), 182 | # bind('rewards.adaptive_kl.target', 18.0), 183 | bind('ppo.batch_size', 32), 184 | bind('rewards.whiten', False), 185 | 186 | bind('run.seed', 1) 187 | ) 188 | 189 | return locals() 190 | 191 | 192 | def launch_train_policy(exp, name, dry_run=False, mpi=8, mode='local', save_dir='/tmp/save/train_policy', **extra_hparams): 193 | experiment_dict = get_experiments() 194 | try: 195 | trials = experiment_dict[exp] 196 | except KeyError: 197 | raise ValueError(f"Couldn't find experiment '{exp}'") 198 | 199 | launch.launch_trials( 200 | name, fn=train_policy.train, trials=trials, mpi=mpi, mode=mode, save_dir=save_dir, 201 | hparam_class=train_policy.HParams, extra_hparams=extra_hparams, dry_run=dry_run) 202 | 203 | 204 | def launch_train_reward(exp, name, dry_run=False, mpi=8, mode='local', save_dir='/tmp/save/train_reward', **extra_hparams): 205 | experiment_dict = get_train_reward_experiments() 206 | try: 207 | trials = experiment_dict[exp] 208 | except KeyError: 209 | raise ValueError(f"Couldn't find experiment '{exp}'") 210 | 211 | launch.launch_trials( 212 | name, fn=train_reward.train, trials=trials, mpi=mpi, mode=mode, save_dir=save_dir, 213 | hparam_class=train_reward.HParams, extra_hparams=extra_hparams, dry_run=dry_run) 214 | 215 | 216 | if __name__ == '__main__': 217 | launch.main(dict( 218 | train_policy=launch_train_policy, 219 | train_reward=launch_train_reward 220 | )) 221 | -------------------------------------------------------------------------------- /lm_human_preferences/datasets/books.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | from lm_human_preferences.utils import gcs 5 | 6 | 7 | def books_generator(mode, seed=0, shuffle=False, comm=None): 8 | datas = [ 9 | json.loads(line) for line in 10 | open(gcs.download_file_cached(f'https://openaipublic.blob.core.windows.net/lm-human-preferences/datasets/book_passages/{mode}.jsonl', comm=comm)) 11 | ] 12 | if shuffle: 13 | random.seed(seed) 14 | random.shuffle(datas) 15 | 16 | for x in datas: 17 | yield x 18 | -------------------------------------------------------------------------------- /lm_human_preferences/datasets/cnndm.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import random 4 | import re 5 | 6 | import ftfy 7 | 8 | from lm_human_preferences.utils import gcs 9 | 10 | dm_single_close_quote = u'\u2019' # unicode 11 | dm_double_close_quote = u'\u201d' 12 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence 13 | 14 | def read_text_file(text_file): 15 | lines = [] 16 | with open(text_file, "r") as f: 17 | for line in f: 18 | lines.append(line.strip()) 19 | return lines 20 | 21 | def fix_missing_period(line): 22 | """Adds a period to a line that is missing a period""" 23 | if "@highlight" in line: 24 | return line 25 | if line=="": 26 | return line 27 | if line[-1] in END_TOKENS: 28 | return line 29 | # print line[-1] 30 | return line + "." 31 | 32 | def get_art_abs(story_file): 33 | lines = read_text_file(story_file) 34 | # lines = [fix_missing_period(line) for line in lines] 35 | article_lines = [] 36 | highlights = [] 37 | next_is_highlight = False 38 | for line in lines: 39 | if line == "": 40 | continue # empty line 41 | elif line.startswith("@highlight"): 42 | next_is_highlight = True 43 | elif next_is_highlight: 44 | highlights.append(line) 45 | else: 46 | article_lines.append(line) 47 | article = '\n\n'.join(article_lines) 48 | 49 | # Make abstract into a single string, putting and tags around the sentences 50 | highlights = [fix_missing_period(sent) for sent in highlights] 51 | # abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights]) 52 | # abstract = ' '.join(highlights) 53 | return article, highlights 54 | 55 | def hashhex(s): 56 | """Returns a heximal formated SHA1 hash of the input string.""" 57 | h = hashlib.sha1() 58 | h.update(s) 59 | return h.hexdigest() 60 | 61 | def get_path_of_url(url): 62 | if 'dailymail.co.uk' in url or 'mailonsunday.ie' in url or 'lib.store.yahoo.net' in url: 63 | site = 'dailymail' 64 | else: 65 | assert 'cnn.com' in url or 'cnn.hk' in url, url 66 | site = 'cnn' 67 | url_hash = hashhex(url.encode('utf-8')) 68 | return f'{site}/stories/{url_hash}.story' 69 | 70 | def clean_up_start(text): 71 | if text[:2] == 'By': 72 | text = '\n'.join(text.split('\n')[2:]) 73 | text = re.split(r'\(CNN\) +--', text)[-1] 74 | text = re.split(r"\(CNN\)", text[:100])[-1]+text[100:] 75 | text = re.sub(r"^and \w+\n", "", text) 76 | text = re.split(r".*UPDATED:\s+[0-9]{2}:[0-9]{2}.*[2011|2012|2013|2014|2015]", text)[-1] 77 | text = text.replace('’', "'") 78 | text = text.replace('‘', "'") 79 | return text.strip() 80 | 81 | def cnndm_generator(mode, seed=0, shuffle=False, comm=None): 82 | # data originally from https://github.com/abisee/cnn-dailymail 83 | if mode == 'valid': 84 | mode = 'val' 85 | with open(gcs.download_file_cached(f'https://openaipublic.blob.core.windows.net/lm-human-preferences/datasets/cnndm/url_lists/all_{mode}.txt', comm=comm)) as f: 86 | urls = [line.strip() for line in f] 87 | if shuffle: 88 | random.seed(seed) 89 | random.shuffle(urls) 90 | # if n_eval > 0: 91 | # urls = urls[:n_eval] 92 | 93 | urls_dir = gcs.download_directory_cached(f'gs://lm-human-preferences/datasets/cnndm/cache_{mode}', comm=comm) 94 | 95 | for i, url in enumerate(urls): 96 | path = os.path.join(urls_dir, get_path_of_url(url)) 97 | text = open(path).read() 98 | text = clean_up_start(text) 99 | text = ftfy.fix_text(text) 100 | 101 | text = re.sub(r"\n{3,}", "\n\n", text) 102 | text = text.split('@highlight')[0].strip() 103 | yield text 104 | # _, ref_sents = get_art_abs(path) 105 | -------------------------------------------------------------------------------- /lm_human_preferences/datasets/tldr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | 5 | import ftfy 6 | 7 | from lm_human_preferences.utils import gcs 8 | 9 | 10 | def tldr_generator(mode, seed=0, shuffle=False, comm=None): 11 | random.seed(seed) 12 | 13 | if mode == 'test': 14 | mode = 'valid' # validation set serves as training set, since we don't have access.. 15 | assert mode in ['train', 'valid'] 16 | 17 | with open(gcs.download_file_cached(f'https://openaipublic.blob.core.windows.net/lm-human-preferences/tldr/{mode}-subset.json', comm=comm)) as f: 18 | datas = json.load(f) 19 | 20 | if shuffle: 21 | random.seed(seed) 22 | random.shuffle(datas) 23 | 24 | for data in datas: 25 | text = data['content'] 26 | text = ftfy.fix_text(text) 27 | text = re.sub(r"\n{3,}", "\n\n", text) 28 | text = text.strip() 29 | yield text 30 | -------------------------------------------------------------------------------- /lm_human_preferences/label_types.py: -------------------------------------------------------------------------------- 1 | """Interface and implementations of label types for a reward model.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Optional, Dict 5 | 6 | import tensorflow as tf 7 | 8 | from lm_human_preferences.utils.core import Schema, pearson_r 9 | 10 | 11 | class LabelType(ABC): 12 | @abstractmethod 13 | def label_schemas(self) -> Dict[str, Schema]: 14 | """Schema for the human annotations.""" 15 | 16 | @abstractmethod 17 | def target_scales(self, labels: Dict[str, tf.Tensor]) -> Optional[tf.Tensor]: 18 | """Extracts scalars out of labels whose scale corresponds to the reward model's output. 19 | May be none if the labels have no such information.""" 20 | 21 | @abstractmethod 22 | def loss(self, reward_model, labels: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]: 23 | """ 24 | :param labels: the questions with their labels 25 | :returns: a dict of stats, including 'loss' for the actual loss 26 | """ 27 | 28 | @abstractmethod 29 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]: 30 | """Schema for the questions associated with this LabelType.""" 31 | 32 | 33 | class PickBest(LabelType): 34 | """Pick best response amongst N.""" 35 | def __init__(self, num_responses): 36 | self.num_responses = num_responses 37 | 38 | def label_schemas(self): 39 | return dict(best=Schema(tf.int32, ())) 40 | 41 | def target_scales(self, labels): 42 | return None 43 | 44 | def loss(self, reward_model, labels): 45 | logits = tf.stack([reward_model(labels['query'], labels[f'sample{i}']) 46 | for i in range(self.num_responses)], axis=1) 47 | error = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 48 | labels=labels['best'], logits=logits)) 49 | return dict(loss=error, error=error) 50 | 51 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]: 52 | return dict( 53 | query=Schema(tf.int32, (query_length,)), 54 | **{f"sample{i}": Schema(tf.int32, (response_length,)) for i in range(self.num_responses)} 55 | ) 56 | 57 | 58 | class ScalarRating(LabelType): 59 | """Rate a single number with a scalar score.""" 60 | def __init__(self): 61 | pass 62 | 63 | def label_schemas(self): 64 | return dict( 65 | score=Schema(tf.float32, ())) 66 | 67 | def target_scales(self, labels): 68 | return labels['score'] 69 | 70 | def loss(self, reward_model, labels): 71 | predicted = reward_model(labels['query'], labels['sample']) 72 | labels = labels['score'] 73 | error = tf.reduce_mean((labels - predicted) ** 2, axis=0) 74 | label_mean, label_var = tf.nn.moments(labels, axes=[0]) 75 | corr = pearson_r(labels, predicted) 76 | return dict(loss=error, error=error, 77 | label_mean=label_mean, label_var=label_var, corr=corr) 78 | 79 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]: 80 | return dict( 81 | query=Schema(tf.int32, (query_length,)), 82 | sample=Schema(tf.int32, (response_length,)), 83 | ) 84 | 85 | 86 | class ScalarComparison(LabelType): 87 | """Give a scalar indicating difference between two responses.""" 88 | def label_schemas(self): 89 | return dict(difference=Schema(tf.float32, ())) 90 | 91 | def target_scales(self, labels): 92 | # Divide by two to get something with the same variance as the trained reward model output 93 | return labels['difference']/2 94 | 95 | def loss(self, reward_model, labels): 96 | outputs0 = reward_model(labels['query'], labels['sample0']) 97 | outputs1 = reward_model(labels['query'], labels['sample1']) 98 | 99 | differences = labels['difference'] 100 | predicted_differences = outputs1 - outputs0 101 | error = tf.reduce_mean((differences - predicted_differences)**2, axis=0) 102 | return dict(loss=error, error=error) 103 | 104 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]: 105 | return dict( 106 | query=Schema(tf.int32, (query_length,)), 107 | sample0=Schema(tf.int32, (response_length,)), 108 | sample1=Schema(tf.int32, (response_length,)), 109 | ) 110 | 111 | 112 | def get(label_type: str) -> LabelType: 113 | if label_type == 'scalar_rating': 114 | return ScalarRating() 115 | if label_type == 'scalar_compare': 116 | return ScalarComparison() 117 | if label_type.startswith('best_of_'): 118 | n = int(label_type[len('best_of_'):]) 119 | return PickBest(n) 120 | raise ValueError(f"Unexpected label type {label_type}") 121 | -------------------------------------------------------------------------------- /lm_human_preferences/language/datasets.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Dict 3 | 4 | import tensorflow as tf 5 | 6 | from lm_human_preferences.datasets.books import books_generator 7 | from lm_human_preferences.datasets.cnndm import cnndm_generator 8 | from lm_human_preferences.datasets.tldr import tldr_generator 9 | 10 | _registry: Dict[str, "Dataset"] = {} 11 | 12 | class Dataset: 13 | def __init__( 14 | self, 15 | name, 16 | *, 17 | generator=None, 18 | ): 19 | global _registry 20 | assert name not in _registry 21 | _registry[name] = self 22 | 23 | self.name = name 24 | 25 | self.generator = generator 26 | 27 | def tf_dataset( 28 | self, 29 | sequence_length, 30 | *, 31 | mode, 32 | encoder=None, 33 | seed=0, 34 | comm=None, 35 | shuffle=True, 36 | repeat_count=None, # Defaults to infinite repeat 37 | # trims so that it starts right after start token 38 | start_token=None, 39 | # trims off last end_token 40 | end_token=None, 41 | padding_token=None, 42 | ): 43 | if padding_token is None: 44 | padding_token = encoder.padding_token 45 | def _generator(): 46 | inner_gen = self.generator(mode, seed=seed, shuffle=shuffle, comm=comm) 47 | for text in inner_gen: 48 | tokens = encoder.encode(text) 49 | if start_token is not None: 50 | try: 51 | first_index = tokens.index(start_token)+1 52 | if first_index < len(tokens): 53 | tokens = tokens[first_index:] 54 | except: 55 | continue 56 | 57 | tokens = tokens[:sequence_length] 58 | 59 | if end_token is not None: 60 | try: 61 | last_index = len(tokens)-tokens[::-1].index(end_token) 62 | tokens = tokens[:last_index] 63 | except: 64 | continue 65 | 66 | if len(tokens) < sequence_length: 67 | tokens = tokens + [padding_token] * (sequence_length - len(tokens)) 68 | 69 | assert len(tokens) == sequence_length 70 | 71 | yield dict(tokens=tokens) 72 | 73 | tf_dataset = tf.data.Dataset.from_generator( 74 | _generator, 75 | output_types=dict(tokens=tf.int32), 76 | output_shapes=dict(tokens=(sequence_length,)), 77 | ) 78 | tf_dataset = tf_dataset.repeat(repeat_count) 79 | 80 | if comm is not None: 81 | num_shards = comm.Get_size() 82 | shard_idx = comm.Get_rank() 83 | if num_shards > 1: 84 | assert seed is not None 85 | tf_dataset = tf_dataset.shard(num_shards, shard_idx) 86 | 87 | return tf_dataset 88 | 89 | 90 | def get_dataset(name) -> Dataset: 91 | global _registry 92 | return _registry[name] 93 | 94 | CnnDm = Dataset( 95 | "cnndm", 96 | generator=cnndm_generator, 97 | ) 98 | 99 | Tldr = Dataset( 100 | "tldr", 101 | generator=tldr_generator, 102 | ) 103 | 104 | Books = Dataset( 105 | "books", 106 | generator=books_generator, 107 | ) 108 | 109 | def test_generator(mode, seed=0, shuffle=False, comm=None): 110 | while True: 111 | yield ''.join([random.choice('abcdefghijklmnopqrstuvwxyz.') for _ in range(40)]) 112 | 113 | Test = Dataset( 114 | "test", 115 | generator=test_generator 116 | ) 117 | 118 | 119 | """ 120 | import tensorflow as tf 121 | from lm_human_preferences.language.datasets import Books as ds 122 | from lm_human_preferences.language.encodings import Main as encoding 123 | 124 | e = encoding.get_encoder() 125 | x = ds.tf_dataset(16, mode='test', encoder=e) 126 | op = x.make_one_shot_iterator().get_next() 127 | s = tf.Session() 128 | 129 | while True: 130 | print(e.decode(s.run(op)['tokens'])) 131 | input() 132 | """ 133 | -------------------------------------------------------------------------------- /lm_human_preferences/language/encodings.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import json 4 | import os 5 | from functools import lru_cache 6 | 7 | import tensorflow as tf 8 | import regex as re 9 | 10 | @lru_cache() 11 | def bytes_to_unicode(): 12 | """ 13 | Returns list of utf-8 byte and a corresponding list of unicode strings. 14 | The reversible bpe codes work on unicode strings. 15 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 16 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 17 | This is a signficant percentage of your normal, say, 32K bpe vocab. 18 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 19 | And avoids mapping to whitespace/control characters the bpe code barfs on. 20 | """ 21 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 22 | cs = bs[:] 23 | n = 0 24 | for b in range(2 ** 8): 25 | if b not in bs: 26 | bs.append(b) 27 | cs.append(2 ** 8 + n) 28 | n += 1 29 | cs = [chr(n) for n in cs] 30 | return dict(zip(bs, cs)) 31 | 32 | 33 | def get_pairs(word): 34 | """Return set of symbol pairs in a word. 35 | 36 | Word is represented as tuple of symbols (symbols being variable-length strings). 37 | """ 38 | pairs = set() 39 | prev_char = word[0] 40 | for char in word[1:]: 41 | pairs.add((prev_char, char)) 42 | prev_char = char 43 | return pairs 44 | 45 | 46 | class ReversibleEncoder: 47 | def __init__(self, encoder, bpe_merges, errors="replace", eot_token=None): 48 | self.encoder = encoder 49 | self.decoder = {v: k for k, v in self.encoder.items()} 50 | self.errors = errors # how to handle errors in decoding 51 | self.byte_encoder = bytes_to_unicode() 52 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 53 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 54 | self.eot_token = eot_token 55 | self.cache = {} 56 | self.padding_token = len(encoder) + 2 # +2 unnecessary, for historical reasons 57 | self.decoder[self.padding_token] = '' 58 | 59 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 60 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 61 | 62 | def bpe(self, token): 63 | if token in self.cache: 64 | return self.cache[token] 65 | word = tuple(token) 66 | pairs = get_pairs(word) 67 | 68 | if not pairs: 69 | return token 70 | 71 | while True: 72 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 73 | if bigram not in self.bpe_ranks: 74 | break 75 | first, second = bigram 76 | new_word = [] 77 | i = 0 78 | while i < len(word): 79 | try: 80 | j = word.index(first, i) 81 | new_word.extend(word[i:j]) 82 | i = j 83 | except: 84 | new_word.extend(word[i:]) 85 | break 86 | 87 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 88 | new_word.append(first + second) 89 | i += 2 90 | else: 91 | new_word.append(word[i]) 92 | i += 1 93 | new_word = tuple(new_word) 94 | word = new_word 95 | if len(word) == 1: 96 | break 97 | else: 98 | pairs = get_pairs(word) 99 | word = " ".join(word) 100 | self.cache[token] = word 101 | return word 102 | 103 | def encode(self, text): 104 | bpe_tokens = [] 105 | for token in re.findall(self.pat, text): 106 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 107 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) 108 | return bpe_tokens 109 | 110 | def decode(self, tokens, pretty=False): 111 | del pretty 112 | text = "".join([self.decoder[token] for token in tokens]) 113 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) 114 | return text 115 | 116 | 117 | def read_file(path): 118 | with tf.gfile.Open(path, "rb") as fh: 119 | return fh.read() 120 | 121 | 122 | class Encoding: 123 | def __init__( 124 | self, 125 | name, 126 | *, 127 | n_vocab=0, 128 | eot_token=None, 129 | encoder_path="encoder.json", 130 | bpe_path="vocab.bpe", 131 | base_path=None, 132 | ): 133 | self.name = name 134 | self.eot_token = eot_token 135 | self.n_vocab = n_vocab 136 | 137 | if base_path is None: 138 | base_path = os.path.join("gs://gpt-2/encodings", name) 139 | 140 | self.base_path = base_path 141 | if name != "test": 142 | self.encoder_path = os.path.join(self.base_path, encoder_path) 143 | self.bpe_path = os.path.join(self.base_path, bpe_path) 144 | 145 | def get_encoder(self): 146 | if self.name == "test": 147 | vocab = "abcdefghijklmnopqrstuvwxyz." 148 | assert len(vocab) == self.n_vocab 149 | 150 | class TestEncoder(ReversibleEncoder): 151 | def __init__(self): 152 | super().__init__(encoder={w: i for i, w in enumerate(vocab)}, bpe_merges=list()) 153 | self.padding_token = len(vocab) 154 | def encode(self, text): 155 | return [self.encoder.get(x, len(vocab) - 1) for x in text] 156 | def decode(self, tokens, pretty=False): 157 | return ''.join([self.decoder.get(t, '') for t in tokens]) 158 | 159 | return TestEncoder() 160 | 161 | encoder_dict = json.loads(read_file(self.encoder_path).decode()) 162 | bpe_data = read_file(self.bpe_path).decode() 163 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] 164 | assert len(encoder_dict) == self.n_vocab 165 | encoder = ReversibleEncoder(encoder=encoder_dict, bpe_merges=bpe_merges, eot_token=self.eot_token) 166 | assert encoder.padding_token >= self.n_vocab 167 | return encoder 168 | 169 | 170 | Main = Encoding("main", n_vocab=50257, eot_token=50256) 171 | 172 | Test = Encoding("test", n_vocab=27, eot_token=26) 173 | -------------------------------------------------------------------------------- /lm_human_preferences/language/model.py: -------------------------------------------------------------------------------- 1 | """Alec's transformer model.""" 2 | 3 | from functools import partial 4 | from typing import Optional 5 | from dataclasses import dataclass 6 | 7 | import tensorflow as tf 8 | import numpy as np 9 | from tensorflow.python.framework import function 10 | 11 | from lm_human_preferences.utils import core as utils 12 | from lm_human_preferences.utils import hyperparams 13 | 14 | @dataclass 15 | class HParams(hyperparams.HParams): 16 | # Encoding (set during loading process) 17 | n_vocab: int = 0 18 | 19 | # Model parameters 20 | n_ctx: int = 512 21 | n_embd: int = 768 22 | n_head: int = 12 23 | n_layer: int = 12 24 | 25 | embd_pdrop: float = 0.1 26 | attn_pdrop: float = 0.1 27 | resid_pdrop: float = 0.1 28 | head_pdrop: float = 0.1 29 | 30 | 31 | def parse_comma_separated_int_list(s): 32 | return [int(i) for i in s.split(",")] if s else [] 33 | 34 | 35 | def gelu(x): 36 | with tf.name_scope('gelu'): 37 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 38 | 39 | 40 | def dropout(x, pdrop, *, do_dropout, stateless=True, seed=None, name): 41 | """Like tf.nn.dropout but stateless. 42 | """ 43 | if stateless: 44 | assert seed is not None 45 | def _dropout(): 46 | with tf.name_scope(name): 47 | noise_shape = tf.shape(x) 48 | 49 | if stateless: 50 | r = tf.random.stateless_uniform(noise_shape, seed, dtype=x.dtype) 51 | # floor uniform [keep_prob, 1.0 + keep_prob) 52 | mask = tf.floor(1 - pdrop + r) 53 | return x * (mask * (1 / (1 - pdrop))) 54 | else: 55 | return tf.nn.dropout(x, rate=pdrop, noise_shape=noise_shape) 56 | if pdrop == 0 or not do_dropout: 57 | return x 58 | else: 59 | return _dropout() 60 | 61 | 62 | def norm(x, scope, *, axis=-1, epsilon=1e-5): 63 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 64 | with tf.variable_scope(scope): 65 | n_state = x.shape[-1].value 66 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 67 | s = tf.reduce_mean(tf.square(x), axis=axis, keepdims=True) 68 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) 69 | u = tf.reduce_mean(x, axis=axis, keepdims=True) 70 | s = s - tf.square(u) 71 | x = (x - u) * tf.rsqrt(s + epsilon) 72 | x = x*g + b 73 | return x 74 | 75 | 76 | def split_states(x, n): 77 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 78 | *start, m = utils.shape_list(x) 79 | return tf.reshape(x, start + [n, m//n]) 80 | 81 | 82 | def merge_states(x): 83 | """Smash the last two dimensions of x into a single dimension.""" 84 | *start, a, b = utils.shape_list(x) 85 | return tf.reshape(x, start + [a*b]) 86 | 87 | 88 | def conv1x1(x, scope, nf, *, w_init_stdev=0.02): 89 | with tf.variable_scope(scope): 90 | *start, nx = utils.shape_list(x) 91 | 92 | # Don't cast params until just prior to use -- saves a lot of memory for large models 93 | with tf.control_dependencies([x]): 94 | w = tf.squeeze(tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)), axis=0) 95 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) 96 | c = tf.matmul(tf.reshape(x, [-1, nx]), w) + b 97 | c = tf.reshape(c, start+[nf]) 98 | return c 99 | 100 | 101 | def attention_mask(nd, ns, *, dtype): 102 | """1's in the lower triangle, counting from the lower right corner. 103 | 104 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 105 | """ 106 | i = tf.range(nd)[:,None] 107 | j = tf.range(ns) 108 | m = i >= j - ns + nd 109 | # to ignore first parts of context (useful for sampling with static shapes) 110 | # m = tf.math.logical_and(m, tf.math.logical_or(j >= ignore, i < ignore - ns + nd)) 111 | return tf.cast(m, dtype) 112 | 113 | 114 | def softmax(x, axis=-1): 115 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 116 | ex = tf.exp(x) 117 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 118 | 119 | 120 | def attn(x, scope, n_state, *, past, mask, do_dropout, scale=False, hparams, seed): 121 | assert x.shape.ndims == 3 # Should be [batch, sequence, features] 122 | if past is not None: 123 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 124 | 125 | def split_heads(x): 126 | # From [batch, sequence, features] to [batch, heads, sequence, features] 127 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) 128 | 129 | def merge_heads(x): 130 | # Reverse of split_heads 131 | return merge_states(tf.transpose(x, [0, 2, 1, 3])) 132 | 133 | def mask_attn_weights(w): 134 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 135 | bs, _, nd, ns = utils.shape_list(w) 136 | b = attention_mask(nd, ns, dtype=w.dtype) 137 | b = tf.reshape(b, [1, 1, nd, ns]) 138 | if mask is not None: 139 | b *= tf.reshape(tf.cast(mask, w.dtype), [bs, 1, 1, ns]) 140 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 141 | return w 142 | 143 | def multihead_attn(q, k, v, *, seed): 144 | orig_dtype = v.dtype 145 | q, k, v = map(partial(tf.cast, dtype=tf.float32), (q, k, v)) 146 | # q, k, v have shape [batch, heads, sequence, features] 147 | w = tf.matmul(q, k, transpose_b=True) 148 | 149 | if scale: 150 | n_state = v.shape[-1].value 151 | w = w * tf.rsqrt(tf.cast(n_state, w.dtype)) 152 | 153 | w = mask_attn_weights(w) 154 | w = softmax(w) 155 | w = dropout(w, hparams.attn_pdrop, 156 | do_dropout=do_dropout, name='attn_drop', stateless=True, seed=seed) 157 | a = tf.matmul(w, v) 158 | a = tf.cast(a, dtype=orig_dtype, name='a_cast') 159 | return a 160 | 161 | with tf.variable_scope(scope): 162 | attn_seed, resid_seed = split_seed(seed, 2) 163 | 164 | assert n_state % hparams.n_head == 0 165 | w_init_stdev = 1/np.sqrt(n_state) 166 | c = conv1x1(x, 'c_attn', n_state * 3, w_init_stdev=w_init_stdev) 167 | q, k, v = map(split_heads, tf.split(c, 3, axis=2)) 168 | present = tf.stack([k, v], axis=1) 169 | if past is not None: 170 | pk, pv = tf.unstack(past, axis=1) 171 | k = tf.concat([pk, k], axis=-2) 172 | v = tf.concat([pv, v], axis=-2) 173 | a = multihead_attn(q, k, v, seed=attn_seed) 174 | a = merge_heads(a) 175 | w_init_stdev = 1/np.sqrt(n_state*hparams.n_layer) 176 | a = conv1x1(a, 'c_proj', n_state, w_init_stdev=w_init_stdev) 177 | a = dropout(a, hparams.resid_pdrop, do_dropout=do_dropout, stateless=True, seed=resid_seed, name='attn_resid_drop') 178 | return a, present 179 | 180 | 181 | def mlp(x, scope, n_hidden, *, do_dropout, hparams, seed): 182 | with tf.variable_scope(scope): 183 | nx = x.shape[-1].value 184 | w_init_stdev = 1/np.sqrt(nx) 185 | h = gelu( 186 | conv1x1(x, 'c_fc', n_hidden, w_init_stdev=w_init_stdev)) 187 | w_init_stdev = 1/np.sqrt(n_hidden*hparams.n_layer) 188 | h2 = conv1x1(h, 'c_proj', nx, w_init_stdev=w_init_stdev) 189 | h2 = dropout(h2, hparams.resid_pdrop, do_dropout=do_dropout, stateless=True, seed=seed, name='mlp_drop') 190 | return h2 191 | 192 | 193 | def block(x, scope, *, past, mask, do_dropout, scale=False, hparams, seed): 194 | with tf.variable_scope(scope): 195 | attn_seed, mlp_seed = split_seed(seed, 2) 196 | 197 | nx = x.shape[-1].value 198 | a, present = attn( 199 | norm(x, 'ln_1'), 200 | 'attn', nx, past=past, mask=mask, do_dropout=do_dropout, scale=scale, hparams=hparams, seed=attn_seed) 201 | x = x + a 202 | 203 | m = mlp( 204 | norm(x, 'ln_2'), 205 | 'mlp', nx*4, do_dropout=do_dropout, hparams=hparams, seed=mlp_seed) 206 | h = x + m 207 | return h, present 208 | 209 | 210 | @function.Defun( 211 | python_grad_func=lambda x, dy: tf.convert_to_tensor(dy), 212 | shape_func=lambda op: [op.inputs[0].get_shape()]) 213 | def convert_gradient_to_tensor(x): 214 | """Force gradient to be a dense tensor. 215 | 216 | It's often faster to do dense embedding gradient on GPU than sparse on CPU. 217 | """ 218 | return x 219 | 220 | 221 | def embed(X, we): 222 | """Embedding lookup. 223 | 224 | X has shape [batch, sequence, info]. Currently info = 2 corresponding to [token_id, position]. 225 | """ 226 | we = convert_gradient_to_tensor(we) 227 | e = tf.gather(we, X) 228 | return e 229 | 230 | 231 | #tensor contraction of the final axes of x with the first axes of y 232 | #need to write it ourselves because tensorflow's tensordot is slow 233 | def tensordot(x, y, num_axes): 234 | split_x_axes_at = x.shape.ndims - num_axes 235 | x_shape = tf.shape(x)[:split_x_axes_at] 236 | y_shape = tf.shape(y)[num_axes:] 237 | rx = tf.reshape(x, [tf.reduce_prod(x_shape), tf.reduce_prod(tf.shape(x)[split_x_axes_at:])]) 238 | ry = tf.reshape(y, [-1, tf.reduce_prod(y_shape)]) 239 | rresult = tf.matmul(rx, ry) 240 | result = tf.reshape(rresult, tf.concat([x_shape, y_shape], axis=0)) 241 | result.set_shape(x.shape[:split_x_axes_at].concatenate(y.shape[num_axes:])) 242 | return result 243 | 244 | 245 | #more convenient fc layer that avoids stupid shape stuff 246 | #consumes in_axes of x 247 | #produces y of shape outshape 248 | def fc_layer(x, outshape, *, in_axes=1, scale=None): 249 | inshape = tuple([int(d) for d in x.shape[-in_axes:]]) if in_axes>0 else () 250 | outshape = tuple(outshape) 251 | if scale is None: 252 | scale = 1 / np.sqrt(np.prod(inshape) + 1) 253 | w = tf.get_variable('w', inshape + outshape, initializer=tf.random_normal_initializer(stddev=scale)) 254 | b = tf.get_variable('b', outshape, initializer=tf.constant_initializer(0)) 255 | # Call the regularizer manually so that it works correctly with GradientTape 256 | regularizer = tf.contrib.layers.l2_regularizer(scale=1/np.prod(outshape)) #so that initial value of regularizer is 1 257 | reg_loss = regularizer(w) 258 | return tensordot(x, w, in_axes) + b, reg_loss 259 | 260 | 261 | def past_shape(*, hparams, batch_size=None, sequence=None): 262 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, utils.exact_div(hparams.n_embd, hparams.n_head)] 263 | 264 | 265 | def positions_for(*, batch, sequence, past_length, mask): 266 | if mask is None: 267 | return utils.expand_tile(past_length + tf.range(sequence), batch, axis=0) 268 | else: 269 | return tf.cumsum(tf.cast(mask, tf.int32), exclusive=True, axis=-1)[:, past_length:] 270 | 271 | 272 | def split_seed(seed, n=2): 273 | if n == 0: 274 | return [] 275 | return tf.split( 276 | tf.random.stateless_uniform(dtype=tf.int64, shape=[2*n], minval=-2**63, maxval=2**63-1, seed=seed), 277 | n, name='split_seeds') 278 | 279 | 280 | class Model: 281 | def __init__(self, hparams: HParams, scalar_heads=[], scope=None): 282 | self.hparams = hparams 283 | self.scalar_heads = scalar_heads 284 | with tf.variable_scope(scope, 'model') as scope: 285 | self.scope = scope 286 | self.built = False 287 | 288 | def __call__(self, *, X, Y=None, past=None, past_tokens=None, mask=None, 289 | padding_token: Optional[int]=None, do_dropout=False): 290 | X = tf.convert_to_tensor(X, dtype=tf.int32) 291 | if mask is not None: 292 | mask = tf.convert_to_tensor(mask, dtype=tf.bool) 293 | assert mask.dtype == tf.bool 294 | if padding_token is not None: 295 | assert mask is None, 'At most one of mask and padding_token should be set' 296 | mask = tf.not_equal(X, padding_token) 297 | X = tf.where(mask, X, tf.zeros_like(X)) 298 | if past is not None: 299 | assert past_tokens is not None, 'padding_token requires past_tokens' 300 | mask = tf.concat([tf.not_equal(past_tokens, padding_token), mask], axis=1) 301 | with tf.variable_scope(self.scope, reuse=self.built, auxiliary_name_scope=not self.built): 302 | self.built = True 303 | results = {} 304 | batch, sequence = utils.shape_list(X) 305 | 306 | seed = tf.random.uniform(dtype=tf.int64, shape=[2], minval=-2**63, maxval=2**63-1) 307 | wpe_seed, wte_seed, blocks_seed, heads_seed = split_seed(seed, 4) 308 | 309 | wpe = tf.get_variable('wpe', [self.hparams.n_ctx, self.hparams.n_embd], 310 | initializer=tf.random_normal_initializer(stddev=0.01)) 311 | wte = tf.get_variable('wte', [self.hparams.n_vocab, self.hparams.n_embd], 312 | initializer=tf.random_normal_initializer(stddev=0.02)) 313 | wpe = dropout(wpe, self.hparams.embd_pdrop, 314 | do_dropout=do_dropout, stateless=True, seed=wpe_seed, name='wpe_drop') 315 | wte = dropout(wte, self.hparams.embd_pdrop, 316 | do_dropout=do_dropout, stateless=True, seed=wte_seed, name='wte_drop') 317 | 318 | past_length = 0 if past is None else tf.shape(past)[-2] 319 | 320 | positions = positions_for(batch=batch, sequence=sequence, past_length=past_length, mask=mask) 321 | h = embed(X, wte) + embed(positions, wpe) 322 | # Transformer 323 | presents = [] 324 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * self.hparams.n_layer 325 | assert len(pasts) == self.hparams.n_layer 326 | block_seeds = split_seed(blocks_seed, self.hparams.n_layer) 327 | for layer, (past, block_seed) in enumerate(zip(pasts, block_seeds)): 328 | h, present = block( 329 | h, 'h%d' % layer, past=past, mask=mask, do_dropout=do_dropout, scale=True, 330 | hparams=self.hparams, seed=block_seed) 331 | presents.append(present) 332 | results['present'] = tf.stack(presents, axis=1) 333 | h = norm(h, 'ln_f') 334 | if mask is not None: 335 | # For non-present tokens, use the output from the last present token instead. 336 | present_indices = utils.where(mask[:,past_length:], tf.tile(tf.range(sequence)[None,:], [batch, 1]), -1) 337 | use_indices = utils.cumulative_max(present_indices) 338 | # assert since GPUs don't 339 | with tf.control_dependencies([tf.assert_none_equal(use_indices, -1)]): 340 | h = utils.index_each(h, use_indices) 341 | results['h'] = h 342 | 343 | # Language model loss. Do tokens 0 379 | return params 380 | -------------------------------------------------------------------------------- /lm_human_preferences/language/sample.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from lm_human_preferences.language import model 4 | from lm_human_preferences.utils import core as utils 5 | 6 | 7 | def sample_sequence(*, step, model_hparams, length, batch_size=None, context=None, 8 | temperature=1, top_k=0, top_p=1.0, extra_outputs={}, cond=None): 9 | """ 10 | Sampling from an autoregressive sequence model. 11 | 12 | Inputs: 13 | step: A function which takes model hparams, a tokens Tensor, past, and 14 | returns a dictionary with 'logits' and 'presents', and any extra vars. 15 | context: Includes start tokens. 16 | extra_outputs: Map from extra output key to dtype 17 | Returns: 18 | A dict with keys 'presents', 'logits', and any keys in extra_outputs 19 | """ 20 | 21 | with tf.name_scope('sample_seq'): 22 | batch_size, *_ = utils.shape_list(context) 23 | 24 | beta = 1 / tf.maximum(tf.cast(temperature, tf.float32), 1e-10) 25 | 26 | context_output = step(model_hparams, context) 27 | logits = tf.cast(context_output['logits'][:,-1], tf.float32) 28 | 29 | first_output_logits = tf.cast(beta, logits.dtype) * logits 30 | first_outputs = utils.sample_from_logits(first_output_logits) 31 | first_logprobs = utils.logprobs_from_logits(logits=first_output_logits, labels=first_outputs) 32 | 33 | def body(past, prev, output, logprobs, *extras): 34 | next_outputs = step(model_hparams, prev[:, tf.newaxis], past=past, 35 | past_tokens=output[:, :-1]) 36 | logits = tf.cast(next_outputs['logits'], tf.float32) * beta 37 | if top_k != 0: 38 | logits = tf.cond(tf.equal(top_k, 0), 39 | lambda: logits, 40 | lambda: utils.take_top_k_logits(logits, top_k)) 41 | if top_p != 1.0: 42 | logits = utils.take_top_p_logits(logits, top_p) 43 | next_sample = utils.sample_from_logits(logits, dtype=tf.int32) 44 | 45 | next_logprob = utils.logprobs_from_logits(logits=logits, labels=next_sample) 46 | return [ 47 | tf.concat([past, next_outputs['presents']], axis=-2), 48 | tf.squeeze(next_sample, axis=[1]), 49 | tf.concat([output, next_sample], axis=1), 50 | tf.concat([logprobs, next_logprob], axis=1), 51 | *[tf.concat([prev, next_outputs[k]], axis=1) for k, prev in zip(extra_outputs, extras)], 52 | ] 53 | 54 | try: 55 | shape_batch_size = int(batch_size) 56 | except TypeError: 57 | shape_batch_size = None 58 | if cond is None: 59 | def always_true(*args): 60 | return True 61 | cond = always_true 62 | presents, _, tokens, logprobs, *extras = tf.while_loop( 63 | body=body, 64 | cond=cond, 65 | loop_vars=[ 66 | context_output['presents'], # past 67 | first_outputs, # prev 68 | tf.concat([context, first_outputs[:, tf.newaxis]], axis=1), # output 69 | first_logprobs[:, tf.newaxis], #logprobs 70 | *[context_output[k][:, -1:] for k in extra_outputs] # extras 71 | ], 72 | shape_invariants=[ 73 | tf.TensorShape(model.past_shape(hparams=model_hparams, batch_size=shape_batch_size)), 74 | tf.TensorShape([shape_batch_size]), 75 | tf.TensorShape([shape_batch_size, None]), 76 | tf.TensorShape([shape_batch_size, None]), 77 | *[tf.TensorShape([shape_batch_size, None]) for _ in extra_outputs] 78 | ], 79 | maximum_iterations=length-1, 80 | back_prop=False, 81 | parallel_iterations=2, 82 | ) 83 | 84 | return dict(tokens=tokens, presents=presents, logprobs=logprobs, **dict(zip(extra_outputs, extras))) 85 | -------------------------------------------------------------------------------- /lm_human_preferences/language/test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Transformer model tests.""" 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from lm_human_preferences.utils import core as utils 8 | from lm_human_preferences.language import model 9 | 10 | def test_incremental(): 11 | hparams = model.HParams() 12 | hparams.override_from_dict(dict( 13 | n_vocab=10, 14 | n_ctx=5, 15 | n_embd=9, 16 | n_head=3, 17 | n_layer=2, 18 | )) 19 | batch_size = 2 20 | steps = 5 21 | np.random.seed(7) 22 | tf.set_random_seed(7) 23 | 24 | # Transformer model 25 | m = model.Model(hparams=hparams) 26 | X = tf.placeholder(shape=[batch_size, None], dtype=tf.int32) 27 | logits = m(X=X)['lm_logits'] 28 | past_p = tf.placeholder(shape=model.past_shape(hparams=hparams, batch_size=batch_size), dtype=tf.float32) 29 | # Test reusing it in a different variable scope 30 | with tf.variable_scope('other_scope'): 31 | past_lm = m(X=X[:,-1:], past=past_p) 32 | past_logits = past_lm['lm_logits'] 33 | future = tf.concat([past_p, past_lm['present']], axis=-2) 34 | 35 | # Data 36 | ids = np.random.randint(hparams.n_vocab, size=[batch_size, steps]).astype(np.int32) 37 | past = np.zeros(model.past_shape(hparams=hparams, batch_size=batch_size, sequence=0), dtype=np.float32) 38 | 39 | # Evaluate 40 | with tf.Session() as sess: 41 | tf.global_variables_initializer().run() 42 | for step in range(steps): 43 | logits_v, past_logits_v, past = sess.run([logits, past_logits, future], 44 | feed_dict={X: ids[:,:step+1], past_p: past}) 45 | assert np.allclose(logits_v[:,-1:], past_logits_v, atol=1e-3, rtol=1e-3) 46 | 47 | 48 | def test_mask(): 49 | np.random.seed(7) 50 | tf.set_random_seed(7) 51 | 52 | # Make a transformer 53 | hparams = model.HParams() 54 | hparams.override_from_dict(dict( 55 | n_vocab=10, 56 | n_ctx=8, 57 | n_embd=3, 58 | n_head=3, 59 | n_layer=2, 60 | )) 61 | batch_size = 4# 64 62 | policy = model.Model(hparams=hparams) 63 | 64 | # Random pasts and tokens 65 | past_length = 4 66 | length = 3 67 | past = np.random.randn(*model.past_shape( 68 | hparams=hparams, batch_size=batch_size, sequence=past_length)).astype(np.float32) 69 | X = np.random.randint(hparams.n_vocab, size=[batch_size, length]) 70 | 71 | # Run model without gaps 72 | logits = policy(past=past, X=X)['lm_logits'] 73 | 74 | # Run the same thing, but with gaps randomly inserted 75 | gap_past_length = 7 76 | gap_length = 5 77 | def random_subsequence(*, n, size): 78 | # Always make the first token be present, since the model tries to fill gaps with the previous states 79 | sub = [ 80 | np.concatenate(([0], np.random.choice(np.arange(1,n), size=size-1, replace=False))) 81 | for _ in range(batch_size) 82 | ] 83 | return np.sort(sub, axis=-1) 84 | past_sub = random_subsequence(n=gap_past_length, size=past_length) 85 | X_sub = random_subsequence(n=gap_length, size=length) 86 | past_gap = np.random.randn(*model.past_shape( 87 | hparams=hparams, batch_size=batch_size, sequence=gap_past_length)).astype(np.float32) 88 | X_gap = np.random.randint(hparams.n_vocab, size=[batch_size, gap_length]) 89 | mask = np.zeros([batch_size, gap_past_length + gap_length], dtype=np.bool) 90 | for b in range(batch_size): 91 | for i in range(past_length): 92 | past_gap[b,:,:,:,past_sub[b,i]] = past[b,:,:,:,i] 93 | for i in range(length): 94 | X_gap[b,X_sub[b,i]] = X[b,i] 95 | mask[b, past_sub[b]] = mask[b, gap_past_length + X_sub[b]] = 1 96 | gap_logits = policy(past=past_gap, X=X_gap, mask=mask)['lm_logits'] 97 | sub_logits = utils.index_each(gap_logits, X_sub) 98 | 99 | # Compare 100 | with tf.Session() as sess: 101 | tf.global_variables_initializer().run() 102 | logits, sub_logits = sess.run([logits, sub_logits]) 103 | assert logits.shape == sub_logits.shape 104 | assert np.allclose(logits, sub_logits, atol=1e-5) 105 | 106 | 107 | def test_attention_mask(): 108 | with tf.Session() as sess: 109 | for nd in 1, 2, 3: 110 | for ns in range(nd, 4): 111 | ours = model.attention_mask(nd, ns, dtype=tf.int32) 112 | theirs = tf.matrix_band_part(tf.ones([nd, ns], dtype=tf.int32), tf.cast(-1, tf.int32), ns-nd) 113 | ours, theirs = sess.run([ours, theirs]) 114 | print(ours) 115 | print(theirs) 116 | assert np.all(ours == theirs) 117 | 118 | 119 | if __name__ == '__main__': 120 | test_mask() 121 | test_attention_mask() 122 | test_incremental() 123 | -------------------------------------------------------------------------------- /lm_human_preferences/language/test_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Test sample_sequence().""" 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.contrib.training import HParams 7 | 8 | from lm_human_preferences.language import sample 9 | 10 | n_vocab = 10 11 | batch_size = 2 12 | hparams = HParams( 13 | n_layer=0, 14 | n_head=1, 15 | n_embd=0, 16 | n_attn=0, 17 | ) 18 | 19 | # Returns a policy that deterministically chooses previous token + 1. 20 | def step(hparams, tokens, past=None, past_tokens=None): 21 | logits = tf.one_hot(tokens + 1, n_vocab, on_value=0., off_value=-np.inf, dtype=tf.float32) 22 | ret = { 23 | 'logits': logits, 24 | 'presents': tf.zeros(shape=[2, 0, 2, 1, 0, 0]), 25 | } 26 | return ret 27 | 28 | def test_sample_sequence(): 29 | output = sample.sample_sequence(step=step, model_hparams=hparams, length=4, batch_size=batch_size, 30 | context=tf.constant([[5, 0], [4, 3]])) 31 | expected = np.array([[5, 0, 1, 2, 3, 4], [4, 3, 4, 5, 6, 7]]) 32 | 33 | with tf.Session() as sess: 34 | np.testing.assert_array_equal(sess.run(output)['tokens'], expected) 35 | 36 | 37 | if __name__ == '__main__': 38 | test_sample_sequence() 39 | -------------------------------------------------------------------------------- /lm_human_preferences/language/trained_models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | 4 | import tensorflow as tf 5 | 6 | from lm_human_preferences.language import encodings, model 7 | 8 | 9 | class TrainedModel(): 10 | def __init__(self, name, *, savedir=None, scope=None): 11 | self.name = name 12 | self.scope = scope 13 | self.savedir = savedir if savedir else os.path.join('gs://gpt-2/models/', name) 14 | if name == 'test': 15 | self.encoding = encodings.Test 16 | else: 17 | self.encoding = encodings.Main 18 | self._hparams = None 19 | 20 | def checkpoint(self): 21 | if self.name == 'test': 22 | return None 23 | ckpt = tf.train.latest_checkpoint(self.savedir) 24 | if ckpt is not None: 25 | return ckpt 26 | return tf.train.latest_checkpoint(os.path.join(self.savedir, 'checkpoints')) 27 | 28 | def hparams(self): 29 | if self._hparams is None: 30 | if self.name == 'test': 31 | hparams = test_hparams() 32 | else: 33 | hparams = load_hparams( 34 | os.path.join(self.savedir, 'hparams.json') 35 | ) 36 | self._hparams = hparams 37 | return copy.deepcopy(self._hparams) 38 | 39 | def init_op(self, params, new_scope): 40 | assert params 41 | params = dict(**params) 42 | checkpoint = self.checkpoint() 43 | available = tf.train.list_variables(checkpoint) 44 | unchanged = {} 45 | 46 | for name, shape in available: 47 | our_name = name 48 | if self.scope: 49 | if name.startswith(self.scope): 50 | our_name = name[len(self.scope):].lstrip('/') 51 | else: 52 | continue 53 | # Annoying hack since some code uses 'scope/model' as the scope and other code uses just 'scope' 54 | our_name = '%s/%s' % (new_scope, our_name) 55 | if our_name not in params: 56 | # NOTE: this happens for global_step and optimizer variables 57 | # (e.g. beta1_power, beta2_power, blah/Adam, blah/Adam_1) 58 | # print(f'{name} is missing for scope {new_scope}') 59 | continue 60 | var = params[our_name] 61 | del params[our_name] 62 | assert var.shape == shape, 'Shape mismatch: %s.shape = %s != %s' % (var.op.name, var.shape, shape) 63 | unchanged[name] = var 64 | for name in params.keys(): 65 | print(f'Param {name} is missing from checkpoint {checkpoint}') 66 | tf.train.init_from_checkpoint(checkpoint, unchanged) 67 | 68 | def load_hparams(file): 69 | hparams = model.HParams() 70 | hparams.override_from_json_file(file) 71 | return hparams 72 | 73 | def test_hparams(): 74 | hparams = model.HParams() 75 | hparams.override_from_dict(dict( 76 | n_vocab=27, # Corresponds to random encoding length 77 | n_ctx=8, 78 | n_layer=2, 79 | n_embd=7, 80 | n_head=1, 81 | )) 82 | return hparams 83 | -------------------------------------------------------------------------------- /lm_human_preferences/lm_tasks.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | import tensorflow as tf 5 | 6 | from lm_human_preferences.language import datasets 7 | from lm_human_preferences.utils import core as utils 8 | from lm_human_preferences.utils import hyperparams 9 | 10 | 11 | @dataclass 12 | class PolicyHParams(hyperparams.HParams): 13 | temperature: float = 1.0 14 | initial_model: str = None 15 | 16 | @dataclass 17 | class TaskHParams(hyperparams.HParams): 18 | # Query params 19 | query_length: int = None 20 | query_dataset: str = None 21 | query_prefix: str = '' 22 | query_suffix: str = '' 23 | start_text: Optional[str] = '.' 24 | end_text: Optional[str] = None 25 | 26 | # Response params 27 | response_length: int = None 28 | 29 | # Truncate response after the first occurrence of this token at or after index after when sampling. 30 | truncate_token: Optional[int] = None 31 | truncate_after: int = 0 32 | penalty_reward_value: int = -1 33 | 34 | policy: PolicyHParams = field(default_factory=PolicyHParams) 35 | 36 | #returns a postprocessing function 37 | #it is applied to responses before they are scored 38 | #central example: replace all tokens after truncate_token with padding_token 39 | def postprocess_fn_from_hparams(hparams: TaskHParams, padding_token: int): 40 | def get_mask(responses, truncate_token, truncate_after): 41 | # We want to truncate at the first occurrence of truncate_token that appears at or after 42 | # position truncate_after in the responses 43 | mask = tf.cast(tf.equal(responses, truncate_token), tf.int32) 44 | mask = tf.concat([tf.zeros_like(mask)[:,:truncate_after], mask[:,truncate_after:]], axis=1) 45 | return tf.cast(tf.cumsum(mask, axis=1) - mask, tf.bool) 46 | if hparams.truncate_token is not None: 47 | def truncate(responses): 48 | mask = get_mask(responses, hparams.truncate_token, hparams.truncate_after) 49 | return tf.where(mask, padding_token * tf.ones_like(responses), responses) 50 | return truncate 51 | else: 52 | return lambda responses: responses 53 | 54 | #returns a filter function 55 | #responses not passing that function will receive a low (fixed) score 56 | #only query humans on responses that pass that function 57 | #central example: ensure that the sample contains truncate_token 58 | def filter_fn_from_hparams(hparams: TaskHParams): 59 | def filter(responses): 60 | if hparams.truncate_token is not None: 61 | matches_token = tf.equal(responses[:, hparams.truncate_after:], hparams.truncate_token) 62 | return tf.reduce_any(matches_token, axis=-1) 63 | else: 64 | return tf.ones(tf.shape(responses)[0], dtype=tf.bool) 65 | return filter 66 | 67 | 68 | def query_formatter(hparams: TaskHParams, encoder): 69 | """Turns a query into a context to feed to the language model 70 | 71 | NOTE: Both of these are lists of tokens 72 | """ 73 | def query_formatter(queries): 74 | batch_size = tf.shape(queries)[0] 75 | prefix_tokens = tf.constant(encoder.encode(hparams.query_prefix), dtype=tf.int32) 76 | tiled_prefix = utils.expand_tile(prefix_tokens, batch_size, axis=0) 77 | suffix_tokens = tf.constant(encoder.encode(hparams.query_suffix), dtype=tf.int32) 78 | tiled_suffix = utils.expand_tile(suffix_tokens, batch_size, axis=0) 79 | return tf.concat([tiled_prefix, queries, tiled_suffix], 1) 80 | return query_formatter 81 | 82 | 83 | def make_query_sampler(*, hparams: TaskHParams, encoder, batch_size: int, mode='train', comm=None): 84 | if hparams.start_text: 85 | start_token, = encoder.encode(hparams.start_text) 86 | else: 87 | start_token = None 88 | 89 | if hparams.end_text: 90 | end_token, = encoder.encode(hparams.end_text) 91 | else: 92 | end_token = None 93 | 94 | data = datasets.get_dataset(hparams.query_dataset).tf_dataset( 95 | sequence_length=hparams.query_length, mode=mode, comm=comm, encoder=encoder, 96 | start_token=start_token, end_token=end_token, 97 | ) 98 | data = data.map(lambda d: tf.cast(d['tokens'], tf.int32)) 99 | data = data.batch(batch_size, drop_remainder=True) 100 | 101 | context_iterator = data.make_one_shot_iterator() 102 | 103 | def sampler(scope=None): 104 | with tf.name_scope(scope, 'sample_corpus'): 105 | context_tokens = context_iterator.get_next() 106 | return dict(tokens=context_tokens) 107 | return sampler 108 | -------------------------------------------------------------------------------- /lm_human_preferences/policy.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from lm_human_preferences.language import model, sample 4 | from lm_human_preferences.utils import core as utils 5 | from lm_human_preferences.utils.core import Schema 6 | 7 | 8 | class Policy: 9 | def __init__( 10 | self, 11 | trained_model, *, 12 | scope=None, use_resource=False, 13 | embed_queries=lambda queries: queries, 14 | temperature=1.0, is_root=True, 15 | build_respond=True, 16 | ): 17 | self.trained_model = trained_model 18 | self.model_hparams = trained_model.hparams() 19 | self.is_root = is_root 20 | 21 | self.use_resource = use_resource 22 | self.encoder = self.trained_model.encoding.get_encoder() 23 | 24 | with tf.variable_scope(scope, 'transformer_policy', use_resource=self.use_resource) as s: 25 | self.scope = s 26 | self.model = model.Model( 27 | hparams=self.model_hparams, 28 | scalar_heads=['value']) 29 | 30 | self.built = False 31 | self.embed_queries = embed_queries 32 | self.temperature = temperature 33 | self.padding_token = self.encoder.padding_token 34 | 35 | if build_respond: 36 | self.respond = utils.graph_function( 37 | queries=Schema(tf.int32, (None, None)), 38 | length=Schema(tf.int32, ()), 39 | )(self.respond_op) 40 | self.analyze_responses = utils.graph_function( 41 | queries=Schema(tf.int32, (None, None)), 42 | responses=Schema(tf.int32, (None, None)), 43 | )(self.analyze_responses_op) 44 | 45 | def get_encoder(self): 46 | return self.encoder 47 | 48 | def step_core(self, model_hparams, tokens, past=None, past_tokens=None, do_dropout=False, name=None): 49 | with tf.name_scope(name, 'step'): 50 | with tf.variable_scope( 51 | self.scope, 52 | reuse=self.built, 53 | auxiliary_name_scope=not self.built, 54 | use_resource=self.use_resource): 55 | lm_output = self.model(X=tokens, past=past, past_tokens=past_tokens, 56 | do_dropout=do_dropout, padding_token=self.padding_token) 57 | 58 | # need to slice logits since we don't want to generate special tokens 59 | logits = lm_output['lm_logits'][:,:,:self.model_hparams.n_vocab] 60 | presents = lm_output['present'] 61 | value = lm_output['value'] 62 | if not self.built: 63 | self._set_initializers() 64 | self.built = True 65 | return { 66 | 'logits': logits, 67 | 'values': value, 68 | 'presents': presents, 69 | } 70 | 71 | def ensure_built(self): 72 | if not self.built: 73 | with tf.name_scope('dummy'): 74 | self.step_core(self.model_hparams, tokens=tf.zeros([0,0], dtype=tf.int32)) 75 | 76 | def get_params(self): 77 | self.ensure_built() 78 | params = utils.find_trainable_variables(self.scope.name) 79 | assert len(params) > 0 80 | return params 81 | 82 | def _set_initializers(self): 83 | """Change initializers to load a language model from a tensorflow checkpoint.""" 84 | # Skip if 85 | # 1. We're not rank 0. Values will be copied from there. 86 | # 2. We want random initialization. Normal initialization will do the work. 87 | if not self.is_root or self.trained_model.name == 'test': 88 | return 89 | 90 | with tf.init_scope(): 91 | scope = self.scope.name 92 | 93 | # Initialize! 94 | params = {v.op.name: v for v in utils.find_trainable_variables(scope)} 95 | self.trained_model.init_op(params, new_scope=scope) 96 | 97 | def respond_op(self, queries, length): 98 | contexts = self.embed_queries(queries) 99 | context_length = tf.shape(contexts)[1] 100 | result = sample.sample_sequence( 101 | step=self.step_core, 102 | context=contexts, 103 | length=length, 104 | model_hparams=self.model_hparams, 105 | temperature=self.temperature, 106 | extra_outputs={'values':tf.float32}, 107 | ) 108 | return dict( 109 | responses=result['tokens'][:, context_length:], 110 | logprobs=result['logprobs'], 111 | values=result['values'], 112 | ) 113 | 114 | def analyze_responses_op(self, queries, responses): 115 | contexts = self.embed_queries(queries) 116 | context_length = tf.shape(contexts)[1] 117 | tokens = tf.concat([contexts, responses], axis=1) 118 | result = self.step_core(self.model_hparams, tokens) 119 | logits = result['logits'][:, context_length-1:-1] 120 | 121 | logits /= self.temperature 122 | return dict( 123 | logprobs = utils.logprobs_from_logits(logits=logits, labels=responses), 124 | entropies = utils.entropy_from_logits(logits), 125 | values = result['values'][:, context_length-1:-1], 126 | ) 127 | 128 | -------------------------------------------------------------------------------- /lm_human_preferences/rewards.py: -------------------------------------------------------------------------------- 1 | """Synthetic scores.""" 2 | 3 | import os 4 | 5 | import tensorflow as tf 6 | from mpi4py import MPI 7 | 8 | from lm_human_preferences.language import trained_models, model 9 | from lm_human_preferences.utils import core as utils 10 | from lm_human_preferences.utils.core import Schema 11 | 12 | 13 | # TODO: combine this with TrainedRewardModel 14 | class RewardModelTrainer: 15 | def __init__( 16 | self, 17 | trained_model, *, 18 | scope='reward_model', use_resource=False, 19 | is_root=True, 20 | ): 21 | self.trained_model = trained_model 22 | self.hparams = trained_model.hparams() 23 | self.is_root = is_root 24 | 25 | self.use_resource = use_resource 26 | self.encoder = self.trained_model.encoding.get_encoder() 27 | 28 | self.scope = scope 29 | self.model = model.Model(hparams=self.hparams, scope=f'{scope}/model', scalar_heads=['reward']) 30 | 31 | self.built = False 32 | self.padding_token = self.encoder.padding_token 33 | 34 | self.get_rewards = utils.graph_function( 35 | queries=Schema(tf.int32, (None, None)), 36 | responses=Schema(tf.int32, (None, None)), 37 | )(self.get_rewards_op) 38 | 39 | 40 | def get_encoder(self): 41 | return self.encoder 42 | 43 | def _build(self, tokens, do_dropout=False, name=None): 44 | with tf.variable_scope(self.scope, reuse=self.built, auxiliary_name_scope=not self.built, use_resource=self.use_resource): 45 | lm_output = self.model(X=tokens, do_dropout=do_dropout, padding_token=self.padding_token) 46 | 47 | reward = lm_output['reward'][:, -1] 48 | with tf.variable_scope('reward_norm'): 49 | if not self.built: 50 | self.reward_gain = tf.get_variable('gain', shape=(), initializer=tf.constant_initializer(1)) 51 | self.reward_bias = tf.get_variable('bias', shape=(), initializer=tf.constant_initializer(0)) 52 | self._reward_gain_p = tf.placeholder(name='gain_p', dtype=tf.float32, shape=()) 53 | self._reward_bias_p = tf.placeholder(name='bias_p', dtype=tf.float32, shape=()) 54 | self._set_reward_norm = tf.group(self.reward_gain.assign(self._reward_gain_p), 55 | self.reward_bias.assign(self._reward_bias_p)) 56 | if reward is not None: 57 | reward = self.reward_gain * reward + self.reward_bias 58 | if not self.built: 59 | self._set_initializers() 60 | self.built = True 61 | return reward 62 | 63 | def ensure_built(self): 64 | if self.built: 65 | return 66 | with tf.name_scope('dummy'): 67 | self._build(tokens=tf.zeros([0,0], dtype=tf.int32)) 68 | 69 | def get_params(self): 70 | self.ensure_built() 71 | return self.model.get_params() + [self.reward_gain, self.reward_bias] 72 | 73 | def reset_reward_scale(self): 74 | sess = tf.get_default_session() 75 | sess.run(self._set_reward_norm, feed_dict={self._reward_gain_p: 1, self._reward_bias_p: 0}) 76 | 77 | def set_reward_norm(self, *, old_mean, old_std, new_mean, new_std): 78 | """Given old_mean+-old_std of reward_model, change gain and bias to get N(new_mean,new_std).""" 79 | sess = tf.get_default_session() 80 | old_gain, old_bias = sess.run((self.reward_gain, self.reward_bias)) 81 | assert old_gain == 1 and old_bias == 0,\ 82 | f'set_reward_norm expects gain = 1 and bias = 0, not {old_gain}, {old_bias}' 83 | # gain * N(old_mean,old_std) + bias = N(gain * old_mean, gain * old_std) + bias 84 | # = N(gain * old_mean + bias, gain * old_std) 85 | # gain * old_std = new_std, gain = new_std / old_std 86 | # gain * old_mean + bias = new_mean, bias = new_mean - gain * old_mean 87 | gain = new_std / old_std 88 | bias = new_mean - gain * old_mean 89 | sess.run(self._set_reward_norm, feed_dict={self._reward_gain_p: gain, self._reward_bias_p: bias}) 90 | 91 | def _set_initializers(self): 92 | """Change initializers to load a language model from a tensorflow checkpoint.""" 93 | # Skip if 94 | # 1. We're not rank 0. Values will be copied from there. 95 | # 2. We want random initialization. Normal initialization will do the work. 96 | if not self.is_root or self.trained_model.name == 'test': 97 | return 98 | 99 | with tf.init_scope(): 100 | # Initialize! 101 | params = {v.op.name: v for v in utils.find_trainable_variables(self.scope)} 102 | assert params 103 | self.trained_model.init_op(params, new_scope=self.scope) 104 | 105 | def get_rewards_op(self, queries, responses): 106 | tokens = tf.concat([queries, responses], axis=1) 107 | return self._build(tokens) 108 | 109 | 110 | class TrainedRewardModel(): 111 | def __init__(self, train_dir, encoding, *, scope='reward_model', comm=MPI.COMM_WORLD): 112 | self.train_dir = train_dir 113 | self.comm = comm 114 | 115 | self.encoding = encoding 116 | encoder = encoding.get_encoder() 117 | if train_dir != 'test': 118 | self.hparams = trained_models.load_hparams(os.path.join(train_dir, 'hparams.json')) 119 | assert self.hparams.n_vocab == encoding.n_vocab, f'{self.hparams.n_vocab} != {encoding.n_vocab}' 120 | else: 121 | self.hparams = trained_models.test_hparams() 122 | 123 | self.padding_token = encoder.padding_token 124 | 125 | self.encoder = encoder 126 | 127 | self.scope = scope 128 | self.model = model.Model(hparams=self.hparams, scope=f'{scope}/model', scalar_heads=['reward']) 129 | 130 | def _build(self, X): 131 | results = self.model(X=X, padding_token=self.padding_token) 132 | reward = results['reward'][:, -1] 133 | with tf.variable_scope(f'{self.scope}/reward_norm'): 134 | self.reward_gain = tf.get_variable('gain', shape=(), initializer=tf.constant_initializer(1)) 135 | self.reward_bias = tf.get_variable('bias', shape=(), initializer=tf.constant_initializer(0)) 136 | reward = self.reward_gain * reward + self.reward_bias 137 | self._set_initializers() 138 | return reward 139 | 140 | def ensure_built(self): 141 | if self.model.built: 142 | return 143 | with tf.name_scope('dummy'): 144 | self._build(X=tf.zeros([0,0], dtype=tf.int32)) 145 | 146 | def _set_initializers(self): 147 | """Change initializers to load a model from a tensorflow checkpoint.""" 148 | if self.comm.Get_rank() > 0 or self.train_dir == 'test': 149 | return 150 | 151 | assert self.model.built 152 | checkpoint_scope = 'reward_model' 153 | 154 | with tf.init_scope(): 155 | # Initialize! 156 | params = {v.op.name: v for v in self.get_params()} 157 | checkpoint = tf.train.latest_checkpoint(os.path.join(self.train_dir, 'checkpoints/')) 158 | available = tf.train.list_variables(checkpoint) 159 | unchanged = {} 160 | 161 | for name, shape in available: 162 | if not name.startswith(checkpoint_scope + '/'): 163 | # print('skipping', name) 164 | continue 165 | if name.endswith('adam') or name.endswith('adam_1'): 166 | # print('skipping', name) 167 | continue 168 | print('setting', name) 169 | var = params[self.scope + name[len(checkpoint_scope):]] 170 | assert var.shape == shape, 'Shape mismatch: %s.shape = %s != %s' % (var.op.name, var.shape, shape) 171 | unchanged[name] = var 172 | tf.train.init_from_checkpoint(checkpoint, unchanged) 173 | 174 | def get_params(self): 175 | return self.model.get_params() + [self.reward_gain, self.reward_bias] 176 | 177 | def score_fn(self, queries, responses): 178 | tokens = tf.concat([queries, responses], axis=1) 179 | return self._build(tokens) 180 | -------------------------------------------------------------------------------- /lm_human_preferences/test_train_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import tempfile 4 | from lm_human_preferences import train_policy 5 | 6 | def hparams_for_test(): 7 | hparams = train_policy.HParams() 8 | hparams.ppo.batch_size = 8 9 | hparams.noptepochs = 1 10 | hparams.task.policy.initial_model = 'test' 11 | hparams.task.query_length = 2 12 | hparams.task.response_length = 3 13 | hparams.task.query_dataset = 'test' 14 | hparams.rewards.trained_model = 'test' 15 | hparams.ppo.total_episodes = 8 16 | hparams.run.log_interval = 1 17 | 18 | return hparams 19 | 20 | 21 | def train_policy_test(override_params): 22 | hparams = hparams_for_test() 23 | hparams.override_from_dict(override_params) 24 | hparams.validate() 25 | train_policy.train(hparams=hparams) 26 | 27 | 28 | def test_truncation(): 29 | train_policy_test({ 30 | 'task.truncate_token': 13, 31 | 'task.truncate_after': 2, 32 | }) 33 | 34 | def test_defaults(): 35 | train_policy_test({}) 36 | 37 | def test_affixing(): 38 | train_policy_test({ 39 | 'task.query_prefix': 'a', 40 | 'task.query_suffix': 'b' 41 | }) 42 | 43 | def test_adaptive_kl(): 44 | train_policy_test({ 45 | 'rewards.trained_model': 'test', # not sure why needed 46 | 'rewards.adaptive_kl': 'on', 47 | 'rewards.adaptive_kl.target': 3.0, 48 | 'rewards.adaptive_kl.horizon': 100, 49 | }) 50 | 51 | def test_save(): 52 | train_policy_test({ 53 | 'run.save_dir': tempfile.mkdtemp() , 54 | 'run.save_interval': 1 55 | }) 56 | 57 | def test_reward_training(): 58 | train_policy_test({ 59 | 'rewards.trained_model': None, 60 | 'rewards.train_new_model': 'on', 61 | 'rewards.train_new_model.task.policy.initial_model': 'test', 62 | 'rewards.train_new_model.task.query_length': 2, 63 | 'rewards.train_new_model.task.response_length': 3, 64 | 'rewards.train_new_model.task.query_dataset': 'test', 65 | 'rewards.train_new_model.labels.source': 'test', 66 | 'rewards.train_new_model.labels.num_train': 16, 67 | 'rewards.train_new_model.batch_size': 8, 68 | 'rewards.train_new_model.labels.type': 'best_of_4', 69 | }) 70 | -------------------------------------------------------------------------------- /lm_human_preferences/test_train_reward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import tempfile 4 | from lm_human_preferences import train_reward 5 | 6 | def hparams_for_test(): 7 | hparams = train_reward.HParams() 8 | hparams.rollout_batch_size = 8 9 | hparams.task.query_length = 2 10 | hparams.task.response_length = 3 11 | hparams.noptepochs = 1 12 | hparams.task.policy.initial_model = 'test' 13 | hparams.task.query_dataset = 'test' 14 | hparams.task.start_text = None 15 | hparams.run.log_interval = 1 16 | 17 | hparams.labels.source = 'test' 18 | hparams.labels.num_train = 16 19 | hparams.labels.type = 'best_of_4' 20 | 21 | hparams.batch_size = 8 22 | 23 | return hparams 24 | 25 | 26 | def train_reward_test(override_params): 27 | hparams = hparams_for_test() 28 | hparams.override_from_dict(override_params) 29 | hparams.validate() 30 | train_reward.train(hparams=hparams) 31 | 32 | 33 | def test_basic(): 34 | train_reward_test({}) 35 | 36 | 37 | def test_scalar_compare(): 38 | train_reward_test({'labels.type': 'scalar_compare'}) 39 | 40 | 41 | def test_scalar_rating(): 42 | train_reward_test({'labels.type': 'scalar_rating'}) 43 | 44 | 45 | def test_normalize_before(): 46 | train_reward_test({ 47 | 'normalize_before': True, 48 | 'normalize_after': False, 49 | 'normalize_samples': 1024, 50 | 'debug_normalize': 1024, 51 | }) 52 | 53 | 54 | def test_normalize_both(): 55 | train_reward_test({ 56 | 'normalize_before': True, 57 | 'normalize_after': True, 58 | 'normalize_samples': 1024, 59 | 'debug_normalize': 1024, 60 | }) 61 | 62 | def test_save(): 63 | train_reward_test({ 64 | 'run.save_dir': tempfile.mkdtemp() , 65 | 'run.save_interval': 1 66 | }) 67 | -------------------------------------------------------------------------------- /lm_human_preferences/train_policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import os 5 | import sys 6 | import time 7 | from dataclasses import dataclass, field 8 | from functools import partial 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from mpi4py import MPI 14 | from tensorflow.contrib import summary 15 | 16 | from lm_human_preferences import lm_tasks, train_reward 17 | from lm_human_preferences.language import trained_models 18 | from lm_human_preferences.policy import Policy 19 | from lm_human_preferences.rewards import TrainedRewardModel 20 | from lm_human_preferences.utils import core as utils 21 | from lm_human_preferences.utils import hyperparams 22 | from lm_human_preferences.utils.core import Schema 23 | 24 | 25 | @dataclass 26 | class AdaptiveKLParams(hyperparams.HParams): 27 | target: float = None 28 | horizon: int = 10000 # in episodes 29 | 30 | 31 | @dataclass 32 | class RewardHParams(hyperparams.HParams): 33 | kl_coef: float = 0.2 34 | adaptive_kl: Optional[AdaptiveKLParams] = None 35 | 36 | trained_model: Optional[str] = None 37 | 38 | train_new_model: Optional[train_reward.HParams] = None 39 | 40 | def validate(self, *, prefix=''): 41 | super().validate(prefix=prefix) 42 | assert self.trained_model is None or self.train_new_model is None, 'Cannot use trained_model and train new model' 43 | assert self.trained_model is not None or self.train_new_model is not None, 'Need either trained_model or to train a new model' 44 | 45 | 46 | @dataclass 47 | class PpoHParams(hyperparams.HParams): 48 | total_episodes: int = 2000000 49 | batch_size: int = 64 50 | nminibatches: int = 1 51 | noptepochs: int = 4 52 | lr: float = 5e-6 53 | vf_coef: float = .1 54 | cliprange: float = .2 55 | cliprange_value: float = .2 56 | gamma: float = 1 57 | lam: float = 0.95 58 | whiten_rewards: bool = True 59 | 60 | 61 | @dataclass 62 | class HParams(hyperparams.HParams): 63 | run: train_reward.RunHParams = field(default_factory=train_reward.RunHParams) 64 | 65 | task: lm_tasks.TaskHParams = field(default_factory=lm_tasks.TaskHParams) 66 | rewards: RewardHParams = field(default_factory=RewardHParams) 67 | ppo: PpoHParams = field(default_factory=PpoHParams) 68 | 69 | def validate(self, *, prefix=''): 70 | super().validate(prefix=prefix) 71 | # NOTE: must additionally divide by # ranks 72 | minibatch_size = utils.exact_div(self.ppo.batch_size, self.ppo.nminibatches) 73 | if self.ppo.whiten_rewards: 74 | assert minibatch_size >= 8, \ 75 | f"Minibatch size {minibatch_size} is insufficient for whitening in PPOTrainer.loss" 76 | 77 | 78 | def nupdates(hparams): 79 | return utils.ceil_div(hparams.ppo.total_episodes, hparams.ppo.batch_size) 80 | 81 | 82 | def policy_frac(hparams): 83 | """How far we are through policy training.""" 84 | return tf.cast(tf.train.get_global_step(), tf.float32) / nupdates(hparams) 85 | 86 | 87 | def tf_times(): 88 | """Returns (time since start, time since last) as a tensorflow op.""" 89 | # Keep track of start and last times 90 | with tf.init_scope(): 91 | init = tf.timestamp() 92 | 93 | def make(name): 94 | return tf.Variable(init, name=name, trainable=False, use_resource=True) 95 | 96 | start = make('start_time') 97 | last = make('last_time') 98 | 99 | # Get new time and update last 100 | now = tf.timestamp() 101 | prev = last.read_value() 102 | with tf.control_dependencies([prev]): 103 | with tf.control_dependencies([last.assign(now)]): 104 | return tf.cast(now - start.read_value(), tf.float32), tf.cast(now - prev, tf.float32) 105 | 106 | 107 | class FixedKLController: 108 | def __init__(self, kl_coef): 109 | self.value = kl_coef 110 | 111 | def update(self, current, n_steps): 112 | pass 113 | 114 | 115 | class AdaptiveKLController: 116 | def __init__(self, init_kl_coef, hparams): 117 | self.value = init_kl_coef 118 | self.hparams = hparams 119 | 120 | def update(self, current, n_steps): 121 | target = self.hparams.target 122 | proportional_error = np.clip(current / target - 1, -0.2, 0.2) 123 | mult = 1 + proportional_error * n_steps / self.hparams.horizon 124 | self.value *= mult 125 | 126 | 127 | 128 | class PPOTrainer(): 129 | def __init__(self, *, policy, ref_policy, query_sampler, score_fn, hparams, comm): 130 | self.comm = comm 131 | self.policy = policy 132 | self.ref_policy = ref_policy 133 | self.score_fn = score_fn 134 | self.hparams = hparams 135 | 136 | if hparams.rewards.adaptive_kl is None: 137 | self.kl_ctl = FixedKLController(hparams.rewards.kl_coef) 138 | else: 139 | self.kl_ctl = AdaptiveKLController(hparams.rewards.kl_coef, hparams=hparams.rewards.adaptive_kl) 140 | 141 | response_length = hparams.task.response_length 142 | query_length = hparams.task.query_length 143 | 144 | @utils.graph_function() 145 | def sample_queries(): 146 | return query_sampler()['tokens'] 147 | self.sample_queries = sample_queries 148 | 149 | def compute_rewards(scores, logprobs, ref_logprobs): 150 | kl = logprobs - ref_logprobs 151 | non_score_reward = -self.kl_ctl.value * kl 152 | rewards = non_score_reward.copy() 153 | rewards[:, -1] += scores 154 | return rewards, non_score_reward, self.kl_ctl.value 155 | self.compute_rewards = compute_rewards 156 | 157 | # per rank sizes 158 | per_rank_rollout_batch_size = utils.exact_div(hparams.ppo.batch_size, comm.Get_size()) 159 | per_rank_minibatch_size = utils.exact_div(per_rank_rollout_batch_size, hparams.ppo.nminibatches) 160 | 161 | @utils.graph_function( 162 | rollouts=dict( 163 | queries=Schema(tf.int32, (per_rank_minibatch_size, query_length)), 164 | responses=Schema(tf.int32, (per_rank_minibatch_size, response_length)), 165 | values=Schema(tf.float32, (per_rank_minibatch_size, response_length)), 166 | logprobs=Schema(tf.float32, (per_rank_minibatch_size, response_length)), 167 | rewards=Schema(tf.float32, (per_rank_minibatch_size, response_length)), 168 | )) 169 | def train_minibatch(rollouts): 170 | """One step of PPO training.""" 171 | 172 | left = 1 - policy_frac(hparams) 173 | lrnow = hparams.ppo.lr * left 174 | 175 | ppo_loss, stats = self.loss(rollouts) 176 | ppo_train_op = utils.minimize( 177 | loss=ppo_loss, lr=lrnow, params=policy.get_params(), name='ppo_opt', comm=self.comm) 178 | return ppo_train_op, stats 179 | 180 | def train(rollouts): 181 | stat_list = [] 182 | 183 | # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch 184 | for ppo_epoch_idx in range(hparams.ppo.noptepochs): 185 | order = np.random.permutation(per_rank_rollout_batch_size) 186 | for mb_start in range(0, per_rank_rollout_batch_size, per_rank_minibatch_size): 187 | mb_data = {k: v[order[mb_start:mb_start+per_rank_minibatch_size]] 188 | for k, v in rollouts.items()} 189 | 190 | step = tf.train.get_global_step().eval() 191 | 192 | _, stats = train_minibatch(mb_data) 193 | stat_list.append(stats) 194 | 195 | # Collect the stats. (They will be averaged later.) 196 | return {k: [s[k] for s in stat_list] for k in stat_list[0].keys()} 197 | self.train = train 198 | 199 | # NOTE: must line up with stats created in self.loss (TODO: better solution?) 200 | scalar_batch = Schema(tf.float32, (None,)) 201 | ppo_stat_schemas = utils.flatten_dict(dict( 202 | loss=dict(policy=scalar_batch, value=scalar_batch, total=scalar_batch), 203 | policy=dict(entropy=scalar_batch, approxkl=scalar_batch, clipfrac=scalar_batch), 204 | returns=dict(mean=scalar_batch, var=scalar_batch), 205 | val=dict(vpred=scalar_batch, error=scalar_batch, clipfrac=scalar_batch, mean=scalar_batch, var=scalar_batch), 206 | ), sep='/') 207 | stat_data_schemas = dict( 208 | logprobs=Schema(tf.float32, (None, hparams.task.response_length)), 209 | ref_logprobs=Schema(tf.float32, (None, hparams.task.response_length)), 210 | scores=scalar_batch, 211 | non_score_reward=Schema(tf.float32, (None, hparams.task.response_length)), 212 | score_stats=score_fn.stat_schemas, 213 | train_stats=ppo_stat_schemas, 214 | ) 215 | @utils.graph_function( 216 | **stat_data_schemas, kl_coef=Schema(tf.float32, ())) 217 | def record_step_stats(*, kl_coef, **data): 218 | ppo_summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='ppo', comm=self.comm) 219 | 220 | kl = data['logprobs'] - data['ref_logprobs'] 221 | mean_kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1)) 222 | mean_entropy = tf.reduce_mean(tf.reduce_sum(-data['logprobs'], axis=1)) 223 | mean_non_score_reward = tf.reduce_mean(tf.reduce_sum(data['non_score_reward'], axis=1)) 224 | stats = { 225 | 'objective/kl': mean_kl, 226 | 'objective/kl_coef': kl_coef, 227 | 'objective/entropy': mean_entropy, 228 | } 229 | for k, v in data['train_stats'].items(): 230 | stats[f'ppo/{k}'] = tf.reduce_mean(v, axis=0) 231 | for k, v in data['score_stats'].items(): 232 | mean = tf.reduce_mean(v, axis=0) 233 | stats[f'objective/{k}'] = mean 234 | stats[f'objective/{k}_total'] = mean + mean_non_score_reward 235 | 236 | stats = utils.FlatStats.from_dict(stats).map_flat( 237 | partial(utils.mpi_allreduce_mean, comm=self.comm)).as_dict() 238 | 239 | # Add more statistics 240 | step = tf.train.get_global_step().read_value() 241 | stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var'] 242 | steps = step + 1 243 | stats.update({ 244 | 'elapsed/updates': steps, 245 | 'elapsed/steps/serial': steps * hparams.task.response_length, 246 | 'elapsed/steps/total': steps * hparams.ppo.batch_size * hparams.task.response_length, 247 | 'elapsed/episodes': steps * hparams.ppo.batch_size, 248 | }) 249 | 250 | # Time statistics 251 | total, delta = tf_times() 252 | stats.update({ 253 | 'elapsed/fps': tf.cast(hparams.ppo.batch_size * hparams.task.response_length / delta, tf.int32), 254 | 'elapsed/time': total, 255 | }) 256 | if ppo_summary_writer: 257 | record_op = utils.record_stats( 258 | stats=stats, summary_writer=ppo_summary_writer, step=step, log_interval=hparams.run.log_interval, name='ppo_stats', comm=self.comm) 259 | else: 260 | record_op = tf.no_op() 261 | return record_op, stats 262 | self.record_step_stats = record_step_stats 263 | 264 | def print_samples(self, queries, responses, scores, logprobs, ref_logprobs): 265 | if self.comm.Get_rank() != 0: 266 | return 267 | if tf.train.get_global_step().eval() % self.hparams.run.log_interval != 0: 268 | return 269 | 270 | encoder = self.policy.encoder 271 | 272 | # Log samples 273 | for i in range(min(3, len(queries))): 274 | sample_kl = np.sum(logprobs[i] - ref_logprobs[i]) 275 | print(encoder.decode(queries[i][:self.hparams.task.query_length]).replace("\n", "⏎")) 276 | print(encoder.decode(responses[i]).replace("\n", "⏎")) 277 | print(f" score = {scores[i]:+.2f}") 278 | print(f" kl = {sample_kl:+.2f}") 279 | print(f" total = {scores[i] - self.hparams.rewards.kl_coef * sample_kl:+.2f}") 280 | 281 | def step(self): 282 | step_started_at = time.time() 283 | 284 | queries = self.sample_queries() 285 | rollouts = self.policy.respond(queries, length=self.hparams.task.response_length) 286 | 287 | responses = rollouts['responses'] 288 | logprobs = rollouts['logprobs'] 289 | rollouts['queries'] = queries 290 | ref_logprobs = self.ref_policy.analyze_responses(queries, responses)['logprobs'] 291 | scores, postprocessed_responses, score_stats = self.score_fn(queries, responses) 292 | 293 | rewards, non_score_reward, kl_coef = self.compute_rewards( 294 | scores=scores, 295 | logprobs=logprobs, 296 | ref_logprobs=ref_logprobs) 297 | rollouts['rewards'] = rewards 298 | 299 | train_stats = self.train(rollouts=rollouts) 300 | 301 | _, stats = self.record_step_stats( 302 | scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs, non_score_reward=non_score_reward, 303 | train_stats=train_stats, score_stats=score_stats, kl_coef=kl_coef) 304 | 305 | self.kl_ctl.update(stats['objective/kl'], self.hparams.ppo.batch_size) 306 | 307 | self.print_samples(queries=queries, responses=postprocessed_responses, 308 | scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs) 309 | 310 | # Record profiles of the step times 311 | step = tf.get_default_session().run(tf.train.get_global_step()) 312 | step_time = time.time() - step_started_at 313 | eps_per_second = float(self.hparams.ppo.batch_size) / step_time 314 | if self.comm.Get_rank() == 0: 315 | print(f"[ppo_step {step}] step_time={step_time:.2f}s, " 316 | f"eps/s={eps_per_second:.2f}") 317 | 318 | 319 | def loss(self, rollouts): 320 | values = rollouts['values'] 321 | old_logprob = rollouts['logprobs'] 322 | rewards = rollouts['rewards'] 323 | with tf.name_scope('ppo_loss'): 324 | if self.hparams.ppo.whiten_rewards: 325 | rewards = utils.whiten(rewards, shift_mean=False) 326 | 327 | lastgaelam = 0 328 | advantages_reversed = [] 329 | gen_length = self.hparams.task.response_length 330 | for t in reversed(range(gen_length)): 331 | nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 332 | delta = rewards[:, t] + self.hparams.ppo.gamma * nextvalues - values[:, t] 333 | lastgaelam = delta + self.hparams.ppo.gamma * self.hparams.ppo.lam * lastgaelam 334 | advantages_reversed.append(lastgaelam) 335 | advantages = tf.stack(advantages_reversed[::-1], axis=1) 336 | returns = advantages + values 337 | 338 | advantages = utils.whiten(advantages) 339 | advantages = tf.stop_gradient(advantages) # Shouldn't do anything, but better not to think about it 340 | 341 | outputs = self.policy.analyze_responses_op(rollouts['queries'], rollouts['responses']) 342 | 343 | vpred = outputs['values'] 344 | vpredclipped = tf.clip_by_value(vpred, values - self.hparams.ppo.cliprange_value, values + self.hparams.ppo.cliprange_value) 345 | vf_losses1 = tf.square(vpred - returns) 346 | vf_losses2 = tf.square(vpredclipped - returns) 347 | vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2)) 348 | vf_clipfrac = tf.reduce_mean(tf.cast(tf.greater(vf_losses2, vf_losses1), tf.float32)) 349 | 350 | logprob = outputs['logprobs'] 351 | ratio = tf.exp(logprob - old_logprob) 352 | pg_losses = -advantages * ratio 353 | pg_losses2 = -advantages * tf.clip_by_value(ratio, 1.0 - self.hparams.ppo.cliprange, 1.0 + self.hparams.ppo.cliprange) 354 | pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2)) 355 | pg_clipfrac = tf.reduce_mean(tf.cast(tf.greater(pg_losses2, pg_losses), tf.float32)) 356 | 357 | loss = pg_loss + self.hparams.ppo.vf_coef * vf_loss 358 | 359 | entropy = tf.reduce_mean(outputs['entropies']) 360 | approxkl = .5 * tf.reduce_mean(tf.square(logprob - old_logprob)) 361 | 362 | return_mean, return_var = tf.nn.moments(returns, axes=list(range(returns.shape.ndims))) 363 | value_mean, value_var = tf.nn.moments(values, axes=list(range(values.shape.ndims))) 364 | 365 | stats = dict( 366 | loss=dict(policy=pg_loss, value=vf_loss, total=loss), 367 | policy=dict(entropy=entropy, approxkl=approxkl, clipfrac=pg_clipfrac), 368 | returns=dict(mean=return_mean, var=return_var), 369 | val=dict(vpred=tf.reduce_mean(vpred), error=tf.reduce_mean((vpred - returns) ** 2), 370 | clipfrac=vf_clipfrac, mean=value_mean, var=value_var) 371 | ) 372 | return loss, utils.flatten_dict(stats, sep='/') 373 | 374 | 375 | def make_score_fn(hparams, score_model): 376 | padding_token = score_model.padding_token 377 | 378 | postprocess_fn = lm_tasks.postprocess_fn_from_hparams(hparams, padding_token) 379 | #decorate requires a named function, postprocess_fn can be anonymous 380 | @utils.graph_function(responses=Schema(tf.int32, (None, None))) 381 | def postprocess(responses): 382 | return postprocess_fn(responses) 383 | 384 | filter_fn = lm_tasks.filter_fn_from_hparams(hparams) 385 | @utils.graph_function( 386 | responses=Schema(tf.int32, (None, None)), 387 | rewards=Schema(tf.float32, (None,))) 388 | def penalize(responses, rewards): 389 | valid = filter_fn(responses) 390 | return tf.where(valid, rewards, hparams.penalty_reward_value * tf.ones_like(rewards)) 391 | 392 | @utils.graph_function( 393 | queries=Schema(tf.int32, (None, None)), 394 | responses=Schema(tf.int32, (None, None)) 395 | ) 396 | def unpenalized_score_fn(queries, responses): 397 | return score_model.score_fn(queries, responses) 398 | 399 | def score_fn(queries, responses): 400 | responses = postprocess(responses) 401 | score = penalize(responses, unpenalized_score_fn(queries, responses)) 402 | return score, responses, dict(score=score) 403 | score_fn.stat_schemas = dict(score=Schema(tf.float32, (None,))) 404 | return score_fn 405 | 406 | 407 | 408 | def train(hparams: HParams): 409 | save_dir = hparams.run.save_dir 410 | if hparams.rewards.train_new_model: 411 | assert hparams.task == hparams.rewards.train_new_model.task, f'{hparams.task} != {hparams.rewards.train_new_model.task}' 412 | hparams.rewards.train_new_model.run.save_dir = save_dir 413 | train_reward.train(hparams.rewards.train_new_model) 414 | if 'pytest' in sys.modules: 415 | hparams.rewards.trained_model = 'test' 416 | elif save_dir: 417 | hparams.rewards.trained_model = None if save_dir is None else os.path.join(save_dir, 'reward_model') 418 | 419 | comm = MPI.COMM_WORLD 420 | 421 | with tf.Graph().as_default(): 422 | hyperparams.dump(hparams) 423 | 424 | m = trained_models.TrainedModel(hparams.task.policy.initial_model) 425 | encoder = m.encoding.get_encoder() 426 | hyperparams.dump(m.hparams(), name='model_hparams') 427 | 428 | if save_dir: 429 | if not save_dir.startswith('https:'): 430 | os.makedirs(os.path.join(save_dir, 'policy'), exist_ok=True) 431 | with tf.gfile.Open(os.path.join(save_dir, 'train_policy_hparams.json'), 'w') as f: 432 | json.dump(hparams.to_nested_dict(), f, indent=2) 433 | with tf.gfile.Open(os.path.join(save_dir, 'policy', 'hparams.json'), 'w') as f: 434 | json.dump(m.hparams().to_nested_dict(), f, indent=2) 435 | with tf.gfile.Open(os.path.join(save_dir, 'policy', 'encoding'), 'w') as f: 436 | json.dump(m.encoding.name, f, indent=2) 437 | utils.set_mpi_seed(hparams.run.seed) 438 | 439 | score_model = TrainedRewardModel(hparams.rewards.trained_model, m.encoding, comm=comm) 440 | 441 | ref_policy = Policy( 442 | m, scope='ref_policy', 443 | is_root=comm.Get_rank() == 0, 444 | embed_queries=lm_tasks.query_formatter(hparams.task, encoder), 445 | temperature=hparams.task.policy.temperature, 446 | build_respond=False) 447 | 448 | policy = Policy( 449 | m, scope='policy', 450 | is_root=comm.Get_rank() == 0, 451 | embed_queries=lm_tasks.query_formatter(hparams.task, encoder), 452 | temperature=hparams.task.policy.temperature) 453 | 454 | query_sampler = lm_tasks.make_query_sampler( 455 | hparams=hparams.task, encoder=encoder, comm=comm, 456 | batch_size=utils.exact_div(hparams.ppo.batch_size, comm.Get_size()), 457 | ) 458 | 459 | per_rank_minibatch_size = utils.exact_div(hparams.ppo.batch_size, hparams.ppo.nminibatches * comm.Get_size()) 460 | if hparams.ppo.whiten_rewards: 461 | assert per_rank_minibatch_size >= 8, \ 462 | f"Per-rank minibatch size {per_rank_minibatch_size} is insufficient for whitening" 463 | 464 | global_step = tf.train.get_or_create_global_step() 465 | increment_global_step = tf.group(global_step.assign_add(1)) 466 | 467 | with utils.variables_on_gpu(): 468 | 469 | ppo_trainer = PPOTrainer( 470 | policy=policy, ref_policy=ref_policy, query_sampler=query_sampler, 471 | score_fn=make_score_fn(hparams.task, score_model=score_model), 472 | hparams=hparams, comm=comm) 473 | 474 | if comm.Get_rank() == 0 and save_dir: 475 | print(f"Will save to {save_dir}") 476 | saver = tf.train.Saver(max_to_keep=20, save_relative_paths=True) 477 | checkpoint_dir = os.path.join(save_dir, 'policy/checkpoints/model.ckpt') 478 | else: 479 | saver = None 480 | checkpoint_dir = None 481 | 482 | @utils.graph_function() 483 | def sync_models(): 484 | score_model.ensure_built() 485 | return utils.variable_synchronizer(comm, vars=score_model.get_params() + ref_policy.get_params() + policy.get_params()) 486 | 487 | init_ops = tf.group( 488 | tf.global_variables_initializer(), 489 | tf.local_variables_initializer(), 490 | summary.summary_writer_initializer_op()) 491 | 492 | with utils.mpi_session() as sess: 493 | init_ops.run() 494 | 495 | sync_models() 496 | 497 | tf.get_default_graph().finalize() 498 | 499 | try: 500 | while global_step.eval() < nupdates(hparams): 501 | ppo_trainer.step() 502 | increment_global_step.run() 503 | 504 | if saver and global_step.eval() % hparams.run.save_interval == 0: 505 | saver.save(sess, checkpoint_dir, global_step=global_step) 506 | finally: 507 | if saver: 508 | saver.save(sess, checkpoint_dir, global_step=global_step) 509 | -------------------------------------------------------------------------------- /lm_human_preferences/train_reward.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import os 5 | from dataclasses import dataclass, field 6 | from functools import partial 7 | from typing import Optional 8 | 9 | import numpy as np 10 | import tensorflow as tf 11 | from mpi4py import MPI 12 | from tensorflow.contrib import summary 13 | 14 | from lm_human_preferences import label_types, lm_tasks, rewards 15 | from lm_human_preferences.language import trained_models 16 | from lm_human_preferences.policy import Policy 17 | from lm_human_preferences.utils import core as utils 18 | from lm_human_preferences.utils import gcs, hyperparams 19 | from lm_human_preferences.utils.core import Schema 20 | 21 | 22 | @dataclass 23 | class LabelHParams(hyperparams.HParams): 24 | type: str = None 25 | num_train: int = None 26 | source: str = None 27 | 28 | 29 | @dataclass 30 | class RunHParams(hyperparams.HParams): 31 | seed: Optional[int] = None 32 | log_interval: int = 10 33 | save_interval: int = 50 34 | save_dir: Optional[str] = None 35 | 36 | @dataclass 37 | class HParams(hyperparams.HParams): 38 | run: RunHParams = field(default_factory=RunHParams) 39 | 40 | task: lm_tasks.TaskHParams = field(default_factory=lm_tasks.TaskHParams) 41 | labels: LabelHParams = field(default_factory=LabelHParams) 42 | 43 | batch_size: int = 40 # total across ranks 44 | lr: float = 5e-5 45 | 46 | rollout_batch_size: int = 64 47 | normalize_samples: int = 0 # Samples used to estimate reward mean and std 48 | debug_normalize: int = 0 # Samples used to check that normalization worked 49 | # Whether, before training, to normalize the rewards on the policy to the scales on the training buffer. 50 | # (For comparisons, just use mean 0, var 1.) 51 | normalize_before: bool = False 52 | # Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1 53 | # (so the KL coefficient always has the same meaning). 54 | normalize_after: bool = False 55 | 56 | def validate(self, *, prefix=''): 57 | super().validate(prefix=prefix) 58 | utils.exact_div(self.labels.num_train, self.batch_size) 59 | 60 | def round_down_to_multiple(n, divisor): 61 | return n - n % divisor 62 | 63 | 64 | def download_labels(source, label_type, question_schemas, total_labels, comm): 65 | schemas = {**question_schemas, **label_type.label_schemas()} 66 | 67 | """ 68 | if self.is_root: 69 | with tf.device('cpu:0'): 70 | self._enqueue_phs = { 71 | name: tf.placeholder(name=name, dtype=schema.dtype, shape=(None,) + schema.shape) 72 | for name, schema in self.schemas.items() 73 | } 74 | self._enqueue_answers = self.answer_queue.enqueue_many(self._enqueue_phs) 75 | else: 76 | self._enqueue_phs = None 77 | self._enqueue_answers = None 78 | """ 79 | 80 | # TODO: download on just one rank? then do: labels = utils.mpi_bcast_tensor_dict(labels, comm=comm) 81 | if source != 'test': 82 | with open(gcs.download_file_cached(source, comm=comm)) as f: 83 | results = json.load(f) 84 | print('Num labels found in source:', len(results)) 85 | else: 86 | results = [ 87 | { 88 | name: np.zeros(schema.shape, dtype=schema.dtype.as_numpy_dtype) 89 | for name, schema in schemas.items() 90 | } 91 | for _ in range(50) 92 | ] 93 | 94 | assert len(results) >= total_labels 95 | results = results[:total_labels] 96 | return {k: [a[k] for a in results] for k in schemas.keys()} 97 | 98 | 99 | class RewardModelTrainer(): 100 | def __init__(self, *, reward_model, policy, query_sampler, hparams, comm): 101 | self.reward_model = reward_model 102 | 103 | self.policy = policy 104 | self.hparams = hparams 105 | self.num_ranks = comm.Get_size() 106 | self.rank = comm.Get_rank() 107 | self.comm = comm 108 | 109 | self.label_type = label_types.get(hparams.labels.type) 110 | self.question_schemas = self.label_type.question_schemas( 111 | query_length=hparams.task.query_length, 112 | response_length=hparams.task.response_length, 113 | ) 114 | 115 | data_schemas = { 116 | **self.question_schemas, 117 | **self.label_type.label_schemas(), 118 | } 119 | 120 | with tf.device(None), tf.device('/cpu:0'): 121 | with tf.variable_scope('label_buffer', use_resource=True, initializer=tf.zeros_initializer): 122 | self.train_buffer = utils.SampleBuffer(capacity=hparams.labels.num_train, schemas=data_schemas) 123 | 124 | with tf.name_scope('train_reward'): 125 | summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='reward_model', comm=comm) 126 | 127 | @utils.graph_function( 128 | indices=Schema(tf.int32, (None,)), 129 | lr=Schema(tf.float32, ())) 130 | def train_batch(indices, lr): 131 | with tf.name_scope('minibatch'): 132 | minibatch = self.train_buffer.read(indices) 133 | stats = self.label_type.loss(reward_model=self.reward_model.get_rewards_op, labels=minibatch) 134 | 135 | train_op = utils.minimize( 136 | loss=stats['loss'], lr=lr, params=self.reward_model.get_params(), name='opt', comm=self.comm) 137 | 138 | with tf.control_dependencies([train_op]): 139 | step_var = tf.get_variable(name='train_step', dtype=tf.int64, shape=(), trainable=False, use_resource=True) 140 | step = step_var.assign_add(1) - 1 141 | 142 | stats = utils.FlatStats.from_dict(stats).map_flat(partial(utils.mpi_allreduce_mean, comm=comm)).as_dict() 143 | 144 | train_stat_op = utils.record_stats(stats=stats, summary_writer=summary_writer, step=step, log_interval=hparams.run.log_interval, comm=comm) 145 | 146 | return train_stat_op 147 | self.train_batch = train_batch 148 | 149 | if self.hparams.normalize_before or self.hparams.normalize_after: 150 | @utils.graph_function() 151 | def target_mean_std(): 152 | """Returns the means and variances to target for each reward model""" 153 | # Should be the same on all ranks because the train_buf should be the same 154 | scales = self.label_type.target_scales(self.train_buffer.data()) 155 | if scales is None: 156 | return tf.zeros([]), tf.ones([]) 157 | else: 158 | mean, var = tf.nn.moments(scales, axes=[0]) 159 | return mean, tf.sqrt(var) 160 | self.target_mean_std = target_mean_std 161 | 162 | def stats(query_responses): 163 | rewards = np.concatenate([self.reward_model.get_rewards(qs, rs) for qs, rs in query_responses], axis=0) 164 | assert len(rewards.shape) == 1, f'{rewards.shape}' 165 | sums = np.asarray([rewards.sum(axis=0), np.square(rewards).sum(axis=0)]) 166 | means, sqr_means = self.comm.allreduce(sums, op=MPI.SUM) / (self.num_ranks * rewards.shape[0]) 167 | stds = np.sqrt(sqr_means - means ** 2) 168 | return means, stds 169 | self.stats = stats 170 | 171 | def log_stats_after_normalize(stats): 172 | if comm.Get_rank() != 0: 173 | return 174 | means, stds = stats 175 | print(f'after normalize: {means} +- {stds}') 176 | self.log_stats_after_normalize = log_stats_after_normalize 177 | 178 | def reset_reward_scales(): 179 | self.reward_model.reset_reward_scale() 180 | self.reset_reward_scales = reset_reward_scales 181 | 182 | def set_reward_norms(mean, std, new_mean, new_std): 183 | print(f'targets: {new_mean} +- {new_std}') 184 | print(f'before normalize: {mean} +- {std}') 185 | assert np.isfinite((mean, std, new_mean, new_std)).all() 186 | self.reward_model.set_reward_norm(old_mean=mean, old_std=std, new_mean=new_mean, new_std=new_std) 187 | self.set_reward_norms = set_reward_norms 188 | 189 | if self.hparams.normalize_before or self.hparams.normalize_after: 190 | @utils.graph_function() 191 | def sample_policy_batch(): 192 | queries = query_sampler('ref_queries')['tokens'] 193 | responses = policy.respond_op( 194 | queries=queries, length=hparams.task.response_length)['responses'] 195 | return queries, responses 196 | 197 | def sample_policy_responses(n_samples): 198 | n_batches = utils.ceil_div(n_samples, hparams.rollout_batch_size) 199 | return [sample_policy_batch() for _ in range(n_batches)] 200 | self.sample_policy_responses = sample_policy_responses 201 | 202 | @utils.graph_function(labels=utils.add_batch_dim(data_schemas)) 203 | def add_to_buffer(labels): 204 | return self.train_buffer.add(**labels) 205 | self.add_to_buffer = add_to_buffer 206 | 207 | def normalize(self, sample_fn, target_means, target_stds): 208 | if not self.hparams.normalize_samples: 209 | return 210 | 211 | self.reset_reward_scales() 212 | query_responses = sample_fn(self.hparams.normalize_samples) 213 | means, stds = self.stats(query_responses) 214 | 215 | self.set_reward_norms(means, stds, target_means, target_stds) 216 | if self.hparams.debug_normalize: 217 | query_responses = sample_fn(self.hparams.debug_normalize) 218 | stats = self.stats(query_responses) 219 | self.log_stats_after_normalize(stats) 220 | 221 | def train(self): 222 | labels = download_labels( 223 | self.hparams.labels.source, 224 | label_type=self.label_type, 225 | question_schemas=self.question_schemas, 226 | total_labels=self.hparams.labels.num_train, 227 | comm=self.comm 228 | ) 229 | 230 | self.add_to_buffer(labels) 231 | 232 | if self.hparams.normalize_before: 233 | target_mean, target_std = self.target_mean_std() 234 | self.normalize(self.sample_policy_responses, target_mean, target_std) 235 | 236 | # Collect training data for reward model training. train_indices will include the indices 237 | # trained on across all ranks, and its size must be a multiple of minibatch_size. 238 | per_rank_batch_size = utils.exact_div(self.hparams.batch_size, self.num_ranks) 239 | 240 | # Make sure each rank gets the same shuffle so we train on each point exactly once 241 | train_indices = self.comm.bcast(np.random.permutation(self.hparams.labels.num_train)) 242 | 243 | # Train on train_indices 244 | print(self.rank, "training on", self.hparams.labels.num_train, "in batches of", per_rank_batch_size) 245 | for start_index in range(0, self.hparams.labels.num_train, self.hparams.batch_size): 246 | end_index = start_index + self.hparams.batch_size 247 | all_ranks_indices = train_indices[start_index:end_index] 248 | our_indices = all_ranks_indices[self.rank::self.num_ranks] 249 | lr = (1 - start_index / self.hparams.labels.num_train) * self.hparams.lr 250 | self.train_batch(our_indices, lr) 251 | 252 | if self.hparams.normalize_after: 253 | target_mean, target_std = np.zeros([]), np.ones([]) 254 | self.normalize(self.sample_policy_responses, target_mean, target_std) 255 | 256 | 257 | 258 | def train(hparams: HParams): 259 | with tf.Graph().as_default(): 260 | hyperparams.dump(hparams) 261 | utils.set_mpi_seed(hparams.run.seed) 262 | 263 | m = trained_models.TrainedModel(hparams.task.policy.initial_model) 264 | encoder = m.encoding.get_encoder() 265 | hyperparams.dump(m.hparams(), name='model_hparams') 266 | 267 | comm = MPI.COMM_WORLD 268 | ref_policy = Policy( 269 | m, scope='ref_policy', 270 | is_root=comm.Get_rank() == 0, 271 | embed_queries=lm_tasks.query_formatter(hparams.task, encoder), 272 | temperature=hparams.task.policy.temperature, 273 | build_respond=False) 274 | 275 | reward_model = rewards.RewardModelTrainer(m, is_root=comm.Get_rank() == 0) 276 | 277 | query_sampler = lm_tasks.make_query_sampler( 278 | hparams=hparams.task, encoder=encoder, comm=comm, 279 | batch_size=utils.exact_div(hparams.rollout_batch_size, comm.Get_size()) 280 | ) 281 | 282 | tf.train.create_global_step() 283 | 284 | reward_trainer = RewardModelTrainer( 285 | reward_model=reward_model, 286 | policy=ref_policy, 287 | query_sampler=query_sampler, 288 | hparams=hparams, 289 | comm=comm, 290 | ) 291 | 292 | save_dir = hparams.run.save_dir 293 | if comm.Get_rank() == 0 and save_dir: 294 | print(f"Will save to {save_dir}") 295 | saver = tf.train.Saver(max_to_keep=20, save_relative_paths=True) 296 | checkpoint_dir = os.path.join(save_dir, 'reward_model/checkpoints/model.ckpt') 297 | 298 | if not save_dir.startswith('gs://'): 299 | os.makedirs(os.path.join(save_dir, 'reward_model'), exist_ok=True) 300 | with tf.gfile.Open(os.path.join(save_dir, 'train_reward_hparams.json'), 'w') as f: 301 | json.dump(hparams.to_nested_dict(), f, indent=2) 302 | with tf.gfile.Open(os.path.join(save_dir, 'reward_model', 'hparams.json'), 'w') as f: 303 | json.dump(reward_model.hparams.to_nested_dict(), f, indent=2) 304 | with tf.gfile.Open(os.path.join(save_dir, 'reward_model', 'encoding'), 'w') as f: 305 | json.dump(reward_model.trained_model.encoding.name, f, indent=2) 306 | else: 307 | saver = None 308 | checkpoint_dir = None 309 | 310 | with utils.variables_on_gpu(): 311 | init_ops = tf.group( 312 | tf.global_variables_initializer(), 313 | tf.local_variables_initializer(), 314 | summary.summary_writer_initializer_op()) 315 | 316 | @utils.graph_function() 317 | def sync_models(): 318 | return utils.variable_synchronizer(comm, vars=ref_policy.get_params() + reward_model.get_params()) 319 | 320 | tf.get_default_graph().finalize() 321 | 322 | with utils.mpi_session() as sess: 323 | init_ops.run() 324 | sync_models() 325 | 326 | reward_trainer.train() 327 | 328 | if saver: 329 | saver.save(sess, checkpoint_dir) 330 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/combos.py: -------------------------------------------------------------------------------- 1 | def combos(*xs): 2 | if xs: 3 | return [x + combo for x in xs[0] for combo in combos(*xs[1:])] 4 | else: 5 | return [()] 6 | 7 | def each(*xs): 8 | return [y for x in xs for y in x] 9 | 10 | def bind(var, val, descriptor=''): 11 | extra = {} 12 | if descriptor: 13 | extra['descriptor'] = descriptor 14 | return [((var, val, extra),)] 15 | 16 | def label(descriptor): 17 | return bind(None, None, descriptor) 18 | 19 | def labels(*descriptors): 20 | return each(*[label(d) for d in descriptors]) 21 | 22 | def options(var, opts_with_descs): 23 | return each(*[bind(var, val, descriptor) for val, descriptor in opts_with_descs]) 24 | 25 | def _shortstr(v): 26 | if isinstance(v, float): 27 | s = f"{v:.03}" 28 | if '.' in s: 29 | s = s.lstrip('0').replace('.','x') 30 | else: 31 | s = str(v) 32 | return s 33 | 34 | def options_shortdesc(var, desc, opts): 35 | return each(*[bind(var, val, desc + _shortstr(val)) for val in opts]) 36 | 37 | def options_vardesc(var, opts): 38 | return options_shortdesc(var, var, opts) 39 | 40 | def repeat(n): 41 | return each(*[label(i) for i in range(n)]) 42 | 43 | # list monad bind; passes descriptors to body 44 | def foreach(inputs, body): 45 | return [inp + y for inp in inputs for y in body(*[extra['descriptor'] for var, val, extra in inp])] 46 | 47 | def bind_nested(prefix, binds): 48 | return [ 49 | tuple([ (var if var is None else prefix + '.' + var, val, extra) for (var, val, extra) in x ]) 50 | for x in binds 51 | ] 52 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/core.py: -------------------------------------------------------------------------------- 1 | """Utilities.""" 2 | 3 | import collections 4 | import contextlib 5 | import inspect 6 | import os 7 | import platform 8 | import shutil 9 | import subprocess 10 | from dataclasses import dataclass 11 | from functools import lru_cache, partial, wraps 12 | from typing import Any, Dict, Tuple, Optional 13 | 14 | import numpy as np 15 | import tensorflow as tf 16 | from mpi4py import MPI 17 | from tensorflow.contrib import summary 18 | 19 | try: 20 | import horovod.tensorflow as hvd 21 | hvd.init() 22 | except: 23 | hvd = None 24 | 25 | 26 | nest = tf.contrib.framework.nest 27 | 28 | 29 | def nvidia_gpu_count(): 30 | """ 31 | Count the GPUs on this machine. 32 | """ 33 | if shutil.which('nvidia-smi') is None: 34 | return 0 35 | try: 36 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv']) 37 | except subprocess.CalledProcessError: 38 | # Probably no GPUs / no driver running. 39 | return 0 40 | return max(0, len(output.split(b'\n')) - 2) 41 | 42 | 43 | def get_local_rank_size(comm): 44 | """ 45 | Returns the rank of each process on its machine 46 | The processes on a given machine will be assigned ranks 47 | 0, 1, 2, ..., N-1, 48 | where N is the number of processes on this machine. 49 | Useful if you want to assign one gpu per machine 50 | """ 51 | this_node = platform.node() 52 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node)) 53 | node2rankssofar = collections.defaultdict(int) 54 | local_rank = None 55 | for (rank, node) in ranks_nodes: 56 | if rank == comm.Get_rank(): 57 | local_rank = node2rankssofar[node] 58 | node2rankssofar[node] += 1 59 | assert local_rank is not None 60 | return local_rank, node2rankssofar[this_node] 61 | 62 | 63 | @lru_cache() 64 | def gpu_devices(): 65 | if 'CUDA_VISIBLE_DEVICES' in os.environ: 66 | raise ValueError('CUDA_VISIBLE_DEVICES should not be set (it will cause nccl slowdowns). Use VISIBLE_DEVICES instead!') 67 | devices_str = os.environ.get('VISIBLE_DEVICES') 68 | if devices_str is not None: 69 | return list(map(int, filter(len, devices_str.split(',')))) 70 | else: 71 | return list(range(nvidia_gpu_count())) 72 | 73 | @lru_cache() 74 | def gpu_count(): 75 | return len(gpu_devices()) or None 76 | 77 | 78 | @lru_cache() 79 | def _our_gpu(): 80 | """Figure out which GPU we should be using in an MPI context.""" 81 | gpus = gpu_devices() 82 | if not gpus: 83 | return None 84 | rank = MPI.COMM_WORLD.Get_rank() 85 | local_rank, local_size = get_local_rank_size(MPI.COMM_WORLD) 86 | if gpu_count() not in (0, local_size): 87 | raise ValueError('Expected one GPU per rank, got gpus %s, local size %d' % (gpus, local_size)) 88 | gpu = gpus[local_rank] 89 | print('rank %d: gpus = %s, our gpu = %d' % (rank, gpus, gpu)) 90 | return gpu 91 | 92 | 93 | def mpi_session_config(): 94 | """Make a tf.ConfigProto to use only the GPU assigned to this MPI session.""" 95 | config = tf.ConfigProto() 96 | gpu = _our_gpu() 97 | if gpu is not None: 98 | config.gpu_options.visible_device_list = str(gpu) 99 | config.gpu_options.allow_growth = True 100 | return config 101 | 102 | 103 | def mpi_session(): 104 | """Create a session using only the GPU assigned to this MPI process.""" 105 | return tf.Session(config=mpi_session_config()) 106 | 107 | 108 | def set_mpi_seed(seed: Optional[int]): 109 | if seed is not None: 110 | rank = MPI.COMM_WORLD.Get_rank() 111 | seed = seed + rank * 100003 # Prime (kept for backwards compatibility even though it does nothing) 112 | np.random.seed(seed) 113 | tf.set_random_seed(seed) 114 | 115 | 116 | def exact_div(a, b): 117 | q = a // b 118 | if tf.contrib.framework.is_tensor(q): 119 | with tf.control_dependencies([tf.debugging.Assert(tf.equal(a, q * b), [a, b])]): 120 | return tf.identity(q) 121 | else: 122 | if a != q * b: 123 | raise ValueError('Inexact division: %s / %s = %s' % (a, b, a / b)) 124 | return q 125 | 126 | 127 | def ceil_div(a, b): 128 | return (a - 1) // b + 1 129 | 130 | 131 | def expand_tile(value, size, *, axis, name=None): 132 | """Add a new axis of given size.""" 133 | with tf.name_scope(name, 'expand_tile', [value, size, axis]) as scope: 134 | value = tf.convert_to_tensor(value, name='value') 135 | size = tf.convert_to_tensor(size, name='size') 136 | ndims = value.shape.rank 137 | if axis < 0: 138 | axis += ndims + 1 139 | return tf.tile(tf.expand_dims(value, axis=axis), [1]*axis + [size] + [1]*(ndims - axis), name=scope) 140 | 141 | 142 | def index_each(a, ix): 143 | """Do a batched indexing operation: index row i of a by ix[i] 144 | 145 | In the simple case (a is >=2D and ix is 1D), returns [row[i] for row, i in zip(a, ix)]. 146 | 147 | If ix has more dimensions, multiple lookups will be done at each batch index. 148 | For instance, if ix is 2D, returns [[row[i] for i in ix_row] for row, ix_row in zip(a, ix)]. 149 | 150 | Always indexes into dimension 1 of a. 151 | """ 152 | a = tf.convert_to_tensor(a, name='a') 153 | ix = tf.convert_to_tensor(ix, name='ix', dtype=tf.int32) 154 | with tf.name_scope('index_each', values=[a, ix]) as scope: 155 | a.shape[:1].assert_is_compatible_with(ix.shape[:1]) 156 | i0 = tf.range(tf.shape(a)[0], dtype=ix.dtype) 157 | if ix.shape.rank > 1: 158 | i0 = tf.tile(tf.reshape(i0, (-1,) + (1,)*(ix.shape.rank - 1)), tf.concat([[1], tf.shape(ix)[1:]], axis=0)) 159 | return tf.gather_nd(a, tf.stack([i0, ix], axis=-1), name=scope) 160 | 161 | def cumulative_max(x): 162 | """Takes the (inclusive) cumulative maximum along the last axis of x. (Not efficient.)""" 163 | x = tf.convert_to_tensor(x) 164 | with tf.name_scope('cumulative_max', values=[x]) as scope: 165 | repeated = tf.tile( 166 | tf.expand_dims(x, axis=-1), 167 | tf.concat([tf.ones(x.shape.rank, dtype=tf.int32), tf.shape(x)[-1:]], axis=0)) 168 | trues = tf.ones_like(repeated, dtype=tf.bool) 169 | upper_triangle = tf.matrix_band_part(trues, 0, -1) 170 | neg_inf = tf.ones_like(repeated) * tf.dtypes.saturate_cast(-np.inf, dtype=x.dtype) 171 | prefixes = tf.where(upper_triangle, repeated, neg_inf) 172 | return tf.math.reduce_max(prefixes, axis=-2, name=scope) 173 | 174 | 175 | def flatten_dict(nested, sep='.'): 176 | def rec(nest, prefix, into): 177 | for k, v in nest.items(): 178 | if sep in k: 179 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 180 | if isinstance(v, collections.Mapping): 181 | rec(v, prefix + k + sep, into) 182 | else: 183 | into[prefix + k] = v 184 | flat = {} 185 | rec(nested, '', flat) 186 | return flat 187 | 188 | @dataclass 189 | class Schema: 190 | dtype: Any 191 | shape: Tuple[Optional[int],...] 192 | 193 | 194 | def add_batch_dim(schemas, batch_size=None): 195 | def add_dim(schema): 196 | return Schema(dtype=schema.dtype, shape=(batch_size,)+schema.shape) 197 | return nest.map_structure(add_dim, schemas) 198 | 199 | 200 | class SampleBuffer: 201 | """A circular buffer for storing and sampling data. 202 | 203 | Data can be added to the buffer with `add`, and old data will be dropped. If you need to 204 | control where the buffer is stored, wrap the constructor call in a `with tf.device` block: 205 | 206 | with tf.device('cpu:0'): 207 | buffer = SampleBuffer(...) 208 | """ 209 | 210 | def __init__(self, *, capacity: int, schemas: Dict[str,Schema], name=None) -> None: 211 | with tf.variable_scope(name, 'buffer', use_resource=True, initializer=tf.zeros_initializer): 212 | self._capacity = tf.constant(capacity, dtype=tf.int32, name='capacity') 213 | self._total = tf.get_variable( 214 | 'total', dtype=tf.int32, shape=(), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES], 215 | ) 216 | self._vars = { 217 | n: tf.get_variable( 218 | n, dtype=s.dtype, shape=(capacity,) + s.shape, trainable=False, 219 | collections=[tf.GraphKeys.LOCAL_VARIABLES], 220 | ) 221 | for n,s in schemas.items() 222 | } 223 | 224 | def add(self, **data): 225 | """Add new data to the end of the buffer, dropping old data if we exceed capacity.""" 226 | # Check input shapes 227 | if data.keys() != self._vars.keys(): 228 | raise ValueError('data.keys() = %s != %s' % (sorted(data.keys()), sorted(self._vars.keys()))) 229 | first = next(iter(data.values())) 230 | pre = first.shape[:1] 231 | for k, d in data.items(): 232 | try: 233 | d.shape.assert_is_compatible_with(pre.concatenate(self._vars[k].shape[1:])) 234 | except ValueError as e: 235 | raise ValueError('%s, key %s' % (e, k)) 236 | # Enqueue 237 | n = tf.shape(first)[0] 238 | capacity = self._capacity 239 | i0 = (self._total.assign_add(n) - n) % capacity 240 | i0n = i0 + n 241 | i1 = tf.minimum(i0n, capacity) 242 | i2 = i1 % capacity 243 | i3 = i0n % capacity 244 | slices = slice(i0, i1), slice(i2, i3) 245 | sizes = tf.stack([i1 - i0, i3 - i2]) 246 | assigns = [self._vars[k][s].assign(part) 247 | for k,d in data.items() 248 | for s, part in zip(slices, tf.split(d, sizes))] 249 | return tf.group(assigns) 250 | 251 | def total(self): 252 | """Total number of entries ever added, including those already discarded.""" 253 | return self._total.read_value() 254 | 255 | def size(self): 256 | """Current number of entries.""" 257 | return tf.minimum(self.total(), self._capacity) 258 | 259 | def read(self, indices): 260 | """indices: A 1-D Tensor of indices to read from. Each index must be less than 261 | capacity.""" 262 | return {k: v.sparse_read(indices) for k,v in self._vars.items()} 263 | 264 | def data(self): 265 | return {k: v[:self.size()] for k,v in self._vars.items()} 266 | 267 | def sample(self, n, seed=None): 268 | """Sample n entries with replacement.""" 269 | size = self.size() 270 | indices = tf.random_uniform([n], maxval=size, dtype=tf.int32, seed=seed) 271 | return self.read(indices) 272 | 273 | def write(self, indices, updates): 274 | """ 275 | indices: A 1-D Tensor of indices to write to. Each index must be less than `capacity`. 276 | update: A dictionary of new values, where each entry is a tensor with the same length as `indices`. 277 | """ 278 | ops = [] 279 | for k, v in updates.items(): 280 | ops.append(self._vars[k].scatter_update(tf.IndexedSlices(v, tf.cast(indices, dtype=tf.int32)))) 281 | return tf.group(*ops) 282 | 283 | def write_add(self, indices, deltas): 284 | ops = [] 285 | for k, d in deltas.items(): 286 | ops.append(self._vars[k].scatter_add(tf.IndexedSlices(d, tf.cast(indices, dtype=tf.int32)))) 287 | return tf.group(*ops) 288 | 289 | 290 | def entropy_from_logits(logits): 291 | pd = tf.nn.softmax(logits, axis=-1) 292 | return tf.math.reduce_logsumexp(logits, axis=-1) - tf.reduce_sum(pd*logits, axis=-1) 293 | 294 | 295 | def logprobs_from_logits(*, logits, labels): 296 | return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits) 297 | 298 | 299 | def sample_from_logits(logits, dtype=tf.int32): 300 | with tf.name_scope('sample_from_logits', values=[logits]) as scope: 301 | shape = tf.shape(logits) 302 | flat_logits = tf.reshape(logits, [-1, shape[-1]]) 303 | flat_samples = tf.random.categorical(flat_logits, num_samples=1, dtype=dtype) 304 | return tf.reshape(flat_samples, shape[:-1], name=scope) 305 | 306 | 307 | def take_top_k_logits(logits, k): 308 | values, _ = tf.nn.top_k(logits, k=k) 309 | min_values = values[:, :, -1, tf.newaxis] 310 | return tf.where( 311 | logits < min_values, 312 | tf.ones_like(logits) * -1e10, 313 | logits, 314 | ) 315 | 316 | 317 | def take_top_p_logits(logits, p): 318 | """Nucleus sampling""" 319 | batch, sequence, _ = logits.shape.as_list() 320 | sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1) 321 | cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1) 322 | indices = tf.stack([ 323 | tf.range(0, batch)[:, tf.newaxis], 324 | tf.range(0, sequence)[tf.newaxis, :], 325 | # number of indices to include 326 | tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0), 327 | ], axis=-1) 328 | min_values = tf.gather_nd(sorted_logits, indices) 329 | return tf.where( 330 | logits < min_values, 331 | tf.ones_like(logits) * -1e10, 332 | logits, 333 | ) 334 | 335 | 336 | def whiten(values, shift_mean=True): 337 | mean, var = tf.nn.moments(values, axes=list(range(values.shape.rank))) 338 | whitened = (values - mean) * tf.rsqrt(var + 1e-8) 339 | if not shift_mean: 340 | whitened += mean 341 | return whitened 342 | 343 | 344 | 345 | def where(cond, true, false, name=None): 346 | """Similar to tf.where, but broadcasts scalar values.""" 347 | with tf.name_scope(name, 'where', [cond, true, false]) as name: 348 | cond = tf.convert_to_tensor(cond, name='cond', dtype=tf.bool) 349 | true = tf.convert_to_tensor(true, name='true', 350 | dtype=false.dtype if isinstance(false, tf.Tensor) else None) 351 | false = tf.convert_to_tensor(false, name='false', dtype=true.dtype) 352 | if true.shape.rank == false.shape.rank == 0: 353 | shape = tf.shape(cond) 354 | true = tf.fill(shape, true) 355 | false = tf.fill(shape, false) 356 | elif true.shape.rank == 0: 357 | true = tf.fill(tf.shape(false), true) 358 | elif false.shape.rank == 0: 359 | false = tf.fill(tf.shape(true), false) 360 | return tf.where(cond, true, false, name=name) 361 | 362 | 363 | def map_flat(f, values): 364 | """Apply the function f to flattened, concatenated values, then split and reshape back to original shapes.""" 365 | values = tuple(values) 366 | for v in values: 367 | assert not isinstance(v, tf.IndexedSlices) 368 | values = [tf.convert_to_tensor(v) for v in values] 369 | flat = tf.concat([tf.reshape(v, [-1]) for v in values], axis=0) 370 | flat = f(flat) 371 | parts = tf.split(flat, [tf.size(v) for v in values]) 372 | return [tf.reshape(p, tf.shape(v)) for p, v in zip(parts, values)] 373 | 374 | 375 | def map_flat_chunked(f, values, *, limit=1<<29): 376 | """ 377 | Apply the function f to chunked, flattened, concatenated values, then split and reshape back to original shapes. 378 | """ 379 | values = tuple(values) 380 | for v in values: 381 | assert not isinstance(v, tf.IndexedSlices) 382 | values = [tf.convert_to_tensor(v) for v in values] 383 | chunks = chunk_tensors(values, limit=limit) 384 | mapped_values = [v for chunk in chunks for v in map_flat(f, chunk)] 385 | return mapped_values 386 | 387 | 388 | def map_flat_bits(f, values): 389 | """Apply the function f to bit-concatenated values, then convert back to original shapes and dtypes.""" 390 | values = [tf.convert_to_tensor(v) for v in values] 391 | def maybe_bitcast(v, dtype): 392 | cast = tf.cast if tf.bool in (v.dtype, dtype) else tf.bitcast 393 | return cast(v, dtype) 394 | bits = [maybe_bitcast(v, tf.uint8) for v in values] 395 | flat = tf.concat([tf.reshape(b, [-1]) for b in bits], axis=0) 396 | flat = f(flat) 397 | parts = tf.split(flat, [tf.size(b) for b in bits]) 398 | return [maybe_bitcast(tf.reshape(p, tf.shape(b)), v.dtype) 399 | for p, v, b in zip(parts, values, bits)] 400 | 401 | def mpi_bcast_tensor_dict(d, comm): 402 | sorted_keys = sorted(d.keys()) 403 | values = map_flat_bits(partial(mpi_bcast, comm), [d[k] for k in sorted_keys]) 404 | return {k: v for k, v in zip(sorted_keys, values)} 405 | 406 | def mpi_bcast(comm, value, root=0): 407 | """Broadcast value from root to other processes via a TensorFlow py_func.""" 408 | value = tf.convert_to_tensor(value) 409 | if comm.Get_size() == 1: 410 | return value 411 | comm = comm.Dup() # Allow parallelism at graph execution time 412 | if comm.Get_rank() == root: 413 | out = tf.py_func(partial(comm.bcast, root=root), [value], value.dtype) 414 | else: 415 | out = tf.py_func(partial(comm.bcast, None, root=root), [], value.dtype) 416 | out.set_shape(value.shape) 417 | return out 418 | 419 | 420 | def chunk_tensors(tensors, *, limit=1 << 28): 421 | """Chunk the list of tensors into groups of size at most `limit` bytes. 422 | 423 | The tensors must have a static shape. 424 | """ 425 | total = 0 426 | batches = [] 427 | for v in tensors: 428 | size = v.dtype.size * v.shape.num_elements() 429 | if not batches or total + size > limit: 430 | total = 0 431 | batches.append([]) 432 | total += size 433 | batches[-1].append(v) 434 | return batches 435 | 436 | 437 | def variable_synchronizer(comm, vars, *, limit=1<<28): 438 | """Synchronize `vars` from the root to other processs""" 439 | if comm.Get_size() == 1: 440 | return tf.no_op() 441 | 442 | # Split vars into chunks so that no chunk is over limit bytes 443 | batches = chunk_tensors(sorted(vars, key=lambda v: v.name), limit=limit) 444 | 445 | # Synchronize each batch, using a separate communicator to ensure safety 446 | prev = tf.no_op() 447 | for batch in batches: 448 | with tf.control_dependencies([prev]): 449 | assigns = [] 450 | values = map_flat_bits(partial(mpi_bcast, comm), batch) 451 | for var, value in zip(batch, values): 452 | assigns.append(var.assign(value)) 453 | prev = tf.group(*assigns) 454 | return prev 455 | 456 | 457 | def mpi_read_file(comm, path): 458 | """Read a file on rank 0 and broadcast the contents to all machines.""" 459 | if comm.Get_rank() == 0: 460 | with tf.gfile.Open(path, 'rb') as fh: 461 | data = fh.read() 462 | comm.bcast(data) 463 | else: 464 | data = comm.bcast(None) 465 | return data 466 | 467 | 468 | def mpi_allreduce_sum(values, *, comm): 469 | if comm.Get_size() == 1: 470 | return values 471 | orig_dtype = values.dtype 472 | if hvd is None: 473 | orig_shape = values.shape 474 | def _allreduce(vals): 475 | buf = np.zeros(vals.shape, np.float32) 476 | comm.Allreduce(vals, buf, op=MPI.SUM) 477 | return buf 478 | values = tf.py_func(_allreduce, [values], tf.float32) 479 | values.set_shape(orig_shape) 480 | else: 481 | values = hvd.mpi_ops._allreduce(values) 482 | return tf.cast(values, dtype=orig_dtype) 483 | 484 | 485 | def mpi_allreduce_mean(values, *, comm): 486 | scale = 1 / comm.Get_size() 487 | values = mpi_allreduce_sum(values, comm=comm) 488 | return values if scale == 1 else scale * values 489 | 490 | 491 | class FlatStats: 492 | """A bunch of statistics stored as a single flat tensor.""" 493 | 494 | def __init__(self, keys, flat): 495 | keys = tuple(keys) 496 | flat = tf.convert_to_tensor(flat, dtype=tf.float32, name='flat') 497 | assert [len(keys)] == flat.shape.as_list() 498 | self.keys = keys 499 | self.flat = flat 500 | 501 | @staticmethod 502 | def from_dict(stats): 503 | for k, v in stats.items(): 504 | if v.dtype != tf.float32: 505 | raise ValueError('Statistic %s has dtype %r, expected %r' % (k, v.dtype, tf.float32)) 506 | keys = tuple(sorted(stats.keys())) 507 | flat = tf.stack([stats[k] for k in keys]) 508 | return FlatStats(keys, flat) 509 | 510 | def concat(self, more): 511 | dups = set(self.keys) & set(more.keys) 512 | if dups: 513 | raise ValueError('Duplicate statistics: %s' % ', '.join(dups)) 514 | return FlatStats(self.keys + more.keys, tf.concat([self.flat, more.flat], axis=0)) 515 | 516 | def as_dict(self): 517 | flat = tf.unstack(self.flat, num=len(self.keys)) 518 | return dict(safe_zip(self.keys, flat)) 519 | 520 | def with_values(self, flat): 521 | return FlatStats(self.keys, flat) 522 | 523 | def map_flat(self, f): 524 | return FlatStats(self.keys, f(self.flat)) 525 | 526 | 527 | def find_trainable_variables(key): 528 | return [v for v in tf.trainable_variables() if v.op.name.startswith(key + '/')] 529 | 530 | 531 | def variables_on_gpu(): 532 | """Prevent variables from accidentally being placed on the CPU. 533 | 534 | This dodges an obscure bug in tf.train.init_from_checkpoint. 535 | """ 536 | if _our_gpu() is None: 537 | return contextlib.suppress() 538 | def device(op): 539 | return '/gpu:0' if op.type == 'VarHandleOp' else '' 540 | return tf.device(device) 541 | 542 | 543 | 544 | def graph_function(**schemas: Schema): 545 | def decorate(make_op): 546 | def make_ph(path, schema): 547 | return tf.placeholder(name=f'arg_{make_op.__name__}_{path}', shape=schema.shape, dtype=schema.dtype) 548 | phs = nest.map_structure_with_paths(make_ph, schemas) 549 | op = make_op(**phs) 550 | sig = inspect.signature(make_op) 551 | @wraps(make_op) 552 | def run(*args, **kwargs): 553 | bound: inspect.BoundArguments = sig.bind(*args, **kwargs) 554 | bound.apply_defaults() 555 | 556 | arg_dict = bound.arguments 557 | for name, param in sig.parameters.items(): 558 | if param.kind == inspect.Parameter.VAR_KEYWORD: 559 | kwargs = arg_dict[name] 560 | arg_dict.update(kwargs) 561 | del arg_dict[name] 562 | flat_phs = nest.flatten(phs) 563 | flat_arguments = nest.flatten_up_to(phs, bound.arguments) 564 | feed = {ph: arg for ph, arg in zip(flat_phs, flat_arguments)} 565 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True) 566 | 567 | return tf.get_default_session().run(op, feed_dict=feed, options=run_options, run_metadata=None) 568 | return run 569 | return decorate 570 | 571 | 572 | 573 | def pearson_r(x: tf.Tensor, y: tf.Tensor): 574 | assert x.shape.rank == 1 575 | assert y.shape.rank == 1 576 | x_mean, x_var = tf.nn.moments(x, axes=[0]) 577 | y_mean, y_var = tf.nn.moments(y, axes=[0]) 578 | cov = tf.reduce_mean((x - x_mean)*(y - y_mean), axis=0) 579 | return cov / tf.sqrt(x_var * y_var) 580 | 581 | def shape_list(x): 582 | """Deal with dynamic shape in tensorflow cleanly.""" 583 | static = x.shape.as_list() 584 | dynamic = tf.shape(x) 585 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 586 | 587 | def safe_zip(*args): 588 | """Zip, but require all sequences to be the same length.""" 589 | args = tuple(map(tuple, args)) 590 | for a in args[1:]: 591 | if len(args[0]) != len(a): 592 | raise ValueError(f'Lengths do not match: {[len(a) for a in args]}') 593 | return zip(*args) 594 | 595 | 596 | def get_summary_writer(save_dir, subdir='', comm=MPI.COMM_WORLD): 597 | if comm.Get_rank() != 0: 598 | return None 599 | if save_dir is None: 600 | return None 601 | with tf.init_scope(): 602 | return summary.create_file_writer(os.path.join(save_dir, 'tb', subdir)) 603 | 604 | 605 | def record_stats(*, stats, summary_writer, step, log_interval, name=None, comm=MPI.COMM_WORLD): 606 | def log_stats(step, *stat_values): 607 | if comm.Get_rank() != 0 or step % log_interval != 0: 608 | return 609 | 610 | for k, v in safe_zip(stats.keys(), stat_values): 611 | print('k = ', k, ', v = ', v) 612 | 613 | summary_ops = [tf.py_func(log_stats, [step] + list(stats.values()), [])] 614 | if summary_writer: 615 | with summary_writer.as_default(), summary.always_record_summaries(): 616 | for key, value in stats.items(): 617 | summary_ops.append(summary.scalar(key, value, step=step)) 618 | return tf.group(*summary_ops, name=name) 619 | 620 | 621 | def minimize(*, loss, params, lr, name=None, comm=MPI.COMM_WORLD): 622 | with tf.name_scope(name, 'minimize'): 623 | with tf.name_scope('grads'): 624 | grads = tf.gradients(loss, params) 625 | grads, params = zip(*[(g, v) for g, v in zip(grads, params) if g is not None]) 626 | grads = map_flat_chunked(partial(mpi_allreduce_mean, comm=comm), grads) 627 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5, name='adam') 628 | opt_op = optimizer.apply_gradients(zip(grads, params), name=name) 629 | return opt_op 630 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/gcs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import subprocess 4 | import time 5 | import traceback 6 | import warnings 7 | from functools import wraps 8 | from urllib.parse import urlparse, unquote 9 | 10 | import requests 11 | from google.api_core.exceptions import InternalServerError, ServiceUnavailable 12 | from google.cloud import storage 13 | 14 | warnings.filterwarnings("ignore", "Your application has authenticated using end user credentials") 15 | 16 | 17 | def exponential_backoff( 18 | retry_on=lambda e: True, *, init_delay_s=1, max_delay_s=600, max_tries=30, factor=2.0, 19 | jitter=0.2, log_errors=True): 20 | """ 21 | Returns a decorator which retries the wrapped function as long as retry_on returns True for the exception. 22 | :param init_delay_s: How long to wait to do the first retry (in seconds). 23 | :param max_delay_s: At what duration to cap the retry interval at (in seconds). 24 | :param max_tries: How many total attempts to perform. 25 | :param factor: How much to multiply the delay interval by after each attempt (until it reaches max_delay_s). 26 | :param jitter: How much to jitter by (between 0 and 1) -- each delay will be multiplied by a random value between (1-jitter) and (1+jitter). 27 | :param log_errors: Whether to print tracebacks on every retry. 28 | :param retry_on: A predicate which takes an exception and indicates whether to retry after that exception. 29 | """ 30 | def decorate(f): 31 | @wraps(f) 32 | def f_retry(*args, **kwargs): 33 | delay_s = float(init_delay_s) 34 | for i in range(max_tries): 35 | try: 36 | return f(*args, **kwargs) 37 | except Exception as e: 38 | if not retry_on(e) or i == max_tries-1: 39 | raise 40 | if log_errors: 41 | print(f"Retrying after try {i+1}/{max_tries} failed:") 42 | traceback.print_exc() 43 | jittered_delay = random.uniform(delay_s*(1-jitter), delay_s*(1+jitter)) 44 | time.sleep(jittered_delay) 45 | delay_s = min(delay_s * factor, max_delay_s) 46 | return f_retry 47 | return decorate 48 | 49 | 50 | def _gcs_should_retry_on(e): 51 | # Retry on all 503 errors and 500, as recommended by https://cloud.google.com/apis/design/errors#error_retries 52 | return isinstance(e, (InternalServerError, ServiceUnavailable, requests.exceptions.ConnectionError)) 53 | 54 | 55 | def parse_url(url): 56 | """Given a gs:// path, returns bucket name and blob path.""" 57 | result = urlparse(url) 58 | if result.scheme == 'gs': 59 | return result.netloc, unquote(result.path.lstrip('/')) 60 | elif result.scheme == 'https': 61 | assert result.netloc == 'storage.googleapis.com' 62 | bucket, rest = result.path.lstrip('/').split('/', 1) 63 | return bucket, unquote(rest) 64 | else: 65 | raise Exception(f'Could not parse {url} as gcs url') 66 | 67 | 68 | @exponential_backoff(_gcs_should_retry_on) 69 | def get_blob(url, client=None): 70 | if client is None: 71 | client = storage.Client() 72 | bucket_name, path = parse_url(url) 73 | bucket = client.get_bucket(bucket_name) 74 | return bucket.get_blob(path) 75 | 76 | 77 | @exponential_backoff(_gcs_should_retry_on) 78 | def download_contents(url, client=None): 79 | """Given a gs:// path, returns contents of the corresponding blob.""" 80 | blob = get_blob(url, client) 81 | if not blob: return None 82 | return blob.download_as_string() 83 | 84 | 85 | @exponential_backoff(_gcs_should_retry_on) 86 | def upload_contents(url, contents, client=None): 87 | """Given a gs:// path, returns contents of the corresponding blob.""" 88 | if client is None: 89 | client = storage.Client() 90 | bucket_name, path = parse_url(url) 91 | bucket = client.get_bucket(bucket_name) 92 | blob = storage.Blob(path, bucket) 93 | blob.upload_from_string(contents) 94 | 95 | 96 | def download_directory_cached(url, comm=None): 97 | """ Given a GCS path url, caches the contents locally. 98 | WARNING: only use this function if contents under the path won't change! 99 | """ 100 | cache_dir = '/tmp/gcs-cache' 101 | bucket_name, path = parse_url(url) 102 | is_master = not comm or comm.Get_rank() == 0 103 | local_path = os.path.join(cache_dir, bucket_name, path) 104 | 105 | sentinel = os.path.join(local_path, 'SYNCED') 106 | if is_master: 107 | if not os.path.exists(local_path): 108 | os.makedirs(os.path.dirname(local_path), exist_ok=True) 109 | cmd = 'gsutil', '-m', 'cp', '-r', url, os.path.dirname(local_path) + '/' 110 | print(' '.join(cmd)) 111 | subprocess.check_call(cmd) 112 | open(sentinel, 'a').close() 113 | else: 114 | while not os.path.exists(sentinel): 115 | time.sleep(1) 116 | return local_path 117 | 118 | 119 | def download_file_cached(url, comm=None): 120 | """ Given a GCS path url, caches the contents locally. 121 | WARNING: only use this function if contents under the path won't change! 122 | """ 123 | cache_dir = '/tmp/gcs-cache' 124 | bucket_name, path = parse_url(url) 125 | is_master = not comm or comm.Get_rank() == 0 126 | local_path = os.path.join(cache_dir, bucket_name, path) 127 | 128 | sentinel = local_path + '.SYNCED' 129 | if is_master: 130 | if not os.path.exists(local_path): 131 | os.makedirs(os.path.dirname(local_path), exist_ok=True) 132 | cmd = 'gsutil', '-m', 'cp', url, local_path 133 | print(' '.join(cmd)) 134 | subprocess.check_call(cmd) 135 | open(sentinel, 'a').close() 136 | else: 137 | while not os.path.exists(sentinel): 138 | time.sleep(1) 139 | return local_path 140 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/hyperparams.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import typing 4 | from dataclasses import fields, is_dataclass 5 | from functools import lru_cache 6 | 7 | from typeguard import check_type 8 | 9 | from lm_human_preferences.utils import gcs 10 | 11 | 12 | class HParams: 13 | """Used as a base class for hyperparameter structs. They also need to be annotated with @dataclass.""" 14 | 15 | def override_from_json_file(self, filename): 16 | if filename.startswith('gs://'): 17 | hparams_str = gcs.download_contents(filename) 18 | else: 19 | hparams_str = open(filename).read() 20 | self.parse_json(hparams_str) 21 | 22 | def override_from_str(self, hparam_str): 23 | """Overrides values from a string like 'x.y=1,name=foobar'. 24 | 25 | Like tensorflow.contrib.training.HParams, this method does not allow specifying string values containing commas. 26 | """ 27 | kvp_strs = hparam_str.split(',') 28 | flat_dict = {} 29 | for kvp_str in kvp_strs: 30 | k, sep, v = kvp_str.partition('=') 31 | if not sep: 32 | raise ValueError(f"Malformed hyperparameter value: '{kvp_str}'") 33 | flat_dict[k] = v 34 | 35 | self.override_from_str_dict(flat_dict) 36 | 37 | def override_from_str_dict(self, flat_dict, separator='.'): 38 | """Overrides values from a dict like {'x.y': "1", 'name': "foobar"}. 39 | 40 | Treats keys with dots as paths into nested HParams. 41 | Parses values according to the types in the HParams classes. 42 | """ 43 | typemap = _type_map(type(self), separator=separator) 44 | 45 | parsed = {} 46 | for flat_k, s in flat_dict.items(): 47 | if flat_k not in typemap: 48 | raise AttributeError(f"no field {flat_k} in {typemap}") 49 | parsed[flat_k] = _parse_typed_value(typemap[flat_k], s) 50 | 51 | self.override_from_dict(parsed, separator=separator) 52 | 53 | def parse_json(self, s: str): 54 | self.override_from_nested_dict(json.loads(s)) 55 | 56 | def override_from_dict(self, flat_dict, separator='.'): 57 | """Overrides values from a dict like {'x.y': 1, 'name': "foobar"}. 58 | 59 | Treats keys with dots as paths into nested HParams. 60 | Values should be parsed already. 61 | """ 62 | # Parse 'on' and 'off' values. 63 | typemap = _type_map(type(self), separator=separator) 64 | 65 | flat_dict_parsed = {} 66 | for flat_k, v in flat_dict.items(): 67 | cls = _type_to_class(typemap[flat_k]) 68 | if is_hparam_type(cls) and v == 'on': 69 | parsed_v = cls() 70 | elif is_hparam_type(cls) and v == 'off': 71 | parsed_v = None 72 | else: 73 | parsed_v = v 74 | flat_dict_parsed[flat_k] = parsed_v 75 | 76 | # Expand implicit nested 'on' values. For instance, {'x.y': 'on'} should mean {'x': 'on', 'x.y': 'on'}. 77 | flat_dict_expanded = {} 78 | for flat_k, v in flat_dict_parsed.items(): 79 | flat_dict_expanded[flat_k] = v 80 | cls = _type_to_class(typemap[flat_k]) 81 | if is_hparam_type(cls) and v is not None: 82 | parts = flat_k.split(separator) 83 | prefix = parts[0] 84 | for i in range(1, len(parts)): 85 | if prefix not in flat_dict_expanded: 86 | flat_dict_expanded[prefix] = _type_to_class(typemap[prefix])() 87 | prefix += separator + parts[i] 88 | 89 | # Set all the values. The sort ensures that outer classes get initialized before their fields. 90 | for flat_k in sorted(flat_dict_expanded.keys()): 91 | v = flat_dict_expanded[flat_k] 92 | *ks, f = flat_k.split(separator) 93 | hp = self 94 | for i, k in enumerate(ks): 95 | try: 96 | hp = getattr(hp, k) 97 | except AttributeError: 98 | raise AttributeError(f"{hp} {'(' + separator.join(ks[:i]) + ') ' if i else ''}has no field '{k}'") 99 | try: 100 | setattr(hp, f, v) 101 | except AttributeError: 102 | raise AttributeError(f"{hp} ({separator.join(ks)}) has no field '{f}'") 103 | 104 | def override_from_nested_dict(self, nested_dict): 105 | for k, v in nested_dict.items(): 106 | if isinstance(v, dict): 107 | if getattr(self, k) is None: 108 | cls = _type_to_class(_get_field(self, k).type) 109 | setattr(self, k, cls()) 110 | getattr(self, k).override_from_nested_dict(v) 111 | else: 112 | setattr(self, k, v) 113 | 114 | def to_nested_dict(self): 115 | d = {} 116 | for f in fields(self): 117 | fieldval = getattr(self, f.name) 118 | if isinstance(fieldval, HParams): 119 | fieldval = fieldval.to_nested_dict() 120 | d[f.name] = fieldval 121 | return d 122 | 123 | def validate(self, *, prefix=''): 124 | assert is_dataclass(self), f"You forgot to annotate {type(self)} with @dataclass" 125 | for f in fields(self): 126 | fieldval = getattr(self, f.name) 127 | check_type(prefix + f.name, fieldval, f.type) 128 | if isinstance(fieldval, HParams): 129 | fieldval.validate(prefix=prefix + f.name + '.') 130 | 131 | 132 | def is_hparam_type(ty): 133 | if isinstance(ty, type) and issubclass(ty, HParams): 134 | assert is_dataclass(ty) 135 | return True 136 | else: 137 | return False 138 | 139 | 140 | def _is_union_type(ty): 141 | return getattr(ty, '__origin__', None) is typing.Union 142 | 143 | 144 | def dump(hparams, *, name='hparams', out=sys.stdout): 145 | out.write('%s:\n' % name) 146 | def dump_nested(hp, indent): 147 | for f in sorted(fields(hp), key=lambda f: f.name): 148 | v = getattr(hp, f.name) 149 | if isinstance(v, HParams): 150 | out.write('%s%s:\n' % (indent, f.name)) 151 | dump_nested(v, indent=indent+' ') 152 | else: 153 | out.write('%s%s: %s\n' % (indent, f.name, v)) 154 | dump_nested(hparams, indent=' ') 155 | 156 | 157 | def _can_distinguish_unambiguously(type_set): 158 | """Whether it's always possible to tell which type in type_set a certain value is supposed to be""" 159 | if len(type_set) == 1: 160 | return True 161 | if type(None) in type_set: 162 | return True 163 | if str in type_set: 164 | return False 165 | if int in type_set and float in type_set: 166 | return False 167 | if any(_is_union_type(ty) for ty in type_set): 168 | # Nested unions *might* be unambiguous, but don't support for now 169 | return False 170 | return True 171 | 172 | 173 | def _parse_typed_value(ty, s): 174 | if ty is str: 175 | return s 176 | elif ty in (int, float): 177 | return ty(s) 178 | elif ty is bool: 179 | if s in ('t', 'true', 'True'): 180 | return True 181 | elif s in ('f', 'false', 'False'): 182 | return False 183 | else: 184 | raise ValueError(f"Invalid bool '{s}'") 185 | elif ty is type(None): 186 | if s in ('None', 'none', ''): 187 | return None 188 | else: 189 | raise ValueError(f"Invalid None value '{s}'") 190 | elif is_hparam_type(ty): 191 | if s in ('on', 'off'): 192 | # The class will be constructed later 193 | return s 194 | else: 195 | raise ValueError(f"Invalid hparam class value '{s}'") 196 | elif _is_union_type(ty): 197 | if not _can_distinguish_unambiguously(ty.__args__): 198 | raise TypeError(f"Can't always unambiguously parse a value of union '{ty}'") 199 | for ty_option in ty.__args__: 200 | try: 201 | return _parse_typed_value(ty_option, s) 202 | except ValueError: 203 | continue 204 | raise ValueError(f"Couldn't parse '{s}' as any of the types in '{ty}'") 205 | else: 206 | raise ValueError(f"Unsupported hparam type '{ty}'") 207 | 208 | 209 | def _get_field(data, fieldname): 210 | matching_fields = [f for f in fields(data) if f.name == fieldname] 211 | if len(matching_fields) != 1: 212 | raise AttributeError(f"couldn't find field '{fieldname}' in {data}") 213 | return matching_fields[0] 214 | 215 | 216 | def _update_disjoint(dst: dict, src: dict): 217 | for k, v in src.items(): 218 | assert k not in dst 219 | dst[k] = v 220 | 221 | 222 | @lru_cache() 223 | def _type_map(ty, separator): 224 | typemap = {} 225 | for f in fields(ty): 226 | typemap[f.name] = f.type 227 | if is_hparam_type(f.type): 228 | nested = _type_map(f.type, separator=separator) 229 | elif _is_union_type(f.type): 230 | nested = {} 231 | for ty_option in f.type.__args__: 232 | if is_hparam_type(ty_option): 233 | _update_disjoint(nested, _type_map(ty_option, separator=separator)) 234 | else: 235 | nested = {} 236 | _update_disjoint(typemap, {f'{f.name}{separator}{k}': t for k, t in nested.items()}) 237 | return typemap 238 | 239 | 240 | def _type_to_class(ty): 241 | """Extract a constructible class from a type. For instance, `typing.Optional[int]` gives `int`""" 242 | if _is_union_type(ty): 243 | # Only typing.Optional supported: must be of form typing.Union[ty, None] 244 | assert len(ty.__args__) == 2 245 | assert ty.__args__[1] is type(None) 246 | return ty.__args__[0] 247 | else: 248 | return ty 249 | 250 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/launch.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import os 3 | import subprocess 4 | from functools import partial 5 | 6 | import cloudpickle 7 | import fire 8 | 9 | def launch(name, f, *, namespace='safety', mode='local', mpi=1) -> None: 10 | if mode == 'local': 11 | with open('/tmp/pickle_fn', 'wb') as file: 12 | cloudpickle.dump(f, file) 13 | 14 | subprocess.check_call(['mpiexec', '-n', str(mpi), 'python', '-c', 'import sys; import pickle; pickle.loads(open("/tmp/pickle_fn", "rb").read())()']) 15 | return 16 | raise Exception('Other modes unimplemented!') 17 | 18 | def parallel(jobs, mode): 19 | if mode == 'local': 20 | assert len(jobs) == 1, "Cannot run jobs in parallel locally" 21 | for job in jobs: 22 | job() 23 | else: 24 | with concurrent.futures.ThreadPoolExecutor() as executor: 25 | futures = [executor.submit(job) for job in jobs] 26 | for f in futures: 27 | f.result() 28 | 29 | def launch_trials(name, fn, trials, hparam_class, extra_hparams=None, dry_run=False, mpi=1, mode='local', save_dir=None): 30 | jobs = [] 31 | for trial in trials: 32 | descriptors = [] 33 | kwargs = {} 34 | for k, v, s in trial: 35 | if k is not None: 36 | if k in kwargs: 37 | print(f'WARNING: overriding key {k} from {kwargs[k]} to {v}') 38 | kwargs[k] = v 39 | if s.get('descriptor'): 40 | descriptors.append(str(s['descriptor'])) 41 | hparams = hparam_class() 42 | hparams.override_from_dict(kwargs) 43 | if extra_hparams: 44 | hparams.override_from_str_dict(extra_hparams) 45 | job_name = (name + '/' + '-'.join(descriptors)).rstrip('/') 46 | hparams.validate() 47 | if dry_run: 48 | print(f"{job_name}: {kwargs}") 49 | else: 50 | if save_dir: 51 | hparams.run.save_dir = os.path.join(save_dir, job_name) 52 | trial_fn = partial(fn, hparams) 53 | jobs.append(partial(launch, job_name, trial_fn, mpi=mpi, mode=mode)) 54 | 55 | parallel(jobs, mode=mode) 56 | 57 | def main(commands_dict): 58 | """Similar to fire.Fire, but with support for multiple commands without having a class.""" 59 | class _Commands: 60 | def __init__(self): 61 | for name, cmd in commands_dict.items(): 62 | setattr(self, name, cmd) 63 | fire.Fire(_Commands) 64 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/test_core_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """utils tests""" 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from lm_human_preferences.utils import core as utils 8 | 9 | 10 | def test_exact_div(): 11 | assert utils.exact_div(12, 4) == 3 12 | assert utils.exact_div(12, 3) == 4 13 | try: 14 | utils.exact_div(7, 3) 15 | assert False 16 | except ValueError: 17 | pass 18 | 19 | 20 | def test_ceil_div(): 21 | for b in range(1, 10 + 1): 22 | for a in range(-10, 10 + 1): 23 | assert utils.ceil_div(a, b) == int(np.ceil(a / b)) 24 | 25 | 26 | def test_expand_tile(): 27 | np.random.seed(7) 28 | size = 11 29 | with tf.Session(): 30 | for shape in (), (7,), (3, 5): 31 | data = np.asarray(np.random.randn(*shape), dtype=np.float32) 32 | x = tf.constant(data) 33 | for axis in range(-len(shape) - 1, len(shape) + 1): 34 | y = utils.expand_tile(x, size, axis=axis).eval() 35 | assert np.all(np.expand_dims(data, axis=axis) == y) 36 | 37 | 38 | def test_sample_buffer(): 39 | capacity = 100 40 | batch = 17 41 | lots = 100 42 | with tf.Graph().as_default(), tf.Session() as sess: 43 | buffer = utils.SampleBuffer(capacity=capacity, schemas=dict(x=utils.Schema(tf.int32, ()))) 44 | tf.variables_initializer(tf.global_variables() + tf.local_variables()).run() 45 | i_p = tf.placeholder(dtype=tf.int32, shape=()) 46 | add = buffer.add(x=batch * i_p + tf.range(batch)) 47 | sample = buffer.sample(lots, seed=7)['x'] 48 | all_data_1 = buffer.data() 49 | all_data_2 = buffer.read(tf.range(buffer.size())) 50 | for i in range(20): 51 | add.run(feed_dict={i_p: i}) 52 | samples = sample.eval() 53 | hi = batch * (i + 1) 54 | lo = max(0, hi - capacity) 55 | assert lo <= samples.min() <= lo + 3 56 | assert hi - 5 <= samples.max() < hi 57 | np.testing.assert_equal(sess.run(all_data_1), sess.run(all_data_2)) 58 | 59 | 60 | def test_where(): 61 | with tf.Session(): 62 | assert np.all(utils.where([False, True], 7, 8).eval() == [8, 7]) 63 | assert np.all(utils.where([False, True, True], [1, 2, 3], 8).eval() == [8, 2, 3]) 64 | assert np.all(utils.where([False, False, True], 8, [1, 2, 3]).eval() == [1, 2, 8]) 65 | assert np.all(utils.where([False, True], [[1, 2], [3, 4]], -1).eval() == [[-1, -1], [3, 4]]) 66 | assert np.all(utils.where([False, True], -1, [[1, 2], [3, 4]]).eval() == [[1, 2], [-1, -1]]) 67 | 68 | 69 | def test_map_flat(): 70 | with tf.Session() as sess: 71 | inputs = [2], [3, 5], [[7, 11], [13, 17]] 72 | inputs = map(np.asarray, inputs) 73 | outputs = sess.run(utils.map_flat(tf.square, inputs)) 74 | for i, o in zip(inputs, outputs): 75 | assert np.all(i * i == o) 76 | 77 | 78 | def test_map_flat_bits(): 79 | with tf.Session() as sess: 80 | inputs = [2], [3, 5], [[7, 11], [13, 17]], [True, False, True] 81 | dtypes = np.uint8, np.uint16, np.int32, np.int64, np.bool 82 | inputs = [np.asarray(i, dtype=d) for i, d in zip(inputs, dtypes)] 83 | outputs = sess.run(utils.map_flat_bits(lambda x: x + 1, inputs)) 84 | 85 | def tweak(n): 86 | return n + sum(2 ** (8 * i) for i in range(n.dtype.itemsize)) 87 | 88 | for i, o in zip(inputs, outputs): 89 | assert np.all(tweak(i) == o) 90 | 91 | 92 | def test_cumulative_max(): 93 | np.random.seed(7) 94 | with tf.Session().as_default(): 95 | for x in [ 96 | np.random.randn(10), 97 | np.random.randn(11, 7), 98 | np.random.randint(-10, 10, size=10), 99 | np.random.randint(-10, 10, size=(12, 8)), 100 | np.random.randint(-10, 10, size=(3, 3, 4)), 101 | ]: 102 | assert np.all(utils.cumulative_max(x).eval() == np.maximum.accumulate(x, axis=-1)) 103 | 104 | 105 | def test_index_each(): 106 | np.random.seed(7) 107 | x = np.random.randn(7, 11) 108 | i = np.random.randint(x.shape[1], size=x.shape[0]) 109 | y = utils.index_each(x, i) 110 | 111 | x2 = np.random.randn(3, 2, 4) 112 | i2 = np.random.randint(x2.shape[1], size=x2.shape[0]) 113 | y2 = utils.index_each(x2, i2) 114 | 115 | x3 = np.random.randn(5, 9) 116 | i3 = np.random.randint(x3.shape[1], size=(x3.shape[0], 2)) 117 | y3 = utils.index_each(x3, i3) 118 | with tf.Session(): 119 | assert np.all(y.eval() == x[np.arange(7), i]) 120 | assert np.all(y2.eval() == x2[np.arange(3), i2]) 121 | y3val = y3.eval() 122 | assert np.all(y3val[:,0] == x3[np.arange(5), i3[:,0]]) 123 | assert np.all(y3val[:,1] == x3[np.arange(5), i3[:,1]]) 124 | 125 | 126 | def test_index_each_many(): 127 | np.random.seed(7) 128 | x = np.random.randn(7, 11) 129 | i = np.random.randint(x.shape[1], size=[x.shape[0],3]) 130 | y = utils.index_each(x, i) 131 | with tf.Session(): 132 | assert np.all(y.eval() == x[np.arange(7)[:,None], i]) 133 | 134 | 135 | @utils.graph_function(x=utils.Schema(tf.int32, ()), y=utils.Schema(tf.int32, ())) 136 | def tf_sub(x, y=1): 137 | return tf.math.subtract(x, y) 138 | 139 | @utils.graph_function(x=utils.Schema(tf.int32, ()), y=dict(z1=utils.Schema(tf.int32, ()), z2=utils.Schema(tf.int32, ()))) 140 | def tf_sub_2(x, y): 141 | return tf.math.subtract(x, y['z1']) - y['z2'] 142 | 143 | def test_graph_function(): 144 | with tf.Session().as_default(): 145 | assert tf_sub(3) == 2 146 | assert tf_sub(x=3) == 2 147 | assert tf_sub(5, 2) == 3 148 | assert tf_sub(y=2, x=5) == 3 149 | assert tf_sub_2(5, dict(z1=1, z2=2)) == 2 150 | 151 | def test_top_k(): 152 | with tf.Session().as_default(): 153 | logits = tf.constant([[[1,1.01,1.001,0,0,0,2]]], dtype=tf.float32) 154 | np.testing.assert_allclose( 155 | utils.take_top_k_logits(logits, 1).eval(), 156 | [[[-1e10,-1e10,-1e10,-1e10,-1e10,-1e10,2]]] 157 | ) 158 | np.testing.assert_allclose( 159 | utils.take_top_k_logits(logits, 2).eval(), 160 | [[[-1e10,1.01,-1e10,-1e10,-1e10,-1e10,2]]] 161 | ) 162 | np.testing.assert_allclose( 163 | utils.take_top_k_logits(logits, 3).eval(), 164 | [[[-1e10,1.01,1.001,-1e10,-1e10,-1e10,2]]] 165 | ) 166 | np.testing.assert_allclose( 167 | utils.take_top_k_logits(logits, 4).eval(), 168 | [[[1,1.01,1.001,-1e10,-1e10,-1e10,2]]] 169 | ) 170 | np.testing.assert_allclose( 171 | utils.take_top_k_logits(logits, 5).eval(), 172 | [[[1,1.01,1.001,0,0,0,2]]] 173 | ) 174 | 175 | 176 | def test_top_p(): 177 | with tf.Session().as_default(): 178 | logits = tf.constant([[[1,1.01,1.001,0,0,0,2]]], dtype=tf.float32) 179 | np.testing.assert_allclose( 180 | utils.take_top_p_logits(logits, 1).eval(), 181 | logits.eval() 182 | ) 183 | np.testing.assert_allclose( 184 | utils.take_top_p_logits(logits, 0).eval(), 185 | [[[-1e10,-1e10,-1e10,-1e10,-1e10,-1e10,2]]] 186 | ) 187 | np.testing.assert_allclose( 188 | utils.take_top_p_logits(logits, 0.7).eval(), 189 | [[[-1e10,1.01,1.001,-1e10,-1e10,-1e10,2]]] 190 | ) 191 | np.testing.assert_allclose( 192 | utils.take_top_p_logits(logits, 0.6).eval(), 193 | [[[-1e10,1.01,-1e10,-1e10,-1e10,-1e10,2]]] 194 | ) 195 | np.testing.assert_allclose( 196 | utils.take_top_p_logits(logits, 0.5).eval(), 197 | [[[-1e10,-1e10,-1e10,-1e10,-1e10,-1e10,2]]] 198 | ) 199 | 200 | def test_safe_zip(): 201 | assert list(utils.safe_zip([1, 2], [3, 4])) == [(1, 3), (2, 4)] 202 | try: 203 | utils.safe_zip([1, 2], [3, 4, 5]) 204 | assert False 205 | except ValueError: 206 | pass 207 | 208 | 209 | if __name__ == '__main__': 210 | test_sample_buffer() 211 | test_cumulative_max() 212 | test_where() 213 | test_index_each() 214 | test_graph_function() 215 | test_top_k() 216 | test_top_p() 217 | test_safe_zip() 218 | -------------------------------------------------------------------------------- /lm_human_preferences/utils/test_hyperparams.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from dataclasses import dataclass, field 3 | from typing import Optional 4 | 5 | import pytest 6 | 7 | from lm_human_preferences.utils import hyperparams 8 | 9 | 10 | @dataclass 11 | class Simple(hyperparams.HParams): 12 | mandatory_nodefault: int = None 13 | mandatory_withdefault: str = "foo" 14 | optional_nodefault: Optional[int] = None 15 | fun: bool = True 16 | 17 | def test_simple_works(): 18 | hp = Simple() 19 | hp.override_from_str("mandatory_nodefault=3,optional_nodefault=None,fun=false") 20 | hp.validate() 21 | assert hp.mandatory_nodefault == 3 22 | assert hp.mandatory_withdefault == "foo" 23 | assert hp.optional_nodefault is None 24 | assert not hp.fun 25 | 26 | def test_simple_failures(): 27 | hp = Simple() 28 | with pytest.raises(TypeError): 29 | hp.validate() # mandatory_nodefault unset 30 | with pytest.raises(ValueError): 31 | hp.override_from_str("mandatory_nodefault=abc") 32 | with pytest.raises(AttributeError): 33 | hp.override_from_str("nonexistent_field=7.0") 34 | with pytest.raises(ValueError): 35 | hp.override_from_str("fun=?") 36 | 37 | @dataclass 38 | class Nested(hyperparams.HParams): 39 | first: bool = False 40 | simple_1: Simple = field(default_factory=Simple) 41 | simple_2: Optional[Simple] = None 42 | 43 | def test_nested(): 44 | hp = Nested() 45 | hp.override_from_str("simple_1.mandatory_nodefault=8,simple_2=on,simple_2.mandatory_withdefault=HELLO") 46 | with pytest.raises(TypeError): 47 | hp.validate() # simple_2.mandatory_nodefault unset 48 | hp.override_from_dict({'simple_2/mandatory_nodefault': 7, 'simple_1/optional_nodefault': 55}, separator='/') 49 | hp.validate() 50 | assert hp.simple_1.mandatory_nodefault == 8 51 | assert hp.simple_1.mandatory_withdefault == "foo" 52 | assert hp.simple_1.optional_nodefault == 55 53 | assert hp.simple_2.mandatory_nodefault == 7 54 | assert hp.simple_2.mandatory_withdefault == "HELLO" 55 | assert hp.simple_2.optional_nodefault is None 56 | 57 | hp.override_from_str("simple_2=off") 58 | hp.validate() 59 | assert hp.simple_2 is None 60 | 61 | with pytest.raises((TypeError, AttributeError)): 62 | hp.override_from_str("simple_2.fun=True") 63 | with pytest.raises(ValueError): 64 | hp.override_from_str("simple_2=BADVAL") 65 | 66 | def test_nested_dict(): 67 | hp = Nested() 68 | hp.override_from_nested_dict( 69 | {'simple_1': {'mandatory_nodefault': 8}, 'simple_2': {'mandatory_withdefault': "HELLO"}}) 70 | with pytest.raises(TypeError): 71 | hp.validate() # simple_2.mandatory_nodefault unset 72 | hp.override_from_nested_dict( 73 | {'simple_2': {'mandatory_nodefault': 7}, 'simple_1': {'optional_nodefault': 55}, 'first': True}) 74 | hp.validate() 75 | assert hp.to_nested_dict() == { 76 | 'first': True, 77 | 'simple_1': { 78 | 'mandatory_nodefault': 8, 79 | 'mandatory_withdefault': "foo", 80 | 'optional_nodefault': 55, 81 | 'fun': True, 82 | }, 83 | 'simple_2': { 84 | 'mandatory_nodefault': 7, 85 | 'mandatory_withdefault': "HELLO", 86 | 'optional_nodefault': None, 87 | 'fun': True, 88 | }, 89 | } 90 | 91 | def test_nested_order(): 92 | hp = Nested() 93 | # Either order should work 94 | hp.override_from_str_dict(OrderedDict([('simple_2.fun', 'True'), ('simple_2', 'on')])) 95 | hp.override_from_str_dict(OrderedDict([('simple_2', 'on'), ('simple_2.fun', 'True')])) 96 | 97 | @dataclass 98 | class Deeply(hyperparams.HParams): 99 | nested: Nested = None 100 | 101 | def test_deeply_nested(): 102 | hp = Deeply() 103 | hp.override_from_str("nested.simple_2=on") 104 | assert hp.nested is not None 105 | assert hp.nested.simple_2 is not None 106 | 107 | hp = Deeply() 108 | hp.override_from_dict({'nested.simple_2': 'on'}) 109 | assert hp.nested is not None 110 | assert hp.nested.simple_2 is not None 111 | 112 | def test_set_order(): 113 | hp = Deeply() 114 | hp.override_from_dict(OrderedDict([('nested.first', True), ('nested.simple_1', 'on')])) 115 | assert hp.nested.first is True 116 | 117 | hp = Deeply() 118 | hp.override_from_dict(OrderedDict([('nested.simple_1', 'on'), ('nested.first', True)])) 119 | assert hp.nested.first is True 120 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | from functools import partial 5 | 6 | from mpi4py import MPI 7 | import tensorflow as tf 8 | 9 | from lm_human_preferences.utils import launch, hyperparams 10 | from lm_human_preferences.utils import core as utils 11 | from lm_human_preferences.policy import Policy 12 | from lm_human_preferences.language import trained_models 13 | from lm_human_preferences import lm_tasks 14 | from lm_human_preferences import train_policy 15 | 16 | def sample_policy(save_dir=None, savescope='policy', temperature=1.0, seed=None, batch_size=4, nsamples=0): 17 | hparams = train_policy.HParams() 18 | hparams.override_from_json_file(os.path.join(save_dir, 'train_policy_hparams.json')) 19 | print('hparams', hparams) 20 | task = hparams.task 21 | 22 | comm = MPI.COMM_WORLD 23 | nsamples_per_rank = utils.exact_div(nsamples, comm.Get_size()) 24 | with tf.Graph().as_default(): 25 | m = trained_models.TrainedModel(name='sample', savedir=os.path.join(save_dir, 'policy'), scope='policy') 26 | encoder = m.encoding.get_encoder() 27 | hyperparams.dump(m.hparams(), name='model_hparams') 28 | 29 | utils.set_mpi_seed(seed) 30 | 31 | policy = Policy( 32 | m, scope='policy', 33 | is_root=True, # just init on every rank, simplifies code 34 | embed_queries=lm_tasks.query_formatter(task, encoder), 35 | temperature=temperature, 36 | ) 37 | 38 | query_sampler = lm_tasks.make_query_sampler( 39 | hparams=task, encoder=encoder, comm=comm, 40 | batch_size=batch_size, mode='test' 41 | ) 42 | 43 | init_ops = tf.group( 44 | tf.global_variables_initializer(), 45 | tf.local_variables_initializer(), 46 | ) 47 | 48 | with utils.mpi_session() as sess: 49 | init_ops.run() 50 | @utils.graph_function() 51 | def sample_queries(): 52 | return query_sampler()['tokens'] 53 | 54 | tf.get_default_graph().finalize() 55 | 56 | generated = 0 57 | while nsamples_per_rank == 0 or generated < nsamples_per_rank: 58 | queries = sample_queries() 59 | rollouts = policy.respond(queries, length=task.response_length) 60 | assert len(queries.tolist()) == batch_size 61 | assert len(rollouts['responses'].tolist()) == batch_size 62 | for q, r in zip(queries.tolist(), rollouts['responses'].tolist()): 63 | print('=' * 80) 64 | print(encoder.decode(q).replace("\n", "⏎")) 65 | print(encoder.decode(r).replace("\n", "⏎")) 66 | generated += batch_size 67 | 68 | def launch_sample(mode='local', mpi=8, **kwargs): 69 | launch.launch('sample', partial(sample_policy, **kwargs), mode=mode, mpi=mpi) 70 | 71 | if __name__ == '__main__': 72 | launch.main(dict( 73 | sample=launch_sample, 74 | )) 75 | 76 | """ 77 | ./sample.py sample --save_dir gs://jeffwu-rcall/results/safety/lmhf-sent-69c5170-1909161359/ --mpi 8 78 | """ 79 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pkg_resources 4 | 5 | from setuptools import setup, find_packages 6 | 7 | os.environ['CC'] = 'g++' 8 | 9 | setup(name='lm_human_preferences', 10 | version='0.0.1', 11 | packages=find_packages(include=['lm_human_preferences']), 12 | include_package_data=True, 13 | ) 14 | --------------------------------------------------------------------------------