├── .gitignore
├── LICENSE
├── Pipfile
├── Pipfile.lock
├── README.md
├── launch.py
├── lm_human_preferences
├── datasets
│ ├── books.py
│ ├── cnndm.py
│ └── tldr.py
├── label_types.py
├── language
│ ├── datasets.py
│ ├── encodings.py
│ ├── model.py
│ ├── sample.py
│ ├── test_model.py
│ ├── test_sample.py
│ └── trained_models.py
├── lm_tasks.py
├── policy.py
├── rewards.py
├── test_train_policy.py
├── test_train_reward.py
├── train_policy.py
├── train_reward.py
└── utils
│ ├── combos.py
│ ├── core.py
│ ├── gcs.py
│ ├── hyperparams.py
│ ├── launch.py
│ ├── test_core_utils.py
│ └── test_hyperparams.py
├── sample.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .mypy_cache
3 | *.egg-info/
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Pipfile:
--------------------------------------------------------------------------------
1 | [[source]]
2 | name = "pypi"
3 | url = "https://pypi.org/simple"
4 | verify_ssl = true
5 |
6 | [dev-packages]
7 |
8 | [packages]
9 | cloudpickle = "==1.2.1"
10 | dataclasses = "==0.6.0"
11 | fire = "==0.1.3"
12 | ftfy = "==5.4.1"
13 | google-api-python-client = "==1.7.8"
14 | google-cloud-storage = "==1.13.0"
15 | mpi4py = "==3.0.2"
16 | mypy = "==0.580"
17 | numpy = "==1.16.2"
18 | pytest-instafail = "==0.3.0"
19 | pytest-timeout = "==1.2.0"
20 | pytest = "==3.5.0"
21 | pytz = "==2019.1"
22 | regex = "==2017.4.5"
23 | requests = "==2.18.0"
24 | tqdm = "==4.31.1"
25 | typeguard = ">=2.2.2"
26 | lm-human-preferences = {editable = true,path = "."}
27 |
28 | [requires]
29 | python_version = "3.7"
30 |
--------------------------------------------------------------------------------
/Pipfile.lock:
--------------------------------------------------------------------------------
1 | {
2 | "_meta": {
3 | "hash": {
4 | "sha256": "aca3fc5344bba2aa6f9d399ce2323f3f0b72dd912e7f105e4139526101e2607a"
5 | },
6 | "pipfile-spec": 6,
7 | "requires": {
8 | "python_version": "3.7"
9 | },
10 | "sources": [
11 | {
12 | "name": "pypi",
13 | "url": "https://pypi.org/simple",
14 | "verify_ssl": true
15 | }
16 | ]
17 | },
18 | "default": {
19 | "attrs": {
20 | "hashes": [
21 | "sha256:69c0dbf2ed392de1cb5ec704444b08a5ef81680a61cb899dc08127123af36a79",
22 | "sha256:f0b870f674851ecbfbbbd364d6b5cbdff9dcedbc7f3f5e18a6891057f21fe399"
23 | ],
24 | "version": "==19.1.0"
25 | },
26 | "cachetools": {
27 | "hashes": [
28 | "sha256:428266a1c0d36dc5aca63a2d7c5942e88c2c898d72139fca0e97fdd2380517ae",
29 | "sha256:8ea2d3ce97850f31e4a08b0e2b5e6c34997d7216a9d2c98e0f3978630d4da69a"
30 | ],
31 | "version": "==3.1.1"
32 | },
33 | "certifi": {
34 | "hashes": [
35 | "sha256:e4f3620cfea4f83eedc95b24abd9cd56f3c4b146dd0177e83a21b4eb49e21e50",
36 | "sha256:fd7c7c74727ddcf00e9acd26bba8da604ffec95bf1c2144e67aff7a8b50e6cef"
37 | ],
38 | "version": "==2019.9.11"
39 | },
40 | "chardet": {
41 | "hashes": [
42 | "sha256:84ab92ed1c4d4f16916e05906b6b75a6c0fb5db821cc65e70cbd64a3e2a5eaae",
43 | "sha256:fc323ffcaeaed0e0a02bf4d117757b98aed530d9ed4531e3e15460124c106691"
44 | ],
45 | "version": "==3.0.4"
46 | },
47 | "cloudpickle": {
48 | "hashes": [
49 | "sha256:603244e0f552b72a267d47a7d9b347b27a3430f58a0536037a290e7e0e212ecf",
50 | "sha256:b8ba7e322f2394b9bbbdc1c976e6442c2c02acc784cb9e553cee9186166a6890"
51 | ],
52 | "index": "pypi",
53 | "version": "==1.2.1"
54 | },
55 | "dataclasses": {
56 | "hashes": [
57 | "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f",
58 | "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"
59 | ],
60 | "index": "pypi",
61 | "version": "==0.6.0"
62 | },
63 | "fire": {
64 | "hashes": [
65 | "sha256:c299d16064ff81cbb649b65988300d4a28b71ecfb789d1fb74d99ea98ae4d2eb"
66 | ],
67 | "index": "pypi",
68 | "version": "==0.1.3"
69 | },
70 | "ftfy": {
71 | "hashes": [
72 | "sha256:619e68f9844cadd03e0d835e9b6790b2399357100c57fddae14d93a8de81e114"
73 | ],
74 | "index": "pypi",
75 | "version": "==5.4.1"
76 | },
77 | "google-api-core": {
78 | "hashes": [
79 | "sha256:2c23fbc81c76b941ffb71301bb975ed66a610e9b03f918feacd1ed59cf43a6ec",
80 | "sha256:b2b91107bcc3b981633c89602b46451f6474973089febab3ee51c49cb7ae6a1f"
81 | ],
82 | "version": "==1.14.2"
83 | },
84 | "google-api-python-client": {
85 | "hashes": [
86 | "sha256:06907006ed5ce831018f03af3852d739c0b2489cdacfda6971bcc2075c762858",
87 | "sha256:937eabdc3940977f712fa648a096a5142766b6d0a0f58bc603e2ac0687397ef0"
88 | ],
89 | "index": "pypi",
90 | "version": "==1.7.8"
91 | },
92 | "google-auth": {
93 | "hashes": [
94 | "sha256:0f7c6a64927d34c1a474da92cfc59e552a5d3b940d3266606c6a28b72888b9e4",
95 | "sha256:20705f6803fd2c4d1cc2dcb0df09d4dfcb9a7d51fd59e94a3a28231fd93119ed"
96 | ],
97 | "version": "==1.6.3"
98 | },
99 | "google-auth-httplib2": {
100 | "hashes": [
101 | "sha256:098fade613c25b4527b2c08fa42d11f3c2037dda8995d86de0745228e965d445",
102 | "sha256:f1c437842155680cf9918df9bc51c1182fda41feef88c34004bd1978c8157e08"
103 | ],
104 | "version": "==0.0.3"
105 | },
106 | "google-cloud-core": {
107 | "hashes": [
108 | "sha256:0090df83dbc5cb2405fa90844366d13176d1c0b48181c1807ab15f53be403f73",
109 | "sha256:89e8140a288acec20c5e56159461d3afa4073570c9758c05d4e6cb7f2f8cc440"
110 | ],
111 | "version": "==0.28.1"
112 | },
113 | "google-cloud-storage": {
114 | "hashes": [
115 | "sha256:936c859c47f8e94fd0005e98235a10d5e75828d2c6c3a8caacae18344a572a0a",
116 | "sha256:fc32b9be41a45016ba2387e3ad23e70ccba399d626ef596409316f7cee477956"
117 | ],
118 | "index": "pypi",
119 | "version": "==1.13.0"
120 | },
121 | "google-resumable-media": {
122 | "hashes": [
123 | "sha256:5fd2e641f477e50be925a55bcfdf0b0cb97c2b92aacd7b15c1d339f70d55c1c7",
124 | "sha256:cdeb8fbb3551a665db921023603af2f0d6ac59ad8b48259cb510b8799505775f"
125 | ],
126 | "version": "==0.4.1"
127 | },
128 | "googleapis-common-protos": {
129 | "hashes": [
130 | "sha256:e61b8ed5e36b976b487c6e7b15f31bb10c7a0ca7bd5c0e837f4afab64b53a0c6"
131 | ],
132 | "version": "==1.6.0"
133 | },
134 | "httplib2": {
135 | "hashes": [
136 | "sha256:6901c8c0ffcf721f9ce270ad86da37bc2b4d32b8802d4a9cec38274898a64044",
137 | "sha256:cf6f9d5876d796539ec922a2c9b9a7cad9bfd90f04badcdc3bcfa537168052c3"
138 | ],
139 | "version": "==0.13.1"
140 | },
141 | "idna": {
142 | "hashes": [
143 | "sha256:3cb5ce08046c4e3a560fc02f138d0ac63e00f8ce5901a56b32ec8b7994082aab",
144 | "sha256:cc19709fd6d0cbfed39ea875d29ba6d4e22c0cebc510a76d6302a28385e8bb70"
145 | ],
146 | "version": "==2.5"
147 | },
148 | "lm-human-preferences": {
149 | "editable": true,
150 | "path": "."
151 | },
152 | "more-itertools": {
153 | "hashes": [
154 | "sha256:409cd48d4db7052af495b09dec721011634af3753ae1ef92d2b32f73a745f832",
155 | "sha256:92b8c4b06dac4f0611c0729b2f2ede52b2e1bac1ab48f089c7ddc12e26bb60c4"
156 | ],
157 | "version": "==7.2.0"
158 | },
159 | "mpi4py": {
160 | "hashes": [
161 | "sha256:014076ffa558bc8d1d82c820c94848ae5f9fe1aab3c9e0a18d80e0c339a4bbe4",
162 | "sha256:020dbf8c8d2b95b6098c6a66352907afed1c449d811fd085247d5ee244890bb1",
163 | "sha256:06514c4205e1de84d04c780ab6aa8751121203dd246a45b120817c4444bed341",
164 | "sha256:0bcd7acb12c7e830267f9d3df13da0576ccf1603fb1c9f940e600ceefbe69200",
165 | "sha256:1c83daae9a99908109200b29c9cfd93e7c0dc9cad50bef15f0ea85642c288746",
166 | "sha256:39807cca8195b0c1e43dc9a3e1d80ef4b7cdc66a9f19a184ce7c28d8b42b7f4a",
167 | "sha256:45b5674d0d630c31bbb94abd9563202ecd83e72a2c54ee719b9813d3a5938767",
168 | "sha256:4f2f6f5cdece7a95b53bfc884ff9201e270ca386f8c53b54ff2bec799e5b8e0c",
169 | "sha256:5c1b377022a43e515812f6064d7b1ec01fd61027592aa16e5ad5e14f27f8db3a",
170 | "sha256:baa8a41f5bddbf581f521fc68db1a297fe24a0256c36bf7dd22fcb3e2cc93ea1",
171 | "sha256:c105ac976e1605a6883db06a37b0dfac497b210de6d8569dc6d23af33597f145",
172 | "sha256:e452b96ff879700dcbcef19d145190d56621419e4fbc73e43998b2e692dc6eeb",
173 | "sha256:f8d629d1e3e3b7b89cb99d0e3bc5505e76cc42089829807950d5c56606ed48e0"
174 | ],
175 | "index": "pypi",
176 | "version": "==3.0.2"
177 | },
178 | "mypy": {
179 | "hashes": [
180 | "sha256:3bd95a1369810f7693366911d85be9f0a0bd994f6cb7162b7a994e5ded90e3d9",
181 | "sha256:7247f9948d7cdaae9408a4ee1662a01853c24e668117b4419acf025b05fbe3ce"
182 | ],
183 | "index": "pypi",
184 | "version": "==0.580"
185 | },
186 | "numpy": {
187 | "hashes": [
188 | "sha256:1980f8d84548d74921685f68096911585fee393975f53797614b34d4f409b6da",
189 | "sha256:22752cd809272671b273bb86df0f505f505a12368a3a5fc0aa811c7ece4dfd5c",
190 | "sha256:23cc40313036cffd5d1873ef3ce2e949bdee0646c5d6f375bf7ee4f368db2511",
191 | "sha256:2b0b118ff547fecabc247a2668f48f48b3b1f7d63676ebc5be7352a5fd9e85a5",
192 | "sha256:3a0bd1edf64f6a911427b608a894111f9fcdb25284f724016f34a84c9a3a6ea9",
193 | "sha256:3f25f6c7b0d000017e5ac55977a3999b0b1a74491eacb3c1aa716f0e01f6dcd1",
194 | "sha256:4061c79ac2230594a7419151028e808239450e676c39e58302ad296232e3c2e8",
195 | "sha256:560ceaa24f971ab37dede7ba030fc5d8fa173305d94365f814d9523ffd5d5916",
196 | "sha256:62be044cd58da2a947b7e7b2252a10b42920df9520fc3d39f5c4c70d5460b8ba",
197 | "sha256:6c692e3879dde0b67a9dc78f9bfb6f61c666b4562fd8619632d7043fb5b691b0",
198 | "sha256:6f65e37b5a331df950ef6ff03bd4136b3c0bbcf44d4b8e99135d68a537711b5a",
199 | "sha256:7a78cc4ddb253a55971115f8320a7ce28fd23a065fc33166d601f51760eecfa9",
200 | "sha256:80a41edf64a3626e729a62df7dd278474fc1726836552b67a8c6396fd7e86760",
201 | "sha256:893f4d75255f25a7b8516feb5766c6b63c54780323b9bd4bc51cdd7efc943c73",
202 | "sha256:972ea92f9c1b54cc1c1a3d8508e326c0114aaf0f34996772a30f3f52b73b942f",
203 | "sha256:9f1d4865436f794accdabadc57a8395bd3faa755449b4f65b88b7df65ae05f89",
204 | "sha256:9f4cd7832b35e736b739be03b55875706c8c3e5fe334a06210f1a61e5c2c8ca5",
205 | "sha256:adab43bf657488300d3aeeb8030d7f024fcc86e3a9b8848741ea2ea903e56610",
206 | "sha256:bd2834d496ba9b1bdda3a6cf3de4dc0d4a0e7be306335940402ec95132ad063d",
207 | "sha256:d20c0360940f30003a23c0adae2fe50a0a04f3e48dc05c298493b51fd6280197",
208 | "sha256:d3b3ed87061d2314ff3659bb73896e622252da52558f2380f12c421fbdee3d89",
209 | "sha256:dc235bf29a406dfda5790d01b998a1c01d7d37f449128c0b1b7d1c89a84fae8b",
210 | "sha256:fb3c83554f39f48f3fa3123b9c24aecf681b1c289f9334f8215c1d3c8e2f6e5b"
211 | ],
212 | "index": "pypi",
213 | "version": "==1.16.2"
214 | },
215 | "pluggy": {
216 | "hashes": [
217 | "sha256:7f8ae7f5bdf75671a718d2daf0a64b7885f74510bcd98b1a0bb420eb9a9d0cff",
218 | "sha256:d345c8fe681115900d6da8d048ba67c25df42973bda370783cd58826442dcd7c",
219 | "sha256:e160a7fcf25762bb60efc7e171d4497ff1d8d2d75a3d0df7a21b76821ecbf5c5"
220 | ],
221 | "version": "==0.6.0"
222 | },
223 | "protobuf": {
224 | "hashes": [
225 | "sha256:00a1b0b352dc7c809749526d1688a64b62ea400c5b05416f93cfb1b11a036295",
226 | "sha256:01acbca2d2c8c3f7f235f1842440adbe01bbc379fa1cbdd80753801432b3fae9",
227 | "sha256:0a795bca65987b62d6b8a2d934aa317fd1a4d06a6dd4df36312f5b0ade44a8d9",
228 | "sha256:0ec035114213b6d6e7713987a759d762dd94e9f82284515b3b7331f34bfaec7f",
229 | "sha256:31b18e1434b4907cb0113e7a372cd4d92c047ce7ba0fa7ea66a404d6388ed2c1",
230 | "sha256:32a3abf79b0bef073c70656e86d5bd68a28a1fbb138429912c4fc07b9d426b07",
231 | "sha256:55f85b7808766e5e3f526818f5e2aeb5ba2edcc45bcccede46a3ccc19b569cb0",
232 | "sha256:64ab9bc971989cbdd648c102a96253fdf0202b0c38f15bd34759a8707bdd5f64",
233 | "sha256:64cf847e843a465b6c1ba90fb6c7f7844d54dbe9eb731e86a60981d03f5b2e6e",
234 | "sha256:917c8662b585470e8fd42f052661fc66d59fccaae450a60044307dcbf82a3335",
235 | "sha256:afed9003d7f2be2c3df20f64220c30faec441073731511728a2cb4cab4cd46a6",
236 | "sha256:bf8e05d638b585d1752c5a84247134a0350d3a8b73d3632489a014a9f6f1e758",
237 | "sha256:d831b047bd69becaf64019a47179eb22118a50dd008340655266a906c69c6417",
238 | "sha256:de2760583ed28749ff885789c1cbc6c9c06d6de92fc825740ab99deb2f25ea4d",
239 | "sha256:eabc4cf1bc19689af8022ba52fd668564a8d96e0d08f3b4732d26a64255216a4",
240 | "sha256:fcff6086c86fb1628d94ea455c7b9de898afc50378042927a59df8065a79a549"
241 | ],
242 | "version": "==3.9.1"
243 | },
244 | "py": {
245 | "hashes": [
246 | "sha256:64f65755aee5b381cea27766a3a147c3f15b9b6b9ac88676de66ba2ae36793fa",
247 | "sha256:dc639b046a6e2cff5bbe40194ad65936d6ba360b52b3c3fe1d08a82dd50b5e53"
248 | ],
249 | "version": "==1.8.0"
250 | },
251 | "pyasn1": {
252 | "hashes": [
253 | "sha256:62cdade8b5530f0b185e09855dd422bc05c0bbff6b72ff61381c09dac7befd8c",
254 | "sha256:a9495356ca1d66ed197a0f72b41eb1823cf7ea8b5bd07191673e8147aecf8604"
255 | ],
256 | "version": "==0.4.7"
257 | },
258 | "pyasn1-modules": {
259 | "hashes": [
260 | "sha256:43c17a83c155229839cc5c6b868e8d0c6041dba149789b6d6e28801c64821722",
261 | "sha256:e30199a9d221f1b26c885ff3d87fd08694dbbe18ed0e8e405a2a7126d30ce4c0"
262 | ],
263 | "version": "==0.2.6"
264 | },
265 | "pytest": {
266 | "hashes": [
267 | "sha256:6266f87ab64692112e5477eba395cfedda53b1933ccd29478e671e73b420c19c",
268 | "sha256:fae491d1874f199537fd5872b5e1f0e74a009b979df9d53d1553fd03da1703e1"
269 | ],
270 | "index": "pypi",
271 | "version": "==3.5.0"
272 | },
273 | "pytest-instafail": {
274 | "hashes": [
275 | "sha256:b4d5fc3ca81e530a8d0e15a7771dc14b06fc9a0930c4b3909a7f4527040572c3"
276 | ],
277 | "index": "pypi",
278 | "version": "==0.3.0"
279 | },
280 | "pytest-timeout": {
281 | "hashes": [
282 | "sha256:c29e3168f10897728059bd6b8ca20b28733d7fe6b8f6c09bb9d89f6146f27cb8",
283 | "sha256:c65a80c87074c17b6dfbe91cd856f260f84fbdad5df9bd79b1cfc26fe5c163f1"
284 | ],
285 | "index": "pypi",
286 | "version": "==1.2.0"
287 | },
288 | "pytz": {
289 | "hashes": [
290 | "sha256:303879e36b721603cc54604edcac9d20401bdbe31e1e4fdee5b9f98d5d31dfda",
291 | "sha256:d747dd3d23d77ef44c6a3526e274af6efeb0a6f1afd5a69ba4d5be4098c8e141"
292 | ],
293 | "index": "pypi",
294 | "version": "==2019.1"
295 | },
296 | "regex": {
297 | "hashes": [
298 | "sha256:19c4b0f68dd97b7116e590f47d60d97ab9e76966acc321b1d20dd87c2b64dff2",
299 | "sha256:1af6b820bec5ca82af87447af5a6dcc23b3ddc96b0184fd71666be0c24fb2a4f",
300 | "sha256:232dbc28a2562d92d713c3c1eb2b9276f3ebcbdb6d3e96ff68d0417a71926784",
301 | "sha256:3d26ce7e605a501509b68c343fc9d9e09f76c2e9e261df8183027bdc750c97ce",
302 | "sha256:52b590a41b9677314d02d9055edc33992db758b3d5167aa1365229a6a0c26a6d",
303 | "sha256:565f9aac9cd43b2351f7fcbc0d6056f8aebf4f6d049a17982085019ab9acdf28",
304 | "sha256:656984899644d3fe2e40533724f513a21127f77162a15dd5244af3c965152c63",
305 | "sha256:689c9d17c3ba02f52e8481a5c584c8c11ba27d6cc5f939efdd838ae0d0d1af41",
306 | "sha256:8a9d9db8ef1621ae51ea12acb5e503204b4586e05c6cfd418aecb9466a71bd87",
307 | "sha256:ad2beea450d551b11b47512ce920127d7c8645e528cc56dc9502c5973e8732f3",
308 | "sha256:b39867f577bc59b2fec9209facc513c761978e4ac63f4b73b9750a2c1501729e",
309 | "sha256:b6a7725a069be8f9dd09e1e500e5b57556b301942e21c8c712627f73ec048286",
310 | "sha256:b9e9b97696e75e826adac1920b13e7bac3a6a2128c085783abd208d73a278d70",
311 | "sha256:bf4896ed1ca2017153fc6b341bc8a0da8ca5480f85eebd7bfe58bbafceb4e728",
312 | "sha256:c3c2fe1e0d90f4c93be5b588480f05defd44f64c65767a657de69c4db4429a39",
313 | "sha256:d811874ed669165fe1059a54f860db5c6ab5f48100bf4945d915fd2f877b2531",
314 | "sha256:db616380b04e29e5709bc3ec0674e827dfed3d18e7d686c09537ab01506127c9",
315 | "sha256:efa66273b49dbd7a9f6a4d02d1a7d5bf353d568a89f7cd8927812daa9f83bb84",
316 | "sha256:f8feab5b517cdc65a61a50549e7dcfa0f61ab872a0034da1f6b8d61775178b6a"
317 | ],
318 | "index": "pypi",
319 | "version": "==2017.4.5"
320 | },
321 | "requests": {
322 | "hashes": [
323 | "sha256:5e88d64aa56ac0fda54e77fb9762ebc65879e171b746d5479a33c4082519d6c6",
324 | "sha256:cd0189f962787284bff715fddaad478eb4d9c15aa167bd64e52ea0f661e7ea5c"
325 | ],
326 | "index": "pypi",
327 | "version": "==2.18.0"
328 | },
329 | "rsa": {
330 | "hashes": [
331 | "sha256:14ba45700ff1ec9eeb206a2ce76b32814958a98e372006c8fb76ba820211be66",
332 | "sha256:1a836406405730121ae9823e19c6e806c62bbad73f890574fff50efa4122c487"
333 | ],
334 | "version": "==4.0"
335 | },
336 | "six": {
337 | "hashes": [
338 | "sha256:3350809f0555b11f552448330d0b52d5f24c91a322ea4a15ef22629740f3761c",
339 | "sha256:d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73"
340 | ],
341 | "version": "==1.12.0"
342 | },
343 | "tqdm": {
344 | "hashes": [
345 | "sha256:d385c95361699e5cf7622485d9b9eae2d4864b21cd5a2374a9c381ffed701021",
346 | "sha256:e22977e3ebe961f72362f6ddfb9197cc531c9737aaf5f607ef09740c849ecd05"
347 | ],
348 | "index": "pypi",
349 | "version": "==4.31.1"
350 | },
351 | "typed-ast": {
352 | "hashes": [
353 | "sha256:0cf0c406af2a6472a02254fe1ced40cb81a7c1215b7ceba88a3bb9c3a864f851",
354 | "sha256:1b784cd3c6778cd7b99afb41ddcaa1eb5b35a399210db7fcf24ed082670e0070",
355 | "sha256:2d7a322c1df6cccff2381c0475c1ebf82d3e9a331e48ed4ea89bbc72a8dedca6",
356 | "sha256:4304399ff89452871348f6fb7a7112454cd508fbe3eb49b5ed711cce9b99fe9e",
357 | "sha256:4658aebc30c0af80e63b579e917c04b592bdf10ef40da381b2fd179075b5d1b6",
358 | "sha256:471a7f12e55ad22f7a4bb2c3e62e39e3ab78008b24c61c48c9042e63b7359bb9",
359 | "sha256:57cb23412dac214383c6b6f0f7b0aec2d0c001a936af20f0b53542bbe4ba08a7",
360 | "sha256:5eb14e6b3aa5ff5d7e964b978a718227b5576b3965f1dd71dd055f71054233a5",
361 | "sha256:8219b6147af4d609096b6db2c797281e19fd3f7232ef35932bc74a812ff417a0",
362 | "sha256:8a7e9635cf0aaca04b2a4d4b3501c0dbc5c49a140b2e55b00e218d41ed2a69c8",
363 | "sha256:935157ada4aa115d61c59e759e43c5862b04d19ffe6fe5c9d735716587535cb7",
364 | "sha256:9525f4cbe3eb7b9e19a87c765ca9bbc1147ce18f75059e15138eb7fc59ce02e3",
365 | "sha256:99c140583eef6b50f3de4af44718a4fc63108671b29c468b5ff83ed383facf6d",
366 | "sha256:9e358ce6d4c43a90c15b99b76261adc852998680628c780f26fd64bc21adb9fa",
367 | "sha256:aaf63a024b54d2788cff3400de79009ee8a23594b581d4f33d90b7c67f8c05bd",
368 | "sha256:c3313b3fa1b6b722866eda370c14fd8f4962b6bcd1f6d43f42d6818a8b29d998",
369 | "sha256:c9342947e5f3480473d836754d69965a12ac2237d99ae85d1e3fdd1c1722669f",
370 | "sha256:cb1c7e5b3195103f5a784db7969fc55463cfae9b354e3b97cc219d32293d5e65",
371 | "sha256:d2d2cce74165cae2663167c921e331fb0eecfff2e93254dfdb16beb99716e519",
372 | "sha256:d6fc3b9fbf67d556223aa5493501022e1d585b9a1892fa87ba1257627763c461",
373 | "sha256:fa4eafaa57074958f065c2a6222d8f11162739f8c9db125472a1f04794a0b91d"
374 | ],
375 | "version": "==1.1.2"
376 | },
377 | "typeguard": {
378 | "hashes": [
379 | "sha256:5b90905662970cb47029cd5800b17b81608162ea2fcab7e5fd19bcc04a7d0b42",
380 | "sha256:5ecab47551c42a8090dcb914c550287a09caf599b4d47958445494f2822165aa"
381 | ],
382 | "index": "pypi",
383 | "version": "==2.5.0"
384 | },
385 | "uritemplate": {
386 | "hashes": [
387 | "sha256:01c69f4fe8ed503b2951bef85d996a9d22434d2431584b5b107b2981ff416fbd",
388 | "sha256:1b9c467a940ce9fb9f50df819e8ddd14696f89b9a8cc87ac77952ba416e0a8fd",
389 | "sha256:c02643cebe23fc8adb5e6becffe201185bf06c40bda5c0b4028a93f1527d011d"
390 | ],
391 | "version": "==3.0.0"
392 | },
393 | "urllib3": {
394 | "hashes": [
395 | "sha256:8ed6d5c1ff9d6ba84677310060d6a3a78ca3072ce0684cb3c645023009c114b1",
396 | "sha256:b14486978518ca0901a76ba973d7821047409d7f726f22156b24e83fd71382a5"
397 | ],
398 | "version": "==1.21.1"
399 | },
400 | "wcwidth": {
401 | "hashes": [
402 | "sha256:3df37372226d6e63e1b1e1eda15c594bca98a22d33a23832a90998faa96bc65e",
403 | "sha256:f4ebe71925af7b40a864553f761ed559b43544f8f71746c2d756c7fe788ade7c"
404 | ],
405 | "version": "==0.1.7"
406 | }
407 | },
408 | "develop": {}
409 | }
410 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | **Status:** Archive (code is provided as-is, no updates expected)
2 |
3 | **Status:** All references to `gs://lm-human-preferences/` were updated to `https://openaipublic.blob.core.windows.net/lm-human-preferences`, as we migrated from GCP to Azure. The code provided as is may no longer work. Pull requests welcome
4 |
5 | # lm-human-preferences
6 |
7 | This repository contains code for the paper [Fine-Tuning Language Models from Human Preferences](https://arxiv.org/abs/1909.08593). See also our [blog post](https://openai.com/blog/fine-tuning-gpt-2/).
8 |
9 | We provide code for:
10 | - Training reward models from human labels
11 | - Fine-tuning language models using those reward models
12 |
13 | It does not contain code for generating labels. However, we have released human labels collected for our experiments, at `gs://lm-human-preferences/labels`.
14 | For those interested, the question and label schemas are simple and documented in [`label_types.py`](./lm_human_preferences/label_types.py).
15 |
16 | The code has only been tested using the smallest GPT-2 model (124M parameters).
17 |
18 | ## Instructions
19 |
20 | This code has only been tested using Python 3.7.3. Training has been tested on GCE machines with 8 V100s, running Ubuntu 16.04, but development also works on Mac OS X.
21 |
22 | ### Installation
23 |
24 | - Install [pipenv](https://github.com/pypa/pipenv#installation).
25 |
26 | - Install [tensorflow](https://www.tensorflow.org/install/gpu): Install CUDA 10.0 and cuDNN 7.6.2, then `pipenv install tensorflow-gpu==1.13.1`. The code may technically run with tensorflow on CPU but will be very slow.
27 |
28 | - Install [`gsutil`](https://cloud.google.com/storage/docs/gsutil_install)
29 |
30 | - Clone this repo. Then:
31 | ```
32 | pipenv install
33 | ```
34 |
35 | - (Recommended) Install [`horovod`](https://github.com/horovod/horovod#install) to speed up the code, or otherwise substitute some fast implementation in the `mpi_allreduce_sum` function of [`core.py`](./lm_human_preferences/utils/core.py). Make sure to use pipenv for the install, e.g. `pipenv install horovod==0.18.1`.
36 |
37 | ### Running
38 |
39 | The following examples assume we are aiming to train a model to continue text in a physically descriptive way.
40 | You can read [`launch.py`](./launch.py) to see how the `descriptiveness` experiments and others are defined.
41 |
42 | Note that we provide pre-trained models, so you can skip directly to RL fine-tuning or even to sampling from a trained policy, if desired.
43 |
44 | #### Training a reward model
45 |
46 | To train a reward model, use a command such as
47 | ```
48 | experiment=descriptiveness
49 | reward_experiment_name=testdesc-$(date +%y%m%d%H%M)
50 | pipenv run ./launch.py train_reward $experiment $reward_experiment_name
51 | ```
52 |
53 | This will save outputs (and tensorboard event files) to the directory `/tmp/save/train_reward/$reward_experiment_name`. The directory can be changed via the `--save_dir` flag.
54 |
55 | #### Finetuning a language model
56 |
57 | Once you have trained a reward model, you can finetune against it.
58 |
59 | First, set
60 | ```
61 | trained_reward_model=/tmp/save/train_reward/$reward_experiment_name
62 | ```
63 | or if using our pretrained model,
64 | ```
65 | trained_reward_model=gs://lm-human-preferences/runs/descriptiveness/reward_model
66 | ```
67 |
68 | Then,
69 | ```
70 | experiment=descriptiveness
71 | policy_experiment_name=testdesc-$(date +%y%m%d%H%M)
72 | pipenv run ./launch.py train_policy $experiment $policy_experiment_name --rewards.trained_model $trained_reward_model --rewards.train_new_model 'off'
73 | ```
74 |
75 | This will save outputs (and tensorboard event files) to the directory `/tmp/save/train_policy/$policy_experiment_name`. The directory can be changed via the `--save_dir` flag.
76 |
77 | #### Both steps at once
78 |
79 | You can run a single command to train a reward model and then finetune against it
80 | ```
81 | experiment=descriptiveness
82 | experiment_name=testdesc-$(date +%y%m%d%H%M)
83 | pipenv run ./launch.py train_policy $experiment $experiment_name
84 | ```
85 |
86 | In this case, outputs are in the directory `/tmp/save/train_policy/$policy_experiment_name`, and the reward model is saved to a subdirectory `reward_model`. The directory can be changed via the `--save_dir` flag.
87 |
88 | #### Sampling from a trained policy
89 |
90 | Specify the policy to load:
91 | ```
92 | save_dir=/tmp/save/train_policy/$policy_experiment_name
93 | ```
94 | or if using our pretrained model,
95 | ```
96 | save_dir=gs://lm-human-preferences/runs/descriptiveness
97 | ```
98 |
99 | Then run:
100 | ```
101 | pipenv run ./sample.py sample --save_dir $save_dir --savescope policy
102 | ```
103 |
104 | Note that this script can run on less than 8 GPUs. You can pass the flag `--mpi 1`, for exapmle, if you only have one GPU.
105 |
106 | ## LICENSE
107 |
108 | [MIT](./LICENSE)
109 |
110 | ## Citation
111 |
112 | Please cite the paper with the following bibtex entry:
113 | ```
114 | @article{ziegler2019finetuning,
115 | title={Fine-Tuning Language Models from Human Preferences},
116 | author={Ziegler, Daniel M. and Stiennon, Nisan and Wu, Jeffrey and Brown, Tom B. and Radford, Alec and Amodei, Dario and Christiano, Paul and Irving, Geoffrey},
117 | journal={arXiv preprint arXiv:1909.08593},
118 | url={https://arxiv.org/abs/1909.08593},
119 | year={2019}
120 | }
121 | ```
122 |
--------------------------------------------------------------------------------
/launch.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | from lm_human_preferences.utils import launch
4 | from lm_human_preferences.utils.combos import bind, combos, each, label, options_shortdesc, bind_nested
5 | from lm_human_preferences import train_policy, train_reward
6 |
7 |
8 | books_task = combos(
9 | bind('query_length', 64),
10 | bind('query_dataset', 'books'),
11 | bind('response_length', 24),
12 | bind('start_text', '.'), # Start the context at the beginning of a sentence
13 | bind('end_text', '.'), # End the context at the end of a sentence.
14 | bind('truncate_token', 13), # Encoding of '.' -- end completions at the end of a sentence.
15 | bind('truncate_after', 16), # Make sure completions are at least 16 tokens long.
16 |
17 | bind('policy.temperature', 0.7),
18 | bind('policy.initial_model', '124M'),
19 | )
20 |
21 | summarize_cnndm_task = combos(
22 | bind('query_prefix', 'Article:\n\n'),
23 | bind('query_suffix', '\n\nTL;DR:'),
24 | bind('end_text', '\n'),
25 | bind('query_dataset', 'cnndm'),
26 | bind('query_length', 500),
27 | bind('response_length', 75),
28 | bind('start_text', None),
29 | bind('truncate_after', 55),
30 | bind('truncate_token', 198), # '\n'
31 |
32 | bind('policy.temperature', 0.5),
33 | bind('policy.initial_model', '124M'),
34 | )
35 |
36 | summarize_tldr_task = combos(
37 | bind('query_suffix', '\n\nTL;DR:'),
38 | bind('query_dataset', 'tldr'),
39 | bind('query_length', 500),
40 | bind('response_length', 75),
41 | bind('start_text', None),
42 | bind('truncate_after', 55),
43 | bind('truncate_token', 198), # '\n'
44 |
45 | bind('policy.temperature', 0.7),
46 | bind('policy.initial_model', '124M'),
47 | )
48 |
49 | def get_train_reward_experiments():
50 | _shared = combos(
51 | bind('labels.type', 'best_of_4'),
52 | bind('normalize_after', True),
53 | bind('normalize_before', True),
54 | bind('normalize_samples', 256),
55 | )
56 |
57 |
58 | _books_task = combos(
59 | bind_nested('task', books_task),
60 | _shared,
61 | bind('batch_size', 32),
62 | bind('lr', 5e-5),
63 | bind('rollout_batch_size', 512),
64 | )
65 |
66 | sentiment = combos(
67 | _books_task,
68 |
69 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/sentiment/offline_5k.json'),
70 | bind('labels.num_train', 4_992),
71 | bind('run.seed', 1)
72 | )
73 |
74 |
75 | descriptiveness = combos(
76 | _books_task,
77 |
78 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/descriptiveness/offline_5k.json'),
79 | bind('labels.num_train', 4_992),
80 | bind('run.seed', 1)
81 | )
82 |
83 | cnndm = combos(
84 | bind_nested('task', summarize_cnndm_task),
85 | _shared,
86 |
87 | # bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/cnndm/offline_60k.json'),
88 | # bind('labels.num_train', 60_000),
89 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/cnndm/online_45k.json'),
90 | bind('labels.num_train', 46_000),
91 |
92 | bind('batch_size', 2 * 8),
93 | bind('lr', 2.5e-5),
94 | bind('rollout_batch_size', 128),
95 | bind('run.seed', 1)
96 | )
97 |
98 | tldr = combos(
99 | bind_nested('task', summarize_tldr_task),
100 | _shared,
101 |
102 | # bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/tldr/offline_60k.json'),
103 | # bind('labels.num_train', 60_000),
104 | bind('labels.source', 'https://openaipublic.blob.core.windows.net/lm-human-preferences/labels/tldr/online_45k.json'),
105 | bind('labels.num_train', 46_000),
106 |
107 | bind('batch_size', 2 * 8),
108 | bind('lr', 2.5e-5),
109 | bind('rollout_batch_size', 128),
110 | bind('run.seed', 1)
111 | )
112 |
113 | return locals()
114 |
115 |
116 | def get_experiments():
117 | train_reward_experiments = get_train_reward_experiments()
118 |
119 | _books_task = combos(
120 | bind_nested('task', books_task),
121 |
122 | bind('ppo.lr', 1e-5),
123 | bind('ppo.total_episodes', 1_000_000),
124 | bind('ppo.batch_size', 512),
125 | )
126 |
127 | sentiment = combos(
128 | _books_task,
129 | bind('rewards.kl_coef', 0.15),
130 | bind('rewards.adaptive_kl', 'on'),
131 | bind('rewards.adaptive_kl.target', 6.0),
132 |
133 | bind('rewards.train_new_model', 'on'),
134 | bind_nested('rewards.train_new_model', train_reward_experiments['sentiment']),
135 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'),
136 |
137 | bind('run.seed', 1)
138 | )
139 |
140 | descriptiveness = combos(
141 | _books_task,
142 | bind('rewards.kl_coef', 0.15),
143 | bind('rewards.adaptive_kl', 'on'),
144 | bind('rewards.adaptive_kl.target', 6.0),
145 |
146 | bind('rewards.train_new_model', 'on'),
147 | bind_nested('rewards.train_new_model', train_reward_experiments['descriptiveness']),
148 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'),
149 |
150 | bind('run.seed', 1)
151 | )
152 |
153 | cnndm = combos(
154 | bind_nested('task', summarize_cnndm_task),
155 |
156 | bind('rewards.train_new_model', 'on'),
157 | bind_nested('rewards.train_new_model', train_reward_experiments['cnndm']),
158 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'),
159 |
160 | bind('ppo.total_episodes', 1_000_000),
161 | bind('ppo.lr', 2e-6),
162 | bind('rewards.kl_coef', 0.01),
163 | # bind('rewards.adaptive_kl', 'on'),
164 | # bind('rewards.adaptive_kl.target', 18.0),
165 | bind('ppo.batch_size', 32),
166 | bind('rewards.whiten', False),
167 |
168 | bind('run.seed', 1)
169 | )
170 |
171 | tldr = combos(
172 | bind_nested('task', summarize_tldr_task),
173 |
174 | bind('rewards.train_new_model', 'on'),
175 | bind_nested('rewards.train_new_model', train_reward_experiments['tldr']),
176 | # bind('rewards.trained_model', '/your/directory/here/reward_model/'),
177 |
178 | bind('ppo.total_episodes', 1_000_000),
179 | bind('ppo.lr', 2e-6),
180 | bind('rewards.kl_coef', 0.03), # 0.01 too low
181 | # bind('rewards.adaptive_kl', 'on'),
182 | # bind('rewards.adaptive_kl.target', 18.0),
183 | bind('ppo.batch_size', 32),
184 | bind('rewards.whiten', False),
185 |
186 | bind('run.seed', 1)
187 | )
188 |
189 | return locals()
190 |
191 |
192 | def launch_train_policy(exp, name, dry_run=False, mpi=8, mode='local', save_dir='/tmp/save/train_policy', **extra_hparams):
193 | experiment_dict = get_experiments()
194 | try:
195 | trials = experiment_dict[exp]
196 | except KeyError:
197 | raise ValueError(f"Couldn't find experiment '{exp}'")
198 |
199 | launch.launch_trials(
200 | name, fn=train_policy.train, trials=trials, mpi=mpi, mode=mode, save_dir=save_dir,
201 | hparam_class=train_policy.HParams, extra_hparams=extra_hparams, dry_run=dry_run)
202 |
203 |
204 | def launch_train_reward(exp, name, dry_run=False, mpi=8, mode='local', save_dir='/tmp/save/train_reward', **extra_hparams):
205 | experiment_dict = get_train_reward_experiments()
206 | try:
207 | trials = experiment_dict[exp]
208 | except KeyError:
209 | raise ValueError(f"Couldn't find experiment '{exp}'")
210 |
211 | launch.launch_trials(
212 | name, fn=train_reward.train, trials=trials, mpi=mpi, mode=mode, save_dir=save_dir,
213 | hparam_class=train_reward.HParams, extra_hparams=extra_hparams, dry_run=dry_run)
214 |
215 |
216 | if __name__ == '__main__':
217 | launch.main(dict(
218 | train_policy=launch_train_policy,
219 | train_reward=launch_train_reward
220 | ))
221 |
--------------------------------------------------------------------------------
/lm_human_preferences/datasets/books.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | from lm_human_preferences.utils import gcs
5 |
6 |
7 | def books_generator(mode, seed=0, shuffle=False, comm=None):
8 | datas = [
9 | json.loads(line) for line in
10 | open(gcs.download_file_cached(f'https://openaipublic.blob.core.windows.net/lm-human-preferences/datasets/book_passages/{mode}.jsonl', comm=comm))
11 | ]
12 | if shuffle:
13 | random.seed(seed)
14 | random.shuffle(datas)
15 |
16 | for x in datas:
17 | yield x
18 |
--------------------------------------------------------------------------------
/lm_human_preferences/datasets/cnndm.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import random
4 | import re
5 |
6 | import ftfy
7 |
8 | from lm_human_preferences.utils import gcs
9 |
10 | dm_single_close_quote = u'\u2019' # unicode
11 | dm_double_close_quote = u'\u201d'
12 | END_TOKENS = ['.', '!', '?', '...', "'", "`", '"', dm_single_close_quote, dm_double_close_quote, ")"] # acceptable ways to end a sentence
13 |
14 | def read_text_file(text_file):
15 | lines = []
16 | with open(text_file, "r") as f:
17 | for line in f:
18 | lines.append(line.strip())
19 | return lines
20 |
21 | def fix_missing_period(line):
22 | """Adds a period to a line that is missing a period"""
23 | if "@highlight" in line:
24 | return line
25 | if line=="":
26 | return line
27 | if line[-1] in END_TOKENS:
28 | return line
29 | # print line[-1]
30 | return line + "."
31 |
32 | def get_art_abs(story_file):
33 | lines = read_text_file(story_file)
34 | # lines = [fix_missing_period(line) for line in lines]
35 | article_lines = []
36 | highlights = []
37 | next_is_highlight = False
38 | for line in lines:
39 | if line == "":
40 | continue # empty line
41 | elif line.startswith("@highlight"):
42 | next_is_highlight = True
43 | elif next_is_highlight:
44 | highlights.append(line)
45 | else:
46 | article_lines.append(line)
47 | article = '\n\n'.join(article_lines)
48 |
49 | # Make abstract into a single string, putting and tags around the sentences
50 | highlights = [fix_missing_period(sent) for sent in highlights]
51 | # abstract = ' '.join(["%s %s %s" % (SENTENCE_START, sent, SENTENCE_END) for sent in highlights])
52 | # abstract = ' '.join(highlights)
53 | return article, highlights
54 |
55 | def hashhex(s):
56 | """Returns a heximal formated SHA1 hash of the input string."""
57 | h = hashlib.sha1()
58 | h.update(s)
59 | return h.hexdigest()
60 |
61 | def get_path_of_url(url):
62 | if 'dailymail.co.uk' in url or 'mailonsunday.ie' in url or 'lib.store.yahoo.net' in url:
63 | site = 'dailymail'
64 | else:
65 | assert 'cnn.com' in url or 'cnn.hk' in url, url
66 | site = 'cnn'
67 | url_hash = hashhex(url.encode('utf-8'))
68 | return f'{site}/stories/{url_hash}.story'
69 |
70 | def clean_up_start(text):
71 | if text[:2] == 'By':
72 | text = '\n'.join(text.split('\n')[2:])
73 | text = re.split(r'\(CNN\) +--', text)[-1]
74 | text = re.split(r"\(CNN\)", text[:100])[-1]+text[100:]
75 | text = re.sub(r"^and \w+\n", "", text)
76 | text = re.split(r".*UPDATED:\s+[0-9]{2}:[0-9]{2}.*[2011|2012|2013|2014|2015]", text)[-1]
77 | text = text.replace('’', "'")
78 | text = text.replace('‘', "'")
79 | return text.strip()
80 |
81 | def cnndm_generator(mode, seed=0, shuffle=False, comm=None):
82 | # data originally from https://github.com/abisee/cnn-dailymail
83 | if mode == 'valid':
84 | mode = 'val'
85 | with open(gcs.download_file_cached(f'https://openaipublic.blob.core.windows.net/lm-human-preferences/datasets/cnndm/url_lists/all_{mode}.txt', comm=comm)) as f:
86 | urls = [line.strip() for line in f]
87 | if shuffle:
88 | random.seed(seed)
89 | random.shuffle(urls)
90 | # if n_eval > 0:
91 | # urls = urls[:n_eval]
92 |
93 | urls_dir = gcs.download_directory_cached(f'gs://lm-human-preferences/datasets/cnndm/cache_{mode}', comm=comm)
94 |
95 | for i, url in enumerate(urls):
96 | path = os.path.join(urls_dir, get_path_of_url(url))
97 | text = open(path).read()
98 | text = clean_up_start(text)
99 | text = ftfy.fix_text(text)
100 |
101 | text = re.sub(r"\n{3,}", "\n\n", text)
102 | text = text.split('@highlight')[0].strip()
103 | yield text
104 | # _, ref_sents = get_art_abs(path)
105 |
--------------------------------------------------------------------------------
/lm_human_preferences/datasets/tldr.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import re
4 |
5 | import ftfy
6 |
7 | from lm_human_preferences.utils import gcs
8 |
9 |
10 | def tldr_generator(mode, seed=0, shuffle=False, comm=None):
11 | random.seed(seed)
12 |
13 | if mode == 'test':
14 | mode = 'valid' # validation set serves as training set, since we don't have access..
15 | assert mode in ['train', 'valid']
16 |
17 | with open(gcs.download_file_cached(f'https://openaipublic.blob.core.windows.net/lm-human-preferences/tldr/{mode}-subset.json', comm=comm)) as f:
18 | datas = json.load(f)
19 |
20 | if shuffle:
21 | random.seed(seed)
22 | random.shuffle(datas)
23 |
24 | for data in datas:
25 | text = data['content']
26 | text = ftfy.fix_text(text)
27 | text = re.sub(r"\n{3,}", "\n\n", text)
28 | text = text.strip()
29 | yield text
30 |
--------------------------------------------------------------------------------
/lm_human_preferences/label_types.py:
--------------------------------------------------------------------------------
1 | """Interface and implementations of label types for a reward model."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Optional, Dict
5 |
6 | import tensorflow as tf
7 |
8 | from lm_human_preferences.utils.core import Schema, pearson_r
9 |
10 |
11 | class LabelType(ABC):
12 | @abstractmethod
13 | def label_schemas(self) -> Dict[str, Schema]:
14 | """Schema for the human annotations."""
15 |
16 | @abstractmethod
17 | def target_scales(self, labels: Dict[str, tf.Tensor]) -> Optional[tf.Tensor]:
18 | """Extracts scalars out of labels whose scale corresponds to the reward model's output.
19 | May be none if the labels have no such information."""
20 |
21 | @abstractmethod
22 | def loss(self, reward_model, labels: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
23 | """
24 | :param labels: the questions with their labels
25 | :returns: a dict of stats, including 'loss' for the actual loss
26 | """
27 |
28 | @abstractmethod
29 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
30 | """Schema for the questions associated with this LabelType."""
31 |
32 |
33 | class PickBest(LabelType):
34 | """Pick best response amongst N."""
35 | def __init__(self, num_responses):
36 | self.num_responses = num_responses
37 |
38 | def label_schemas(self):
39 | return dict(best=Schema(tf.int32, ()))
40 |
41 | def target_scales(self, labels):
42 | return None
43 |
44 | def loss(self, reward_model, labels):
45 | logits = tf.stack([reward_model(labels['query'], labels[f'sample{i}'])
46 | for i in range(self.num_responses)], axis=1)
47 | error = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
48 | labels=labels['best'], logits=logits))
49 | return dict(loss=error, error=error)
50 |
51 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
52 | return dict(
53 | query=Schema(tf.int32, (query_length,)),
54 | **{f"sample{i}": Schema(tf.int32, (response_length,)) for i in range(self.num_responses)}
55 | )
56 |
57 |
58 | class ScalarRating(LabelType):
59 | """Rate a single number with a scalar score."""
60 | def __init__(self):
61 | pass
62 |
63 | def label_schemas(self):
64 | return dict(
65 | score=Schema(tf.float32, ()))
66 |
67 | def target_scales(self, labels):
68 | return labels['score']
69 |
70 | def loss(self, reward_model, labels):
71 | predicted = reward_model(labels['query'], labels['sample'])
72 | labels = labels['score']
73 | error = tf.reduce_mean((labels - predicted) ** 2, axis=0)
74 | label_mean, label_var = tf.nn.moments(labels, axes=[0])
75 | corr = pearson_r(labels, predicted)
76 | return dict(loss=error, error=error,
77 | label_mean=label_mean, label_var=label_var, corr=corr)
78 |
79 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
80 | return dict(
81 | query=Schema(tf.int32, (query_length,)),
82 | sample=Schema(tf.int32, (response_length,)),
83 | )
84 |
85 |
86 | class ScalarComparison(LabelType):
87 | """Give a scalar indicating difference between two responses."""
88 | def label_schemas(self):
89 | return dict(difference=Schema(tf.float32, ()))
90 |
91 | def target_scales(self, labels):
92 | # Divide by two to get something with the same variance as the trained reward model output
93 | return labels['difference']/2
94 |
95 | def loss(self, reward_model, labels):
96 | outputs0 = reward_model(labels['query'], labels['sample0'])
97 | outputs1 = reward_model(labels['query'], labels['sample1'])
98 |
99 | differences = labels['difference']
100 | predicted_differences = outputs1 - outputs0
101 | error = tf.reduce_mean((differences - predicted_differences)**2, axis=0)
102 | return dict(loss=error, error=error)
103 |
104 | def question_schemas(self, *, query_length, response_length) -> Dict[str, Schema]:
105 | return dict(
106 | query=Schema(tf.int32, (query_length,)),
107 | sample0=Schema(tf.int32, (response_length,)),
108 | sample1=Schema(tf.int32, (response_length,)),
109 | )
110 |
111 |
112 | def get(label_type: str) -> LabelType:
113 | if label_type == 'scalar_rating':
114 | return ScalarRating()
115 | if label_type == 'scalar_compare':
116 | return ScalarComparison()
117 | if label_type.startswith('best_of_'):
118 | n = int(label_type[len('best_of_'):])
119 | return PickBest(n)
120 | raise ValueError(f"Unexpected label type {label_type}")
121 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/datasets.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Dict
3 |
4 | import tensorflow as tf
5 |
6 | from lm_human_preferences.datasets.books import books_generator
7 | from lm_human_preferences.datasets.cnndm import cnndm_generator
8 | from lm_human_preferences.datasets.tldr import tldr_generator
9 |
10 | _registry: Dict[str, "Dataset"] = {}
11 |
12 | class Dataset:
13 | def __init__(
14 | self,
15 | name,
16 | *,
17 | generator=None,
18 | ):
19 | global _registry
20 | assert name not in _registry
21 | _registry[name] = self
22 |
23 | self.name = name
24 |
25 | self.generator = generator
26 |
27 | def tf_dataset(
28 | self,
29 | sequence_length,
30 | *,
31 | mode,
32 | encoder=None,
33 | seed=0,
34 | comm=None,
35 | shuffle=True,
36 | repeat_count=None, # Defaults to infinite repeat
37 | # trims so that it starts right after start token
38 | start_token=None,
39 | # trims off last end_token
40 | end_token=None,
41 | padding_token=None,
42 | ):
43 | if padding_token is None:
44 | padding_token = encoder.padding_token
45 | def _generator():
46 | inner_gen = self.generator(mode, seed=seed, shuffle=shuffle, comm=comm)
47 | for text in inner_gen:
48 | tokens = encoder.encode(text)
49 | if start_token is not None:
50 | try:
51 | first_index = tokens.index(start_token)+1
52 | if first_index < len(tokens):
53 | tokens = tokens[first_index:]
54 | except:
55 | continue
56 |
57 | tokens = tokens[:sequence_length]
58 |
59 | if end_token is not None:
60 | try:
61 | last_index = len(tokens)-tokens[::-1].index(end_token)
62 | tokens = tokens[:last_index]
63 | except:
64 | continue
65 |
66 | if len(tokens) < sequence_length:
67 | tokens = tokens + [padding_token] * (sequence_length - len(tokens))
68 |
69 | assert len(tokens) == sequence_length
70 |
71 | yield dict(tokens=tokens)
72 |
73 | tf_dataset = tf.data.Dataset.from_generator(
74 | _generator,
75 | output_types=dict(tokens=tf.int32),
76 | output_shapes=dict(tokens=(sequence_length,)),
77 | )
78 | tf_dataset = tf_dataset.repeat(repeat_count)
79 |
80 | if comm is not None:
81 | num_shards = comm.Get_size()
82 | shard_idx = comm.Get_rank()
83 | if num_shards > 1:
84 | assert seed is not None
85 | tf_dataset = tf_dataset.shard(num_shards, shard_idx)
86 |
87 | return tf_dataset
88 |
89 |
90 | def get_dataset(name) -> Dataset:
91 | global _registry
92 | return _registry[name]
93 |
94 | CnnDm = Dataset(
95 | "cnndm",
96 | generator=cnndm_generator,
97 | )
98 |
99 | Tldr = Dataset(
100 | "tldr",
101 | generator=tldr_generator,
102 | )
103 |
104 | Books = Dataset(
105 | "books",
106 | generator=books_generator,
107 | )
108 |
109 | def test_generator(mode, seed=0, shuffle=False, comm=None):
110 | while True:
111 | yield ''.join([random.choice('abcdefghijklmnopqrstuvwxyz.') for _ in range(40)])
112 |
113 | Test = Dataset(
114 | "test",
115 | generator=test_generator
116 | )
117 |
118 |
119 | """
120 | import tensorflow as tf
121 | from lm_human_preferences.language.datasets import Books as ds
122 | from lm_human_preferences.language.encodings import Main as encoding
123 |
124 | e = encoding.get_encoder()
125 | x = ds.tf_dataset(16, mode='test', encoder=e)
126 | op = x.make_one_shot_iterator().get_next()
127 | s = tf.Session()
128 |
129 | while True:
130 | print(e.decode(s.run(op)['tokens']))
131 | input()
132 | """
133 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/encodings.py:
--------------------------------------------------------------------------------
1 | """Byte pair encoding utilities"""
2 |
3 | import json
4 | import os
5 | from functools import lru_cache
6 |
7 | import tensorflow as tf
8 | import regex as re
9 |
10 | @lru_cache()
11 | def bytes_to_unicode():
12 | """
13 | Returns list of utf-8 byte and a corresponding list of unicode strings.
14 | The reversible bpe codes work on unicode strings.
15 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
16 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
17 | This is a signficant percentage of your normal, say, 32K bpe vocab.
18 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
19 | And avoids mapping to whitespace/control characters the bpe code barfs on.
20 | """
21 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
22 | cs = bs[:]
23 | n = 0
24 | for b in range(2 ** 8):
25 | if b not in bs:
26 | bs.append(b)
27 | cs.append(2 ** 8 + n)
28 | n += 1
29 | cs = [chr(n) for n in cs]
30 | return dict(zip(bs, cs))
31 |
32 |
33 | def get_pairs(word):
34 | """Return set of symbol pairs in a word.
35 |
36 | Word is represented as tuple of symbols (symbols being variable-length strings).
37 | """
38 | pairs = set()
39 | prev_char = word[0]
40 | for char in word[1:]:
41 | pairs.add((prev_char, char))
42 | prev_char = char
43 | return pairs
44 |
45 |
46 | class ReversibleEncoder:
47 | def __init__(self, encoder, bpe_merges, errors="replace", eot_token=None):
48 | self.encoder = encoder
49 | self.decoder = {v: k for k, v in self.encoder.items()}
50 | self.errors = errors # how to handle errors in decoding
51 | self.byte_encoder = bytes_to_unicode()
52 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
53 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
54 | self.eot_token = eot_token
55 | self.cache = {}
56 | self.padding_token = len(encoder) + 2 # +2 unnecessary, for historical reasons
57 | self.decoder[self.padding_token] = ''
58 |
59 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
60 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
61 |
62 | def bpe(self, token):
63 | if token in self.cache:
64 | return self.cache[token]
65 | word = tuple(token)
66 | pairs = get_pairs(word)
67 |
68 | if not pairs:
69 | return token
70 |
71 | while True:
72 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
73 | if bigram not in self.bpe_ranks:
74 | break
75 | first, second = bigram
76 | new_word = []
77 | i = 0
78 | while i < len(word):
79 | try:
80 | j = word.index(first, i)
81 | new_word.extend(word[i:j])
82 | i = j
83 | except:
84 | new_word.extend(word[i:])
85 | break
86 |
87 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
88 | new_word.append(first + second)
89 | i += 2
90 | else:
91 | new_word.append(word[i])
92 | i += 1
93 | new_word = tuple(new_word)
94 | word = new_word
95 | if len(word) == 1:
96 | break
97 | else:
98 | pairs = get_pairs(word)
99 | word = " ".join(word)
100 | self.cache[token] = word
101 | return word
102 |
103 | def encode(self, text):
104 | bpe_tokens = []
105 | for token in re.findall(self.pat, text):
106 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
107 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
108 | return bpe_tokens
109 |
110 | def decode(self, tokens, pretty=False):
111 | del pretty
112 | text = "".join([self.decoder[token] for token in tokens])
113 | text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
114 | return text
115 |
116 |
117 | def read_file(path):
118 | with tf.gfile.Open(path, "rb") as fh:
119 | return fh.read()
120 |
121 |
122 | class Encoding:
123 | def __init__(
124 | self,
125 | name,
126 | *,
127 | n_vocab=0,
128 | eot_token=None,
129 | encoder_path="encoder.json",
130 | bpe_path="vocab.bpe",
131 | base_path=None,
132 | ):
133 | self.name = name
134 | self.eot_token = eot_token
135 | self.n_vocab = n_vocab
136 |
137 | if base_path is None:
138 | base_path = os.path.join("gs://gpt-2/encodings", name)
139 |
140 | self.base_path = base_path
141 | if name != "test":
142 | self.encoder_path = os.path.join(self.base_path, encoder_path)
143 | self.bpe_path = os.path.join(self.base_path, bpe_path)
144 |
145 | def get_encoder(self):
146 | if self.name == "test":
147 | vocab = "abcdefghijklmnopqrstuvwxyz."
148 | assert len(vocab) == self.n_vocab
149 |
150 | class TestEncoder(ReversibleEncoder):
151 | def __init__(self):
152 | super().__init__(encoder={w: i for i, w in enumerate(vocab)}, bpe_merges=list())
153 | self.padding_token = len(vocab)
154 | def encode(self, text):
155 | return [self.encoder.get(x, len(vocab) - 1) for x in text]
156 | def decode(self, tokens, pretty=False):
157 | return ''.join([self.decoder.get(t, '') for t in tokens])
158 |
159 | return TestEncoder()
160 |
161 | encoder_dict = json.loads(read_file(self.encoder_path).decode())
162 | bpe_data = read_file(self.bpe_path).decode()
163 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
164 | assert len(encoder_dict) == self.n_vocab
165 | encoder = ReversibleEncoder(encoder=encoder_dict, bpe_merges=bpe_merges, eot_token=self.eot_token)
166 | assert encoder.padding_token >= self.n_vocab
167 | return encoder
168 |
169 |
170 | Main = Encoding("main", n_vocab=50257, eot_token=50256)
171 |
172 | Test = Encoding("test", n_vocab=27, eot_token=26)
173 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/model.py:
--------------------------------------------------------------------------------
1 | """Alec's transformer model."""
2 |
3 | from functools import partial
4 | from typing import Optional
5 | from dataclasses import dataclass
6 |
7 | import tensorflow as tf
8 | import numpy as np
9 | from tensorflow.python.framework import function
10 |
11 | from lm_human_preferences.utils import core as utils
12 | from lm_human_preferences.utils import hyperparams
13 |
14 | @dataclass
15 | class HParams(hyperparams.HParams):
16 | # Encoding (set during loading process)
17 | n_vocab: int = 0
18 |
19 | # Model parameters
20 | n_ctx: int = 512
21 | n_embd: int = 768
22 | n_head: int = 12
23 | n_layer: int = 12
24 |
25 | embd_pdrop: float = 0.1
26 | attn_pdrop: float = 0.1
27 | resid_pdrop: float = 0.1
28 | head_pdrop: float = 0.1
29 |
30 |
31 | def parse_comma_separated_int_list(s):
32 | return [int(i) for i in s.split(",")] if s else []
33 |
34 |
35 | def gelu(x):
36 | with tf.name_scope('gelu'):
37 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3))))
38 |
39 |
40 | def dropout(x, pdrop, *, do_dropout, stateless=True, seed=None, name):
41 | """Like tf.nn.dropout but stateless.
42 | """
43 | if stateless:
44 | assert seed is not None
45 | def _dropout():
46 | with tf.name_scope(name):
47 | noise_shape = tf.shape(x)
48 |
49 | if stateless:
50 | r = tf.random.stateless_uniform(noise_shape, seed, dtype=x.dtype)
51 | # floor uniform [keep_prob, 1.0 + keep_prob)
52 | mask = tf.floor(1 - pdrop + r)
53 | return x * (mask * (1 / (1 - pdrop)))
54 | else:
55 | return tf.nn.dropout(x, rate=pdrop, noise_shape=noise_shape)
56 | if pdrop == 0 or not do_dropout:
57 | return x
58 | else:
59 | return _dropout()
60 |
61 |
62 | def norm(x, scope, *, axis=-1, epsilon=1e-5):
63 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform."""
64 | with tf.variable_scope(scope):
65 | n_state = x.shape[-1].value
66 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1))
67 | s = tf.reduce_mean(tf.square(x), axis=axis, keepdims=True)
68 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0))
69 | u = tf.reduce_mean(x, axis=axis, keepdims=True)
70 | s = s - tf.square(u)
71 | x = (x - u) * tf.rsqrt(s + epsilon)
72 | x = x*g + b
73 | return x
74 |
75 |
76 | def split_states(x, n):
77 | """Reshape the last dimension of x into [n, x.shape[-1]/n]."""
78 | *start, m = utils.shape_list(x)
79 | return tf.reshape(x, start + [n, m//n])
80 |
81 |
82 | def merge_states(x):
83 | """Smash the last two dimensions of x into a single dimension."""
84 | *start, a, b = utils.shape_list(x)
85 | return tf.reshape(x, start + [a*b])
86 |
87 |
88 | def conv1x1(x, scope, nf, *, w_init_stdev=0.02):
89 | with tf.variable_scope(scope):
90 | *start, nx = utils.shape_list(x)
91 |
92 | # Don't cast params until just prior to use -- saves a lot of memory for large models
93 | with tf.control_dependencies([x]):
94 | w = tf.squeeze(tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)), axis=0)
95 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0))
96 | c = tf.matmul(tf.reshape(x, [-1, nx]), w) + b
97 | c = tf.reshape(c, start+[nf])
98 | return c
99 |
100 |
101 | def attention_mask(nd, ns, *, dtype):
102 | """1's in the lower triangle, counting from the lower right corner.
103 |
104 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs.
105 | """
106 | i = tf.range(nd)[:,None]
107 | j = tf.range(ns)
108 | m = i >= j - ns + nd
109 | # to ignore first parts of context (useful for sampling with static shapes)
110 | # m = tf.math.logical_and(m, tf.math.logical_or(j >= ignore, i < ignore - ns + nd))
111 | return tf.cast(m, dtype)
112 |
113 |
114 | def softmax(x, axis=-1):
115 | x = x - tf.reduce_max(x, axis=axis, keepdims=True)
116 | ex = tf.exp(x)
117 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True)
118 |
119 |
120 | def attn(x, scope, n_state, *, past, mask, do_dropout, scale=False, hparams, seed):
121 | assert x.shape.ndims == 3 # Should be [batch, sequence, features]
122 | if past is not None:
123 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v]
124 |
125 | def split_heads(x):
126 | # From [batch, sequence, features] to [batch, heads, sequence, features]
127 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3])
128 |
129 | def merge_heads(x):
130 | # Reverse of split_heads
131 | return merge_states(tf.transpose(x, [0, 2, 1, 3]))
132 |
133 | def mask_attn_weights(w):
134 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst.
135 | bs, _, nd, ns = utils.shape_list(w)
136 | b = attention_mask(nd, ns, dtype=w.dtype)
137 | b = tf.reshape(b, [1, 1, nd, ns])
138 | if mask is not None:
139 | b *= tf.reshape(tf.cast(mask, w.dtype), [bs, 1, 1, ns])
140 | w = w*b - tf.cast(1e10, w.dtype)*(1-b)
141 | return w
142 |
143 | def multihead_attn(q, k, v, *, seed):
144 | orig_dtype = v.dtype
145 | q, k, v = map(partial(tf.cast, dtype=tf.float32), (q, k, v))
146 | # q, k, v have shape [batch, heads, sequence, features]
147 | w = tf.matmul(q, k, transpose_b=True)
148 |
149 | if scale:
150 | n_state = v.shape[-1].value
151 | w = w * tf.rsqrt(tf.cast(n_state, w.dtype))
152 |
153 | w = mask_attn_weights(w)
154 | w = softmax(w)
155 | w = dropout(w, hparams.attn_pdrop,
156 | do_dropout=do_dropout, name='attn_drop', stateless=True, seed=seed)
157 | a = tf.matmul(w, v)
158 | a = tf.cast(a, dtype=orig_dtype, name='a_cast')
159 | return a
160 |
161 | with tf.variable_scope(scope):
162 | attn_seed, resid_seed = split_seed(seed, 2)
163 |
164 | assert n_state % hparams.n_head == 0
165 | w_init_stdev = 1/np.sqrt(n_state)
166 | c = conv1x1(x, 'c_attn', n_state * 3, w_init_stdev=w_init_stdev)
167 | q, k, v = map(split_heads, tf.split(c, 3, axis=2))
168 | present = tf.stack([k, v], axis=1)
169 | if past is not None:
170 | pk, pv = tf.unstack(past, axis=1)
171 | k = tf.concat([pk, k], axis=-2)
172 | v = tf.concat([pv, v], axis=-2)
173 | a = multihead_attn(q, k, v, seed=attn_seed)
174 | a = merge_heads(a)
175 | w_init_stdev = 1/np.sqrt(n_state*hparams.n_layer)
176 | a = conv1x1(a, 'c_proj', n_state, w_init_stdev=w_init_stdev)
177 | a = dropout(a, hparams.resid_pdrop, do_dropout=do_dropout, stateless=True, seed=resid_seed, name='attn_resid_drop')
178 | return a, present
179 |
180 |
181 | def mlp(x, scope, n_hidden, *, do_dropout, hparams, seed):
182 | with tf.variable_scope(scope):
183 | nx = x.shape[-1].value
184 | w_init_stdev = 1/np.sqrt(nx)
185 | h = gelu(
186 | conv1x1(x, 'c_fc', n_hidden, w_init_stdev=w_init_stdev))
187 | w_init_stdev = 1/np.sqrt(n_hidden*hparams.n_layer)
188 | h2 = conv1x1(h, 'c_proj', nx, w_init_stdev=w_init_stdev)
189 | h2 = dropout(h2, hparams.resid_pdrop, do_dropout=do_dropout, stateless=True, seed=seed, name='mlp_drop')
190 | return h2
191 |
192 |
193 | def block(x, scope, *, past, mask, do_dropout, scale=False, hparams, seed):
194 | with tf.variable_scope(scope):
195 | attn_seed, mlp_seed = split_seed(seed, 2)
196 |
197 | nx = x.shape[-1].value
198 | a, present = attn(
199 | norm(x, 'ln_1'),
200 | 'attn', nx, past=past, mask=mask, do_dropout=do_dropout, scale=scale, hparams=hparams, seed=attn_seed)
201 | x = x + a
202 |
203 | m = mlp(
204 | norm(x, 'ln_2'),
205 | 'mlp', nx*4, do_dropout=do_dropout, hparams=hparams, seed=mlp_seed)
206 | h = x + m
207 | return h, present
208 |
209 |
210 | @function.Defun(
211 | python_grad_func=lambda x, dy: tf.convert_to_tensor(dy),
212 | shape_func=lambda op: [op.inputs[0].get_shape()])
213 | def convert_gradient_to_tensor(x):
214 | """Force gradient to be a dense tensor.
215 |
216 | It's often faster to do dense embedding gradient on GPU than sparse on CPU.
217 | """
218 | return x
219 |
220 |
221 | def embed(X, we):
222 | """Embedding lookup.
223 |
224 | X has shape [batch, sequence, info]. Currently info = 2 corresponding to [token_id, position].
225 | """
226 | we = convert_gradient_to_tensor(we)
227 | e = tf.gather(we, X)
228 | return e
229 |
230 |
231 | #tensor contraction of the final axes of x with the first axes of y
232 | #need to write it ourselves because tensorflow's tensordot is slow
233 | def tensordot(x, y, num_axes):
234 | split_x_axes_at = x.shape.ndims - num_axes
235 | x_shape = tf.shape(x)[:split_x_axes_at]
236 | y_shape = tf.shape(y)[num_axes:]
237 | rx = tf.reshape(x, [tf.reduce_prod(x_shape), tf.reduce_prod(tf.shape(x)[split_x_axes_at:])])
238 | ry = tf.reshape(y, [-1, tf.reduce_prod(y_shape)])
239 | rresult = tf.matmul(rx, ry)
240 | result = tf.reshape(rresult, tf.concat([x_shape, y_shape], axis=0))
241 | result.set_shape(x.shape[:split_x_axes_at].concatenate(y.shape[num_axes:]))
242 | return result
243 |
244 |
245 | #more convenient fc layer that avoids stupid shape stuff
246 | #consumes in_axes of x
247 | #produces y of shape outshape
248 | def fc_layer(x, outshape, *, in_axes=1, scale=None):
249 | inshape = tuple([int(d) for d in x.shape[-in_axes:]]) if in_axes>0 else ()
250 | outshape = tuple(outshape)
251 | if scale is None:
252 | scale = 1 / np.sqrt(np.prod(inshape) + 1)
253 | w = tf.get_variable('w', inshape + outshape, initializer=tf.random_normal_initializer(stddev=scale))
254 | b = tf.get_variable('b', outshape, initializer=tf.constant_initializer(0))
255 | # Call the regularizer manually so that it works correctly with GradientTape
256 | regularizer = tf.contrib.layers.l2_regularizer(scale=1/np.prod(outshape)) #so that initial value of regularizer is 1
257 | reg_loss = regularizer(w)
258 | return tensordot(x, w, in_axes) + b, reg_loss
259 |
260 |
261 | def past_shape(*, hparams, batch_size=None, sequence=None):
262 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, utils.exact_div(hparams.n_embd, hparams.n_head)]
263 |
264 |
265 | def positions_for(*, batch, sequence, past_length, mask):
266 | if mask is None:
267 | return utils.expand_tile(past_length + tf.range(sequence), batch, axis=0)
268 | else:
269 | return tf.cumsum(tf.cast(mask, tf.int32), exclusive=True, axis=-1)[:, past_length:]
270 |
271 |
272 | def split_seed(seed, n=2):
273 | if n == 0:
274 | return []
275 | return tf.split(
276 | tf.random.stateless_uniform(dtype=tf.int64, shape=[2*n], minval=-2**63, maxval=2**63-1, seed=seed),
277 | n, name='split_seeds')
278 |
279 |
280 | class Model:
281 | def __init__(self, hparams: HParams, scalar_heads=[], scope=None):
282 | self.hparams = hparams
283 | self.scalar_heads = scalar_heads
284 | with tf.variable_scope(scope, 'model') as scope:
285 | self.scope = scope
286 | self.built = False
287 |
288 | def __call__(self, *, X, Y=None, past=None, past_tokens=None, mask=None,
289 | padding_token: Optional[int]=None, do_dropout=False):
290 | X = tf.convert_to_tensor(X, dtype=tf.int32)
291 | if mask is not None:
292 | mask = tf.convert_to_tensor(mask, dtype=tf.bool)
293 | assert mask.dtype == tf.bool
294 | if padding_token is not None:
295 | assert mask is None, 'At most one of mask and padding_token should be set'
296 | mask = tf.not_equal(X, padding_token)
297 | X = tf.where(mask, X, tf.zeros_like(X))
298 | if past is not None:
299 | assert past_tokens is not None, 'padding_token requires past_tokens'
300 | mask = tf.concat([tf.not_equal(past_tokens, padding_token), mask], axis=1)
301 | with tf.variable_scope(self.scope, reuse=self.built, auxiliary_name_scope=not self.built):
302 | self.built = True
303 | results = {}
304 | batch, sequence = utils.shape_list(X)
305 |
306 | seed = tf.random.uniform(dtype=tf.int64, shape=[2], minval=-2**63, maxval=2**63-1)
307 | wpe_seed, wte_seed, blocks_seed, heads_seed = split_seed(seed, 4)
308 |
309 | wpe = tf.get_variable('wpe', [self.hparams.n_ctx, self.hparams.n_embd],
310 | initializer=tf.random_normal_initializer(stddev=0.01))
311 | wte = tf.get_variable('wte', [self.hparams.n_vocab, self.hparams.n_embd],
312 | initializer=tf.random_normal_initializer(stddev=0.02))
313 | wpe = dropout(wpe, self.hparams.embd_pdrop,
314 | do_dropout=do_dropout, stateless=True, seed=wpe_seed, name='wpe_drop')
315 | wte = dropout(wte, self.hparams.embd_pdrop,
316 | do_dropout=do_dropout, stateless=True, seed=wte_seed, name='wte_drop')
317 |
318 | past_length = 0 if past is None else tf.shape(past)[-2]
319 |
320 | positions = positions_for(batch=batch, sequence=sequence, past_length=past_length, mask=mask)
321 | h = embed(X, wte) + embed(positions, wpe)
322 | # Transformer
323 | presents = []
324 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * self.hparams.n_layer
325 | assert len(pasts) == self.hparams.n_layer
326 | block_seeds = split_seed(blocks_seed, self.hparams.n_layer)
327 | for layer, (past, block_seed) in enumerate(zip(pasts, block_seeds)):
328 | h, present = block(
329 | h, 'h%d' % layer, past=past, mask=mask, do_dropout=do_dropout, scale=True,
330 | hparams=self.hparams, seed=block_seed)
331 | presents.append(present)
332 | results['present'] = tf.stack(presents, axis=1)
333 | h = norm(h, 'ln_f')
334 | if mask is not None:
335 | # For non-present tokens, use the output from the last present token instead.
336 | present_indices = utils.where(mask[:,past_length:], tf.tile(tf.range(sequence)[None,:], [batch, 1]), -1)
337 | use_indices = utils.cumulative_max(present_indices)
338 | # assert since GPUs don't
339 | with tf.control_dependencies([tf.assert_none_equal(use_indices, -1)]):
340 | h = utils.index_each(h, use_indices)
341 | results['h'] = h
342 |
343 | # Language model loss. Do tokens 0
379 | return params
380 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/sample.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from lm_human_preferences.language import model
4 | from lm_human_preferences.utils import core as utils
5 |
6 |
7 | def sample_sequence(*, step, model_hparams, length, batch_size=None, context=None,
8 | temperature=1, top_k=0, top_p=1.0, extra_outputs={}, cond=None):
9 | """
10 | Sampling from an autoregressive sequence model.
11 |
12 | Inputs:
13 | step: A function which takes model hparams, a tokens Tensor, past, and
14 | returns a dictionary with 'logits' and 'presents', and any extra vars.
15 | context: Includes start tokens.
16 | extra_outputs: Map from extra output key to dtype
17 | Returns:
18 | A dict with keys 'presents', 'logits', and any keys in extra_outputs
19 | """
20 |
21 | with tf.name_scope('sample_seq'):
22 | batch_size, *_ = utils.shape_list(context)
23 |
24 | beta = 1 / tf.maximum(tf.cast(temperature, tf.float32), 1e-10)
25 |
26 | context_output = step(model_hparams, context)
27 | logits = tf.cast(context_output['logits'][:,-1], tf.float32)
28 |
29 | first_output_logits = tf.cast(beta, logits.dtype) * logits
30 | first_outputs = utils.sample_from_logits(first_output_logits)
31 | first_logprobs = utils.logprobs_from_logits(logits=first_output_logits, labels=first_outputs)
32 |
33 | def body(past, prev, output, logprobs, *extras):
34 | next_outputs = step(model_hparams, prev[:, tf.newaxis], past=past,
35 | past_tokens=output[:, :-1])
36 | logits = tf.cast(next_outputs['logits'], tf.float32) * beta
37 | if top_k != 0:
38 | logits = tf.cond(tf.equal(top_k, 0),
39 | lambda: logits,
40 | lambda: utils.take_top_k_logits(logits, top_k))
41 | if top_p != 1.0:
42 | logits = utils.take_top_p_logits(logits, top_p)
43 | next_sample = utils.sample_from_logits(logits, dtype=tf.int32)
44 |
45 | next_logprob = utils.logprobs_from_logits(logits=logits, labels=next_sample)
46 | return [
47 | tf.concat([past, next_outputs['presents']], axis=-2),
48 | tf.squeeze(next_sample, axis=[1]),
49 | tf.concat([output, next_sample], axis=1),
50 | tf.concat([logprobs, next_logprob], axis=1),
51 | *[tf.concat([prev, next_outputs[k]], axis=1) for k, prev in zip(extra_outputs, extras)],
52 | ]
53 |
54 | try:
55 | shape_batch_size = int(batch_size)
56 | except TypeError:
57 | shape_batch_size = None
58 | if cond is None:
59 | def always_true(*args):
60 | return True
61 | cond = always_true
62 | presents, _, tokens, logprobs, *extras = tf.while_loop(
63 | body=body,
64 | cond=cond,
65 | loop_vars=[
66 | context_output['presents'], # past
67 | first_outputs, # prev
68 | tf.concat([context, first_outputs[:, tf.newaxis]], axis=1), # output
69 | first_logprobs[:, tf.newaxis], #logprobs
70 | *[context_output[k][:, -1:] for k in extra_outputs] # extras
71 | ],
72 | shape_invariants=[
73 | tf.TensorShape(model.past_shape(hparams=model_hparams, batch_size=shape_batch_size)),
74 | tf.TensorShape([shape_batch_size]),
75 | tf.TensorShape([shape_batch_size, None]),
76 | tf.TensorShape([shape_batch_size, None]),
77 | *[tf.TensorShape([shape_batch_size, None]) for _ in extra_outputs]
78 | ],
79 | maximum_iterations=length-1,
80 | back_prop=False,
81 | parallel_iterations=2,
82 | )
83 |
84 | return dict(tokens=tokens, presents=presents, logprobs=logprobs, **dict(zip(extra_outputs, extras)))
85 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/test_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Transformer model tests."""
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from lm_human_preferences.utils import core as utils
8 | from lm_human_preferences.language import model
9 |
10 | def test_incremental():
11 | hparams = model.HParams()
12 | hparams.override_from_dict(dict(
13 | n_vocab=10,
14 | n_ctx=5,
15 | n_embd=9,
16 | n_head=3,
17 | n_layer=2,
18 | ))
19 | batch_size = 2
20 | steps = 5
21 | np.random.seed(7)
22 | tf.set_random_seed(7)
23 |
24 | # Transformer model
25 | m = model.Model(hparams=hparams)
26 | X = tf.placeholder(shape=[batch_size, None], dtype=tf.int32)
27 | logits = m(X=X)['lm_logits']
28 | past_p = tf.placeholder(shape=model.past_shape(hparams=hparams, batch_size=batch_size), dtype=tf.float32)
29 | # Test reusing it in a different variable scope
30 | with tf.variable_scope('other_scope'):
31 | past_lm = m(X=X[:,-1:], past=past_p)
32 | past_logits = past_lm['lm_logits']
33 | future = tf.concat([past_p, past_lm['present']], axis=-2)
34 |
35 | # Data
36 | ids = np.random.randint(hparams.n_vocab, size=[batch_size, steps]).astype(np.int32)
37 | past = np.zeros(model.past_shape(hparams=hparams, batch_size=batch_size, sequence=0), dtype=np.float32)
38 |
39 | # Evaluate
40 | with tf.Session() as sess:
41 | tf.global_variables_initializer().run()
42 | for step in range(steps):
43 | logits_v, past_logits_v, past = sess.run([logits, past_logits, future],
44 | feed_dict={X: ids[:,:step+1], past_p: past})
45 | assert np.allclose(logits_v[:,-1:], past_logits_v, atol=1e-3, rtol=1e-3)
46 |
47 |
48 | def test_mask():
49 | np.random.seed(7)
50 | tf.set_random_seed(7)
51 |
52 | # Make a transformer
53 | hparams = model.HParams()
54 | hparams.override_from_dict(dict(
55 | n_vocab=10,
56 | n_ctx=8,
57 | n_embd=3,
58 | n_head=3,
59 | n_layer=2,
60 | ))
61 | batch_size = 4# 64
62 | policy = model.Model(hparams=hparams)
63 |
64 | # Random pasts and tokens
65 | past_length = 4
66 | length = 3
67 | past = np.random.randn(*model.past_shape(
68 | hparams=hparams, batch_size=batch_size, sequence=past_length)).astype(np.float32)
69 | X = np.random.randint(hparams.n_vocab, size=[batch_size, length])
70 |
71 | # Run model without gaps
72 | logits = policy(past=past, X=X)['lm_logits']
73 |
74 | # Run the same thing, but with gaps randomly inserted
75 | gap_past_length = 7
76 | gap_length = 5
77 | def random_subsequence(*, n, size):
78 | # Always make the first token be present, since the model tries to fill gaps with the previous states
79 | sub = [
80 | np.concatenate(([0], np.random.choice(np.arange(1,n), size=size-1, replace=False)))
81 | for _ in range(batch_size)
82 | ]
83 | return np.sort(sub, axis=-1)
84 | past_sub = random_subsequence(n=gap_past_length, size=past_length)
85 | X_sub = random_subsequence(n=gap_length, size=length)
86 | past_gap = np.random.randn(*model.past_shape(
87 | hparams=hparams, batch_size=batch_size, sequence=gap_past_length)).astype(np.float32)
88 | X_gap = np.random.randint(hparams.n_vocab, size=[batch_size, gap_length])
89 | mask = np.zeros([batch_size, gap_past_length + gap_length], dtype=np.bool)
90 | for b in range(batch_size):
91 | for i in range(past_length):
92 | past_gap[b,:,:,:,past_sub[b,i]] = past[b,:,:,:,i]
93 | for i in range(length):
94 | X_gap[b,X_sub[b,i]] = X[b,i]
95 | mask[b, past_sub[b]] = mask[b, gap_past_length + X_sub[b]] = 1
96 | gap_logits = policy(past=past_gap, X=X_gap, mask=mask)['lm_logits']
97 | sub_logits = utils.index_each(gap_logits, X_sub)
98 |
99 | # Compare
100 | with tf.Session() as sess:
101 | tf.global_variables_initializer().run()
102 | logits, sub_logits = sess.run([logits, sub_logits])
103 | assert logits.shape == sub_logits.shape
104 | assert np.allclose(logits, sub_logits, atol=1e-5)
105 |
106 |
107 | def test_attention_mask():
108 | with tf.Session() as sess:
109 | for nd in 1, 2, 3:
110 | for ns in range(nd, 4):
111 | ours = model.attention_mask(nd, ns, dtype=tf.int32)
112 | theirs = tf.matrix_band_part(tf.ones([nd, ns], dtype=tf.int32), tf.cast(-1, tf.int32), ns-nd)
113 | ours, theirs = sess.run([ours, theirs])
114 | print(ours)
115 | print(theirs)
116 | assert np.all(ours == theirs)
117 |
118 |
119 | if __name__ == '__main__':
120 | test_mask()
121 | test_attention_mask()
122 | test_incremental()
123 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/test_sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Test sample_sequence()."""
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | from tensorflow.contrib.training import HParams
7 |
8 | from lm_human_preferences.language import sample
9 |
10 | n_vocab = 10
11 | batch_size = 2
12 | hparams = HParams(
13 | n_layer=0,
14 | n_head=1,
15 | n_embd=0,
16 | n_attn=0,
17 | )
18 |
19 | # Returns a policy that deterministically chooses previous token + 1.
20 | def step(hparams, tokens, past=None, past_tokens=None):
21 | logits = tf.one_hot(tokens + 1, n_vocab, on_value=0., off_value=-np.inf, dtype=tf.float32)
22 | ret = {
23 | 'logits': logits,
24 | 'presents': tf.zeros(shape=[2, 0, 2, 1, 0, 0]),
25 | }
26 | return ret
27 |
28 | def test_sample_sequence():
29 | output = sample.sample_sequence(step=step, model_hparams=hparams, length=4, batch_size=batch_size,
30 | context=tf.constant([[5, 0], [4, 3]]))
31 | expected = np.array([[5, 0, 1, 2, 3, 4], [4, 3, 4, 5, 6, 7]])
32 |
33 | with tf.Session() as sess:
34 | np.testing.assert_array_equal(sess.run(output)['tokens'], expected)
35 |
36 |
37 | if __name__ == '__main__':
38 | test_sample_sequence()
39 |
--------------------------------------------------------------------------------
/lm_human_preferences/language/trained_models.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 |
4 | import tensorflow as tf
5 |
6 | from lm_human_preferences.language import encodings, model
7 |
8 |
9 | class TrainedModel():
10 | def __init__(self, name, *, savedir=None, scope=None):
11 | self.name = name
12 | self.scope = scope
13 | self.savedir = savedir if savedir else os.path.join('gs://gpt-2/models/', name)
14 | if name == 'test':
15 | self.encoding = encodings.Test
16 | else:
17 | self.encoding = encodings.Main
18 | self._hparams = None
19 |
20 | def checkpoint(self):
21 | if self.name == 'test':
22 | return None
23 | ckpt = tf.train.latest_checkpoint(self.savedir)
24 | if ckpt is not None:
25 | return ckpt
26 | return tf.train.latest_checkpoint(os.path.join(self.savedir, 'checkpoints'))
27 |
28 | def hparams(self):
29 | if self._hparams is None:
30 | if self.name == 'test':
31 | hparams = test_hparams()
32 | else:
33 | hparams = load_hparams(
34 | os.path.join(self.savedir, 'hparams.json')
35 | )
36 | self._hparams = hparams
37 | return copy.deepcopy(self._hparams)
38 |
39 | def init_op(self, params, new_scope):
40 | assert params
41 | params = dict(**params)
42 | checkpoint = self.checkpoint()
43 | available = tf.train.list_variables(checkpoint)
44 | unchanged = {}
45 |
46 | for name, shape in available:
47 | our_name = name
48 | if self.scope:
49 | if name.startswith(self.scope):
50 | our_name = name[len(self.scope):].lstrip('/')
51 | else:
52 | continue
53 | # Annoying hack since some code uses 'scope/model' as the scope and other code uses just 'scope'
54 | our_name = '%s/%s' % (new_scope, our_name)
55 | if our_name not in params:
56 | # NOTE: this happens for global_step and optimizer variables
57 | # (e.g. beta1_power, beta2_power, blah/Adam, blah/Adam_1)
58 | # print(f'{name} is missing for scope {new_scope}')
59 | continue
60 | var = params[our_name]
61 | del params[our_name]
62 | assert var.shape == shape, 'Shape mismatch: %s.shape = %s != %s' % (var.op.name, var.shape, shape)
63 | unchanged[name] = var
64 | for name in params.keys():
65 | print(f'Param {name} is missing from checkpoint {checkpoint}')
66 | tf.train.init_from_checkpoint(checkpoint, unchanged)
67 |
68 | def load_hparams(file):
69 | hparams = model.HParams()
70 | hparams.override_from_json_file(file)
71 | return hparams
72 |
73 | def test_hparams():
74 | hparams = model.HParams()
75 | hparams.override_from_dict(dict(
76 | n_vocab=27, # Corresponds to random encoding length
77 | n_ctx=8,
78 | n_layer=2,
79 | n_embd=7,
80 | n_head=1,
81 | ))
82 | return hparams
83 |
--------------------------------------------------------------------------------
/lm_human_preferences/lm_tasks.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 | import tensorflow as tf
5 |
6 | from lm_human_preferences.language import datasets
7 | from lm_human_preferences.utils import core as utils
8 | from lm_human_preferences.utils import hyperparams
9 |
10 |
11 | @dataclass
12 | class PolicyHParams(hyperparams.HParams):
13 | temperature: float = 1.0
14 | initial_model: str = None
15 |
16 | @dataclass
17 | class TaskHParams(hyperparams.HParams):
18 | # Query params
19 | query_length: int = None
20 | query_dataset: str = None
21 | query_prefix: str = ''
22 | query_suffix: str = ''
23 | start_text: Optional[str] = '.'
24 | end_text: Optional[str] = None
25 |
26 | # Response params
27 | response_length: int = None
28 |
29 | # Truncate response after the first occurrence of this token at or after index after when sampling.
30 | truncate_token: Optional[int] = None
31 | truncate_after: int = 0
32 | penalty_reward_value: int = -1
33 |
34 | policy: PolicyHParams = field(default_factory=PolicyHParams)
35 |
36 | #returns a postprocessing function
37 | #it is applied to responses before they are scored
38 | #central example: replace all tokens after truncate_token with padding_token
39 | def postprocess_fn_from_hparams(hparams: TaskHParams, padding_token: int):
40 | def get_mask(responses, truncate_token, truncate_after):
41 | # We want to truncate at the first occurrence of truncate_token that appears at or after
42 | # position truncate_after in the responses
43 | mask = tf.cast(tf.equal(responses, truncate_token), tf.int32)
44 | mask = tf.concat([tf.zeros_like(mask)[:,:truncate_after], mask[:,truncate_after:]], axis=1)
45 | return tf.cast(tf.cumsum(mask, axis=1) - mask, tf.bool)
46 | if hparams.truncate_token is not None:
47 | def truncate(responses):
48 | mask = get_mask(responses, hparams.truncate_token, hparams.truncate_after)
49 | return tf.where(mask, padding_token * tf.ones_like(responses), responses)
50 | return truncate
51 | else:
52 | return lambda responses: responses
53 |
54 | #returns a filter function
55 | #responses not passing that function will receive a low (fixed) score
56 | #only query humans on responses that pass that function
57 | #central example: ensure that the sample contains truncate_token
58 | def filter_fn_from_hparams(hparams: TaskHParams):
59 | def filter(responses):
60 | if hparams.truncate_token is not None:
61 | matches_token = tf.equal(responses[:, hparams.truncate_after:], hparams.truncate_token)
62 | return tf.reduce_any(matches_token, axis=-1)
63 | else:
64 | return tf.ones(tf.shape(responses)[0], dtype=tf.bool)
65 | return filter
66 |
67 |
68 | def query_formatter(hparams: TaskHParams, encoder):
69 | """Turns a query into a context to feed to the language model
70 |
71 | NOTE: Both of these are lists of tokens
72 | """
73 | def query_formatter(queries):
74 | batch_size = tf.shape(queries)[0]
75 | prefix_tokens = tf.constant(encoder.encode(hparams.query_prefix), dtype=tf.int32)
76 | tiled_prefix = utils.expand_tile(prefix_tokens, batch_size, axis=0)
77 | suffix_tokens = tf.constant(encoder.encode(hparams.query_suffix), dtype=tf.int32)
78 | tiled_suffix = utils.expand_tile(suffix_tokens, batch_size, axis=0)
79 | return tf.concat([tiled_prefix, queries, tiled_suffix], 1)
80 | return query_formatter
81 |
82 |
83 | def make_query_sampler(*, hparams: TaskHParams, encoder, batch_size: int, mode='train', comm=None):
84 | if hparams.start_text:
85 | start_token, = encoder.encode(hparams.start_text)
86 | else:
87 | start_token = None
88 |
89 | if hparams.end_text:
90 | end_token, = encoder.encode(hparams.end_text)
91 | else:
92 | end_token = None
93 |
94 | data = datasets.get_dataset(hparams.query_dataset).tf_dataset(
95 | sequence_length=hparams.query_length, mode=mode, comm=comm, encoder=encoder,
96 | start_token=start_token, end_token=end_token,
97 | )
98 | data = data.map(lambda d: tf.cast(d['tokens'], tf.int32))
99 | data = data.batch(batch_size, drop_remainder=True)
100 |
101 | context_iterator = data.make_one_shot_iterator()
102 |
103 | def sampler(scope=None):
104 | with tf.name_scope(scope, 'sample_corpus'):
105 | context_tokens = context_iterator.get_next()
106 | return dict(tokens=context_tokens)
107 | return sampler
108 |
--------------------------------------------------------------------------------
/lm_human_preferences/policy.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from lm_human_preferences.language import model, sample
4 | from lm_human_preferences.utils import core as utils
5 | from lm_human_preferences.utils.core import Schema
6 |
7 |
8 | class Policy:
9 | def __init__(
10 | self,
11 | trained_model, *,
12 | scope=None, use_resource=False,
13 | embed_queries=lambda queries: queries,
14 | temperature=1.0, is_root=True,
15 | build_respond=True,
16 | ):
17 | self.trained_model = trained_model
18 | self.model_hparams = trained_model.hparams()
19 | self.is_root = is_root
20 |
21 | self.use_resource = use_resource
22 | self.encoder = self.trained_model.encoding.get_encoder()
23 |
24 | with tf.variable_scope(scope, 'transformer_policy', use_resource=self.use_resource) as s:
25 | self.scope = s
26 | self.model = model.Model(
27 | hparams=self.model_hparams,
28 | scalar_heads=['value'])
29 |
30 | self.built = False
31 | self.embed_queries = embed_queries
32 | self.temperature = temperature
33 | self.padding_token = self.encoder.padding_token
34 |
35 | if build_respond:
36 | self.respond = utils.graph_function(
37 | queries=Schema(tf.int32, (None, None)),
38 | length=Schema(tf.int32, ()),
39 | )(self.respond_op)
40 | self.analyze_responses = utils.graph_function(
41 | queries=Schema(tf.int32, (None, None)),
42 | responses=Schema(tf.int32, (None, None)),
43 | )(self.analyze_responses_op)
44 |
45 | def get_encoder(self):
46 | return self.encoder
47 |
48 | def step_core(self, model_hparams, tokens, past=None, past_tokens=None, do_dropout=False, name=None):
49 | with tf.name_scope(name, 'step'):
50 | with tf.variable_scope(
51 | self.scope,
52 | reuse=self.built,
53 | auxiliary_name_scope=not self.built,
54 | use_resource=self.use_resource):
55 | lm_output = self.model(X=tokens, past=past, past_tokens=past_tokens,
56 | do_dropout=do_dropout, padding_token=self.padding_token)
57 |
58 | # need to slice logits since we don't want to generate special tokens
59 | logits = lm_output['lm_logits'][:,:,:self.model_hparams.n_vocab]
60 | presents = lm_output['present']
61 | value = lm_output['value']
62 | if not self.built:
63 | self._set_initializers()
64 | self.built = True
65 | return {
66 | 'logits': logits,
67 | 'values': value,
68 | 'presents': presents,
69 | }
70 |
71 | def ensure_built(self):
72 | if not self.built:
73 | with tf.name_scope('dummy'):
74 | self.step_core(self.model_hparams, tokens=tf.zeros([0,0], dtype=tf.int32))
75 |
76 | def get_params(self):
77 | self.ensure_built()
78 | params = utils.find_trainable_variables(self.scope.name)
79 | assert len(params) > 0
80 | return params
81 |
82 | def _set_initializers(self):
83 | """Change initializers to load a language model from a tensorflow checkpoint."""
84 | # Skip if
85 | # 1. We're not rank 0. Values will be copied from there.
86 | # 2. We want random initialization. Normal initialization will do the work.
87 | if not self.is_root or self.trained_model.name == 'test':
88 | return
89 |
90 | with tf.init_scope():
91 | scope = self.scope.name
92 |
93 | # Initialize!
94 | params = {v.op.name: v for v in utils.find_trainable_variables(scope)}
95 | self.trained_model.init_op(params, new_scope=scope)
96 |
97 | def respond_op(self, queries, length):
98 | contexts = self.embed_queries(queries)
99 | context_length = tf.shape(contexts)[1]
100 | result = sample.sample_sequence(
101 | step=self.step_core,
102 | context=contexts,
103 | length=length,
104 | model_hparams=self.model_hparams,
105 | temperature=self.temperature,
106 | extra_outputs={'values':tf.float32},
107 | )
108 | return dict(
109 | responses=result['tokens'][:, context_length:],
110 | logprobs=result['logprobs'],
111 | values=result['values'],
112 | )
113 |
114 | def analyze_responses_op(self, queries, responses):
115 | contexts = self.embed_queries(queries)
116 | context_length = tf.shape(contexts)[1]
117 | tokens = tf.concat([contexts, responses], axis=1)
118 | result = self.step_core(self.model_hparams, tokens)
119 | logits = result['logits'][:, context_length-1:-1]
120 |
121 | logits /= self.temperature
122 | return dict(
123 | logprobs = utils.logprobs_from_logits(logits=logits, labels=responses),
124 | entropies = utils.entropy_from_logits(logits),
125 | values = result['values'][:, context_length-1:-1],
126 | )
127 |
128 |
--------------------------------------------------------------------------------
/lm_human_preferences/rewards.py:
--------------------------------------------------------------------------------
1 | """Synthetic scores."""
2 |
3 | import os
4 |
5 | import tensorflow as tf
6 | from mpi4py import MPI
7 |
8 | from lm_human_preferences.language import trained_models, model
9 | from lm_human_preferences.utils import core as utils
10 | from lm_human_preferences.utils.core import Schema
11 |
12 |
13 | # TODO: combine this with TrainedRewardModel
14 | class RewardModelTrainer:
15 | def __init__(
16 | self,
17 | trained_model, *,
18 | scope='reward_model', use_resource=False,
19 | is_root=True,
20 | ):
21 | self.trained_model = trained_model
22 | self.hparams = trained_model.hparams()
23 | self.is_root = is_root
24 |
25 | self.use_resource = use_resource
26 | self.encoder = self.trained_model.encoding.get_encoder()
27 |
28 | self.scope = scope
29 | self.model = model.Model(hparams=self.hparams, scope=f'{scope}/model', scalar_heads=['reward'])
30 |
31 | self.built = False
32 | self.padding_token = self.encoder.padding_token
33 |
34 | self.get_rewards = utils.graph_function(
35 | queries=Schema(tf.int32, (None, None)),
36 | responses=Schema(tf.int32, (None, None)),
37 | )(self.get_rewards_op)
38 |
39 |
40 | def get_encoder(self):
41 | return self.encoder
42 |
43 | def _build(self, tokens, do_dropout=False, name=None):
44 | with tf.variable_scope(self.scope, reuse=self.built, auxiliary_name_scope=not self.built, use_resource=self.use_resource):
45 | lm_output = self.model(X=tokens, do_dropout=do_dropout, padding_token=self.padding_token)
46 |
47 | reward = lm_output['reward'][:, -1]
48 | with tf.variable_scope('reward_norm'):
49 | if not self.built:
50 | self.reward_gain = tf.get_variable('gain', shape=(), initializer=tf.constant_initializer(1))
51 | self.reward_bias = tf.get_variable('bias', shape=(), initializer=tf.constant_initializer(0))
52 | self._reward_gain_p = tf.placeholder(name='gain_p', dtype=tf.float32, shape=())
53 | self._reward_bias_p = tf.placeholder(name='bias_p', dtype=tf.float32, shape=())
54 | self._set_reward_norm = tf.group(self.reward_gain.assign(self._reward_gain_p),
55 | self.reward_bias.assign(self._reward_bias_p))
56 | if reward is not None:
57 | reward = self.reward_gain * reward + self.reward_bias
58 | if not self.built:
59 | self._set_initializers()
60 | self.built = True
61 | return reward
62 |
63 | def ensure_built(self):
64 | if self.built:
65 | return
66 | with tf.name_scope('dummy'):
67 | self._build(tokens=tf.zeros([0,0], dtype=tf.int32))
68 |
69 | def get_params(self):
70 | self.ensure_built()
71 | return self.model.get_params() + [self.reward_gain, self.reward_bias]
72 |
73 | def reset_reward_scale(self):
74 | sess = tf.get_default_session()
75 | sess.run(self._set_reward_norm, feed_dict={self._reward_gain_p: 1, self._reward_bias_p: 0})
76 |
77 | def set_reward_norm(self, *, old_mean, old_std, new_mean, new_std):
78 | """Given old_mean+-old_std of reward_model, change gain and bias to get N(new_mean,new_std)."""
79 | sess = tf.get_default_session()
80 | old_gain, old_bias = sess.run((self.reward_gain, self.reward_bias))
81 | assert old_gain == 1 and old_bias == 0,\
82 | f'set_reward_norm expects gain = 1 and bias = 0, not {old_gain}, {old_bias}'
83 | # gain * N(old_mean,old_std) + bias = N(gain * old_mean, gain * old_std) + bias
84 | # = N(gain * old_mean + bias, gain * old_std)
85 | # gain * old_std = new_std, gain = new_std / old_std
86 | # gain * old_mean + bias = new_mean, bias = new_mean - gain * old_mean
87 | gain = new_std / old_std
88 | bias = new_mean - gain * old_mean
89 | sess.run(self._set_reward_norm, feed_dict={self._reward_gain_p: gain, self._reward_bias_p: bias})
90 |
91 | def _set_initializers(self):
92 | """Change initializers to load a language model from a tensorflow checkpoint."""
93 | # Skip if
94 | # 1. We're not rank 0. Values will be copied from there.
95 | # 2. We want random initialization. Normal initialization will do the work.
96 | if not self.is_root or self.trained_model.name == 'test':
97 | return
98 |
99 | with tf.init_scope():
100 | # Initialize!
101 | params = {v.op.name: v for v in utils.find_trainable_variables(self.scope)}
102 | assert params
103 | self.trained_model.init_op(params, new_scope=self.scope)
104 |
105 | def get_rewards_op(self, queries, responses):
106 | tokens = tf.concat([queries, responses], axis=1)
107 | return self._build(tokens)
108 |
109 |
110 | class TrainedRewardModel():
111 | def __init__(self, train_dir, encoding, *, scope='reward_model', comm=MPI.COMM_WORLD):
112 | self.train_dir = train_dir
113 | self.comm = comm
114 |
115 | self.encoding = encoding
116 | encoder = encoding.get_encoder()
117 | if train_dir != 'test':
118 | self.hparams = trained_models.load_hparams(os.path.join(train_dir, 'hparams.json'))
119 | assert self.hparams.n_vocab == encoding.n_vocab, f'{self.hparams.n_vocab} != {encoding.n_vocab}'
120 | else:
121 | self.hparams = trained_models.test_hparams()
122 |
123 | self.padding_token = encoder.padding_token
124 |
125 | self.encoder = encoder
126 |
127 | self.scope = scope
128 | self.model = model.Model(hparams=self.hparams, scope=f'{scope}/model', scalar_heads=['reward'])
129 |
130 | def _build(self, X):
131 | results = self.model(X=X, padding_token=self.padding_token)
132 | reward = results['reward'][:, -1]
133 | with tf.variable_scope(f'{self.scope}/reward_norm'):
134 | self.reward_gain = tf.get_variable('gain', shape=(), initializer=tf.constant_initializer(1))
135 | self.reward_bias = tf.get_variable('bias', shape=(), initializer=tf.constant_initializer(0))
136 | reward = self.reward_gain * reward + self.reward_bias
137 | self._set_initializers()
138 | return reward
139 |
140 | def ensure_built(self):
141 | if self.model.built:
142 | return
143 | with tf.name_scope('dummy'):
144 | self._build(X=tf.zeros([0,0], dtype=tf.int32))
145 |
146 | def _set_initializers(self):
147 | """Change initializers to load a model from a tensorflow checkpoint."""
148 | if self.comm.Get_rank() > 0 or self.train_dir == 'test':
149 | return
150 |
151 | assert self.model.built
152 | checkpoint_scope = 'reward_model'
153 |
154 | with tf.init_scope():
155 | # Initialize!
156 | params = {v.op.name: v for v in self.get_params()}
157 | checkpoint = tf.train.latest_checkpoint(os.path.join(self.train_dir, 'checkpoints/'))
158 | available = tf.train.list_variables(checkpoint)
159 | unchanged = {}
160 |
161 | for name, shape in available:
162 | if not name.startswith(checkpoint_scope + '/'):
163 | # print('skipping', name)
164 | continue
165 | if name.endswith('adam') or name.endswith('adam_1'):
166 | # print('skipping', name)
167 | continue
168 | print('setting', name)
169 | var = params[self.scope + name[len(checkpoint_scope):]]
170 | assert var.shape == shape, 'Shape mismatch: %s.shape = %s != %s' % (var.op.name, var.shape, shape)
171 | unchanged[name] = var
172 | tf.train.init_from_checkpoint(checkpoint, unchanged)
173 |
174 | def get_params(self):
175 | return self.model.get_params() + [self.reward_gain, self.reward_bias]
176 |
177 | def score_fn(self, queries, responses):
178 | tokens = tf.concat([queries, responses], axis=1)
179 | return self._build(tokens)
180 |
--------------------------------------------------------------------------------
/lm_human_preferences/test_train_policy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import tempfile
4 | from lm_human_preferences import train_policy
5 |
6 | def hparams_for_test():
7 | hparams = train_policy.HParams()
8 | hparams.ppo.batch_size = 8
9 | hparams.noptepochs = 1
10 | hparams.task.policy.initial_model = 'test'
11 | hparams.task.query_length = 2
12 | hparams.task.response_length = 3
13 | hparams.task.query_dataset = 'test'
14 | hparams.rewards.trained_model = 'test'
15 | hparams.ppo.total_episodes = 8
16 | hparams.run.log_interval = 1
17 |
18 | return hparams
19 |
20 |
21 | def train_policy_test(override_params):
22 | hparams = hparams_for_test()
23 | hparams.override_from_dict(override_params)
24 | hparams.validate()
25 | train_policy.train(hparams=hparams)
26 |
27 |
28 | def test_truncation():
29 | train_policy_test({
30 | 'task.truncate_token': 13,
31 | 'task.truncate_after': 2,
32 | })
33 |
34 | def test_defaults():
35 | train_policy_test({})
36 |
37 | def test_affixing():
38 | train_policy_test({
39 | 'task.query_prefix': 'a',
40 | 'task.query_suffix': 'b'
41 | })
42 |
43 | def test_adaptive_kl():
44 | train_policy_test({
45 | 'rewards.trained_model': 'test', # not sure why needed
46 | 'rewards.adaptive_kl': 'on',
47 | 'rewards.adaptive_kl.target': 3.0,
48 | 'rewards.adaptive_kl.horizon': 100,
49 | })
50 |
51 | def test_save():
52 | train_policy_test({
53 | 'run.save_dir': tempfile.mkdtemp() ,
54 | 'run.save_interval': 1
55 | })
56 |
57 | def test_reward_training():
58 | train_policy_test({
59 | 'rewards.trained_model': None,
60 | 'rewards.train_new_model': 'on',
61 | 'rewards.train_new_model.task.policy.initial_model': 'test',
62 | 'rewards.train_new_model.task.query_length': 2,
63 | 'rewards.train_new_model.task.response_length': 3,
64 | 'rewards.train_new_model.task.query_dataset': 'test',
65 | 'rewards.train_new_model.labels.source': 'test',
66 | 'rewards.train_new_model.labels.num_train': 16,
67 | 'rewards.train_new_model.batch_size': 8,
68 | 'rewards.train_new_model.labels.type': 'best_of_4',
69 | })
70 |
--------------------------------------------------------------------------------
/lm_human_preferences/test_train_reward.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import tempfile
4 | from lm_human_preferences import train_reward
5 |
6 | def hparams_for_test():
7 | hparams = train_reward.HParams()
8 | hparams.rollout_batch_size = 8
9 | hparams.task.query_length = 2
10 | hparams.task.response_length = 3
11 | hparams.noptepochs = 1
12 | hparams.task.policy.initial_model = 'test'
13 | hparams.task.query_dataset = 'test'
14 | hparams.task.start_text = None
15 | hparams.run.log_interval = 1
16 |
17 | hparams.labels.source = 'test'
18 | hparams.labels.num_train = 16
19 | hparams.labels.type = 'best_of_4'
20 |
21 | hparams.batch_size = 8
22 |
23 | return hparams
24 |
25 |
26 | def train_reward_test(override_params):
27 | hparams = hparams_for_test()
28 | hparams.override_from_dict(override_params)
29 | hparams.validate()
30 | train_reward.train(hparams=hparams)
31 |
32 |
33 | def test_basic():
34 | train_reward_test({})
35 |
36 |
37 | def test_scalar_compare():
38 | train_reward_test({'labels.type': 'scalar_compare'})
39 |
40 |
41 | def test_scalar_rating():
42 | train_reward_test({'labels.type': 'scalar_rating'})
43 |
44 |
45 | def test_normalize_before():
46 | train_reward_test({
47 | 'normalize_before': True,
48 | 'normalize_after': False,
49 | 'normalize_samples': 1024,
50 | 'debug_normalize': 1024,
51 | })
52 |
53 |
54 | def test_normalize_both():
55 | train_reward_test({
56 | 'normalize_before': True,
57 | 'normalize_after': True,
58 | 'normalize_samples': 1024,
59 | 'debug_normalize': 1024,
60 | })
61 |
62 | def test_save():
63 | train_reward_test({
64 | 'run.save_dir': tempfile.mkdtemp() ,
65 | 'run.save_interval': 1
66 | })
67 |
--------------------------------------------------------------------------------
/lm_human_preferences/train_policy.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import json
4 | import os
5 | import sys
6 | import time
7 | from dataclasses import dataclass, field
8 | from functools import partial
9 | from typing import Optional
10 |
11 | import numpy as np
12 | import tensorflow as tf
13 | from mpi4py import MPI
14 | from tensorflow.contrib import summary
15 |
16 | from lm_human_preferences import lm_tasks, train_reward
17 | from lm_human_preferences.language import trained_models
18 | from lm_human_preferences.policy import Policy
19 | from lm_human_preferences.rewards import TrainedRewardModel
20 | from lm_human_preferences.utils import core as utils
21 | from lm_human_preferences.utils import hyperparams
22 | from lm_human_preferences.utils.core import Schema
23 |
24 |
25 | @dataclass
26 | class AdaptiveKLParams(hyperparams.HParams):
27 | target: float = None
28 | horizon: int = 10000 # in episodes
29 |
30 |
31 | @dataclass
32 | class RewardHParams(hyperparams.HParams):
33 | kl_coef: float = 0.2
34 | adaptive_kl: Optional[AdaptiveKLParams] = None
35 |
36 | trained_model: Optional[str] = None
37 |
38 | train_new_model: Optional[train_reward.HParams] = None
39 |
40 | def validate(self, *, prefix=''):
41 | super().validate(prefix=prefix)
42 | assert self.trained_model is None or self.train_new_model is None, 'Cannot use trained_model and train new model'
43 | assert self.trained_model is not None or self.train_new_model is not None, 'Need either trained_model or to train a new model'
44 |
45 |
46 | @dataclass
47 | class PpoHParams(hyperparams.HParams):
48 | total_episodes: int = 2000000
49 | batch_size: int = 64
50 | nminibatches: int = 1
51 | noptepochs: int = 4
52 | lr: float = 5e-6
53 | vf_coef: float = .1
54 | cliprange: float = .2
55 | cliprange_value: float = .2
56 | gamma: float = 1
57 | lam: float = 0.95
58 | whiten_rewards: bool = True
59 |
60 |
61 | @dataclass
62 | class HParams(hyperparams.HParams):
63 | run: train_reward.RunHParams = field(default_factory=train_reward.RunHParams)
64 |
65 | task: lm_tasks.TaskHParams = field(default_factory=lm_tasks.TaskHParams)
66 | rewards: RewardHParams = field(default_factory=RewardHParams)
67 | ppo: PpoHParams = field(default_factory=PpoHParams)
68 |
69 | def validate(self, *, prefix=''):
70 | super().validate(prefix=prefix)
71 | # NOTE: must additionally divide by # ranks
72 | minibatch_size = utils.exact_div(self.ppo.batch_size, self.ppo.nminibatches)
73 | if self.ppo.whiten_rewards:
74 | assert minibatch_size >= 8, \
75 | f"Minibatch size {minibatch_size} is insufficient for whitening in PPOTrainer.loss"
76 |
77 |
78 | def nupdates(hparams):
79 | return utils.ceil_div(hparams.ppo.total_episodes, hparams.ppo.batch_size)
80 |
81 |
82 | def policy_frac(hparams):
83 | """How far we are through policy training."""
84 | return tf.cast(tf.train.get_global_step(), tf.float32) / nupdates(hparams)
85 |
86 |
87 | def tf_times():
88 | """Returns (time since start, time since last) as a tensorflow op."""
89 | # Keep track of start and last times
90 | with tf.init_scope():
91 | init = tf.timestamp()
92 |
93 | def make(name):
94 | return tf.Variable(init, name=name, trainable=False, use_resource=True)
95 |
96 | start = make('start_time')
97 | last = make('last_time')
98 |
99 | # Get new time and update last
100 | now = tf.timestamp()
101 | prev = last.read_value()
102 | with tf.control_dependencies([prev]):
103 | with tf.control_dependencies([last.assign(now)]):
104 | return tf.cast(now - start.read_value(), tf.float32), tf.cast(now - prev, tf.float32)
105 |
106 |
107 | class FixedKLController:
108 | def __init__(self, kl_coef):
109 | self.value = kl_coef
110 |
111 | def update(self, current, n_steps):
112 | pass
113 |
114 |
115 | class AdaptiveKLController:
116 | def __init__(self, init_kl_coef, hparams):
117 | self.value = init_kl_coef
118 | self.hparams = hparams
119 |
120 | def update(self, current, n_steps):
121 | target = self.hparams.target
122 | proportional_error = np.clip(current / target - 1, -0.2, 0.2)
123 | mult = 1 + proportional_error * n_steps / self.hparams.horizon
124 | self.value *= mult
125 |
126 |
127 |
128 | class PPOTrainer():
129 | def __init__(self, *, policy, ref_policy, query_sampler, score_fn, hparams, comm):
130 | self.comm = comm
131 | self.policy = policy
132 | self.ref_policy = ref_policy
133 | self.score_fn = score_fn
134 | self.hparams = hparams
135 |
136 | if hparams.rewards.adaptive_kl is None:
137 | self.kl_ctl = FixedKLController(hparams.rewards.kl_coef)
138 | else:
139 | self.kl_ctl = AdaptiveKLController(hparams.rewards.kl_coef, hparams=hparams.rewards.adaptive_kl)
140 |
141 | response_length = hparams.task.response_length
142 | query_length = hparams.task.query_length
143 |
144 | @utils.graph_function()
145 | def sample_queries():
146 | return query_sampler()['tokens']
147 | self.sample_queries = sample_queries
148 |
149 | def compute_rewards(scores, logprobs, ref_logprobs):
150 | kl = logprobs - ref_logprobs
151 | non_score_reward = -self.kl_ctl.value * kl
152 | rewards = non_score_reward.copy()
153 | rewards[:, -1] += scores
154 | return rewards, non_score_reward, self.kl_ctl.value
155 | self.compute_rewards = compute_rewards
156 |
157 | # per rank sizes
158 | per_rank_rollout_batch_size = utils.exact_div(hparams.ppo.batch_size, comm.Get_size())
159 | per_rank_minibatch_size = utils.exact_div(per_rank_rollout_batch_size, hparams.ppo.nminibatches)
160 |
161 | @utils.graph_function(
162 | rollouts=dict(
163 | queries=Schema(tf.int32, (per_rank_minibatch_size, query_length)),
164 | responses=Schema(tf.int32, (per_rank_minibatch_size, response_length)),
165 | values=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
166 | logprobs=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
167 | rewards=Schema(tf.float32, (per_rank_minibatch_size, response_length)),
168 | ))
169 | def train_minibatch(rollouts):
170 | """One step of PPO training."""
171 |
172 | left = 1 - policy_frac(hparams)
173 | lrnow = hparams.ppo.lr * left
174 |
175 | ppo_loss, stats = self.loss(rollouts)
176 | ppo_train_op = utils.minimize(
177 | loss=ppo_loss, lr=lrnow, params=policy.get_params(), name='ppo_opt', comm=self.comm)
178 | return ppo_train_op, stats
179 |
180 | def train(rollouts):
181 | stat_list = []
182 |
183 | # Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
184 | for ppo_epoch_idx in range(hparams.ppo.noptepochs):
185 | order = np.random.permutation(per_rank_rollout_batch_size)
186 | for mb_start in range(0, per_rank_rollout_batch_size, per_rank_minibatch_size):
187 | mb_data = {k: v[order[mb_start:mb_start+per_rank_minibatch_size]]
188 | for k, v in rollouts.items()}
189 |
190 | step = tf.train.get_global_step().eval()
191 |
192 | _, stats = train_minibatch(mb_data)
193 | stat_list.append(stats)
194 |
195 | # Collect the stats. (They will be averaged later.)
196 | return {k: [s[k] for s in stat_list] for k in stat_list[0].keys()}
197 | self.train = train
198 |
199 | # NOTE: must line up with stats created in self.loss (TODO: better solution?)
200 | scalar_batch = Schema(tf.float32, (None,))
201 | ppo_stat_schemas = utils.flatten_dict(dict(
202 | loss=dict(policy=scalar_batch, value=scalar_batch, total=scalar_batch),
203 | policy=dict(entropy=scalar_batch, approxkl=scalar_batch, clipfrac=scalar_batch),
204 | returns=dict(mean=scalar_batch, var=scalar_batch),
205 | val=dict(vpred=scalar_batch, error=scalar_batch, clipfrac=scalar_batch, mean=scalar_batch, var=scalar_batch),
206 | ), sep='/')
207 | stat_data_schemas = dict(
208 | logprobs=Schema(tf.float32, (None, hparams.task.response_length)),
209 | ref_logprobs=Schema(tf.float32, (None, hparams.task.response_length)),
210 | scores=scalar_batch,
211 | non_score_reward=Schema(tf.float32, (None, hparams.task.response_length)),
212 | score_stats=score_fn.stat_schemas,
213 | train_stats=ppo_stat_schemas,
214 | )
215 | @utils.graph_function(
216 | **stat_data_schemas, kl_coef=Schema(tf.float32, ()))
217 | def record_step_stats(*, kl_coef, **data):
218 | ppo_summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='ppo', comm=self.comm)
219 |
220 | kl = data['logprobs'] - data['ref_logprobs']
221 | mean_kl = tf.reduce_mean(tf.reduce_sum(kl, axis=1))
222 | mean_entropy = tf.reduce_mean(tf.reduce_sum(-data['logprobs'], axis=1))
223 | mean_non_score_reward = tf.reduce_mean(tf.reduce_sum(data['non_score_reward'], axis=1))
224 | stats = {
225 | 'objective/kl': mean_kl,
226 | 'objective/kl_coef': kl_coef,
227 | 'objective/entropy': mean_entropy,
228 | }
229 | for k, v in data['train_stats'].items():
230 | stats[f'ppo/{k}'] = tf.reduce_mean(v, axis=0)
231 | for k, v in data['score_stats'].items():
232 | mean = tf.reduce_mean(v, axis=0)
233 | stats[f'objective/{k}'] = mean
234 | stats[f'objective/{k}_total'] = mean + mean_non_score_reward
235 |
236 | stats = utils.FlatStats.from_dict(stats).map_flat(
237 | partial(utils.mpi_allreduce_mean, comm=self.comm)).as_dict()
238 |
239 | # Add more statistics
240 | step = tf.train.get_global_step().read_value()
241 | stats['ppo/val/var_explained'] = 1 - stats['ppo/val/error'] / stats['ppo/returns/var']
242 | steps = step + 1
243 | stats.update({
244 | 'elapsed/updates': steps,
245 | 'elapsed/steps/serial': steps * hparams.task.response_length,
246 | 'elapsed/steps/total': steps * hparams.ppo.batch_size * hparams.task.response_length,
247 | 'elapsed/episodes': steps * hparams.ppo.batch_size,
248 | })
249 |
250 | # Time statistics
251 | total, delta = tf_times()
252 | stats.update({
253 | 'elapsed/fps': tf.cast(hparams.ppo.batch_size * hparams.task.response_length / delta, tf.int32),
254 | 'elapsed/time': total,
255 | })
256 | if ppo_summary_writer:
257 | record_op = utils.record_stats(
258 | stats=stats, summary_writer=ppo_summary_writer, step=step, log_interval=hparams.run.log_interval, name='ppo_stats', comm=self.comm)
259 | else:
260 | record_op = tf.no_op()
261 | return record_op, stats
262 | self.record_step_stats = record_step_stats
263 |
264 | def print_samples(self, queries, responses, scores, logprobs, ref_logprobs):
265 | if self.comm.Get_rank() != 0:
266 | return
267 | if tf.train.get_global_step().eval() % self.hparams.run.log_interval != 0:
268 | return
269 |
270 | encoder = self.policy.encoder
271 |
272 | # Log samples
273 | for i in range(min(3, len(queries))):
274 | sample_kl = np.sum(logprobs[i] - ref_logprobs[i])
275 | print(encoder.decode(queries[i][:self.hparams.task.query_length]).replace("\n", "⏎"))
276 | print(encoder.decode(responses[i]).replace("\n", "⏎"))
277 | print(f" score = {scores[i]:+.2f}")
278 | print(f" kl = {sample_kl:+.2f}")
279 | print(f" total = {scores[i] - self.hparams.rewards.kl_coef * sample_kl:+.2f}")
280 |
281 | def step(self):
282 | step_started_at = time.time()
283 |
284 | queries = self.sample_queries()
285 | rollouts = self.policy.respond(queries, length=self.hparams.task.response_length)
286 |
287 | responses = rollouts['responses']
288 | logprobs = rollouts['logprobs']
289 | rollouts['queries'] = queries
290 | ref_logprobs = self.ref_policy.analyze_responses(queries, responses)['logprobs']
291 | scores, postprocessed_responses, score_stats = self.score_fn(queries, responses)
292 |
293 | rewards, non_score_reward, kl_coef = self.compute_rewards(
294 | scores=scores,
295 | logprobs=logprobs,
296 | ref_logprobs=ref_logprobs)
297 | rollouts['rewards'] = rewards
298 |
299 | train_stats = self.train(rollouts=rollouts)
300 |
301 | _, stats = self.record_step_stats(
302 | scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs, non_score_reward=non_score_reward,
303 | train_stats=train_stats, score_stats=score_stats, kl_coef=kl_coef)
304 |
305 | self.kl_ctl.update(stats['objective/kl'], self.hparams.ppo.batch_size)
306 |
307 | self.print_samples(queries=queries, responses=postprocessed_responses,
308 | scores=scores, logprobs=logprobs, ref_logprobs=ref_logprobs)
309 |
310 | # Record profiles of the step times
311 | step = tf.get_default_session().run(tf.train.get_global_step())
312 | step_time = time.time() - step_started_at
313 | eps_per_second = float(self.hparams.ppo.batch_size) / step_time
314 | if self.comm.Get_rank() == 0:
315 | print(f"[ppo_step {step}] step_time={step_time:.2f}s, "
316 | f"eps/s={eps_per_second:.2f}")
317 |
318 |
319 | def loss(self, rollouts):
320 | values = rollouts['values']
321 | old_logprob = rollouts['logprobs']
322 | rewards = rollouts['rewards']
323 | with tf.name_scope('ppo_loss'):
324 | if self.hparams.ppo.whiten_rewards:
325 | rewards = utils.whiten(rewards, shift_mean=False)
326 |
327 | lastgaelam = 0
328 | advantages_reversed = []
329 | gen_length = self.hparams.task.response_length
330 | for t in reversed(range(gen_length)):
331 | nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
332 | delta = rewards[:, t] + self.hparams.ppo.gamma * nextvalues - values[:, t]
333 | lastgaelam = delta + self.hparams.ppo.gamma * self.hparams.ppo.lam * lastgaelam
334 | advantages_reversed.append(lastgaelam)
335 | advantages = tf.stack(advantages_reversed[::-1], axis=1)
336 | returns = advantages + values
337 |
338 | advantages = utils.whiten(advantages)
339 | advantages = tf.stop_gradient(advantages) # Shouldn't do anything, but better not to think about it
340 |
341 | outputs = self.policy.analyze_responses_op(rollouts['queries'], rollouts['responses'])
342 |
343 | vpred = outputs['values']
344 | vpredclipped = tf.clip_by_value(vpred, values - self.hparams.ppo.cliprange_value, values + self.hparams.ppo.cliprange_value)
345 | vf_losses1 = tf.square(vpred - returns)
346 | vf_losses2 = tf.square(vpredclipped - returns)
347 | vf_loss = .5 * tf.reduce_mean(tf.maximum(vf_losses1, vf_losses2))
348 | vf_clipfrac = tf.reduce_mean(tf.cast(tf.greater(vf_losses2, vf_losses1), tf.float32))
349 |
350 | logprob = outputs['logprobs']
351 | ratio = tf.exp(logprob - old_logprob)
352 | pg_losses = -advantages * ratio
353 | pg_losses2 = -advantages * tf.clip_by_value(ratio, 1.0 - self.hparams.ppo.cliprange, 1.0 + self.hparams.ppo.cliprange)
354 | pg_loss = tf.reduce_mean(tf.maximum(pg_losses, pg_losses2))
355 | pg_clipfrac = tf.reduce_mean(tf.cast(tf.greater(pg_losses2, pg_losses), tf.float32))
356 |
357 | loss = pg_loss + self.hparams.ppo.vf_coef * vf_loss
358 |
359 | entropy = tf.reduce_mean(outputs['entropies'])
360 | approxkl = .5 * tf.reduce_mean(tf.square(logprob - old_logprob))
361 |
362 | return_mean, return_var = tf.nn.moments(returns, axes=list(range(returns.shape.ndims)))
363 | value_mean, value_var = tf.nn.moments(values, axes=list(range(values.shape.ndims)))
364 |
365 | stats = dict(
366 | loss=dict(policy=pg_loss, value=vf_loss, total=loss),
367 | policy=dict(entropy=entropy, approxkl=approxkl, clipfrac=pg_clipfrac),
368 | returns=dict(mean=return_mean, var=return_var),
369 | val=dict(vpred=tf.reduce_mean(vpred), error=tf.reduce_mean((vpred - returns) ** 2),
370 | clipfrac=vf_clipfrac, mean=value_mean, var=value_var)
371 | )
372 | return loss, utils.flatten_dict(stats, sep='/')
373 |
374 |
375 | def make_score_fn(hparams, score_model):
376 | padding_token = score_model.padding_token
377 |
378 | postprocess_fn = lm_tasks.postprocess_fn_from_hparams(hparams, padding_token)
379 | #decorate requires a named function, postprocess_fn can be anonymous
380 | @utils.graph_function(responses=Schema(tf.int32, (None, None)))
381 | def postprocess(responses):
382 | return postprocess_fn(responses)
383 |
384 | filter_fn = lm_tasks.filter_fn_from_hparams(hparams)
385 | @utils.graph_function(
386 | responses=Schema(tf.int32, (None, None)),
387 | rewards=Schema(tf.float32, (None,)))
388 | def penalize(responses, rewards):
389 | valid = filter_fn(responses)
390 | return tf.where(valid, rewards, hparams.penalty_reward_value * tf.ones_like(rewards))
391 |
392 | @utils.graph_function(
393 | queries=Schema(tf.int32, (None, None)),
394 | responses=Schema(tf.int32, (None, None))
395 | )
396 | def unpenalized_score_fn(queries, responses):
397 | return score_model.score_fn(queries, responses)
398 |
399 | def score_fn(queries, responses):
400 | responses = postprocess(responses)
401 | score = penalize(responses, unpenalized_score_fn(queries, responses))
402 | return score, responses, dict(score=score)
403 | score_fn.stat_schemas = dict(score=Schema(tf.float32, (None,)))
404 | return score_fn
405 |
406 |
407 |
408 | def train(hparams: HParams):
409 | save_dir = hparams.run.save_dir
410 | if hparams.rewards.train_new_model:
411 | assert hparams.task == hparams.rewards.train_new_model.task, f'{hparams.task} != {hparams.rewards.train_new_model.task}'
412 | hparams.rewards.train_new_model.run.save_dir = save_dir
413 | train_reward.train(hparams.rewards.train_new_model)
414 | if 'pytest' in sys.modules:
415 | hparams.rewards.trained_model = 'test'
416 | elif save_dir:
417 | hparams.rewards.trained_model = None if save_dir is None else os.path.join(save_dir, 'reward_model')
418 |
419 | comm = MPI.COMM_WORLD
420 |
421 | with tf.Graph().as_default():
422 | hyperparams.dump(hparams)
423 |
424 | m = trained_models.TrainedModel(hparams.task.policy.initial_model)
425 | encoder = m.encoding.get_encoder()
426 | hyperparams.dump(m.hparams(), name='model_hparams')
427 |
428 | if save_dir:
429 | if not save_dir.startswith('https:'):
430 | os.makedirs(os.path.join(save_dir, 'policy'), exist_ok=True)
431 | with tf.gfile.Open(os.path.join(save_dir, 'train_policy_hparams.json'), 'w') as f:
432 | json.dump(hparams.to_nested_dict(), f, indent=2)
433 | with tf.gfile.Open(os.path.join(save_dir, 'policy', 'hparams.json'), 'w') as f:
434 | json.dump(m.hparams().to_nested_dict(), f, indent=2)
435 | with tf.gfile.Open(os.path.join(save_dir, 'policy', 'encoding'), 'w') as f:
436 | json.dump(m.encoding.name, f, indent=2)
437 | utils.set_mpi_seed(hparams.run.seed)
438 |
439 | score_model = TrainedRewardModel(hparams.rewards.trained_model, m.encoding, comm=comm)
440 |
441 | ref_policy = Policy(
442 | m, scope='ref_policy',
443 | is_root=comm.Get_rank() == 0,
444 | embed_queries=lm_tasks.query_formatter(hparams.task, encoder),
445 | temperature=hparams.task.policy.temperature,
446 | build_respond=False)
447 |
448 | policy = Policy(
449 | m, scope='policy',
450 | is_root=comm.Get_rank() == 0,
451 | embed_queries=lm_tasks.query_formatter(hparams.task, encoder),
452 | temperature=hparams.task.policy.temperature)
453 |
454 | query_sampler = lm_tasks.make_query_sampler(
455 | hparams=hparams.task, encoder=encoder, comm=comm,
456 | batch_size=utils.exact_div(hparams.ppo.batch_size, comm.Get_size()),
457 | )
458 |
459 | per_rank_minibatch_size = utils.exact_div(hparams.ppo.batch_size, hparams.ppo.nminibatches * comm.Get_size())
460 | if hparams.ppo.whiten_rewards:
461 | assert per_rank_minibatch_size >= 8, \
462 | f"Per-rank minibatch size {per_rank_minibatch_size} is insufficient for whitening"
463 |
464 | global_step = tf.train.get_or_create_global_step()
465 | increment_global_step = tf.group(global_step.assign_add(1))
466 |
467 | with utils.variables_on_gpu():
468 |
469 | ppo_trainer = PPOTrainer(
470 | policy=policy, ref_policy=ref_policy, query_sampler=query_sampler,
471 | score_fn=make_score_fn(hparams.task, score_model=score_model),
472 | hparams=hparams, comm=comm)
473 |
474 | if comm.Get_rank() == 0 and save_dir:
475 | print(f"Will save to {save_dir}")
476 | saver = tf.train.Saver(max_to_keep=20, save_relative_paths=True)
477 | checkpoint_dir = os.path.join(save_dir, 'policy/checkpoints/model.ckpt')
478 | else:
479 | saver = None
480 | checkpoint_dir = None
481 |
482 | @utils.graph_function()
483 | def sync_models():
484 | score_model.ensure_built()
485 | return utils.variable_synchronizer(comm, vars=score_model.get_params() + ref_policy.get_params() + policy.get_params())
486 |
487 | init_ops = tf.group(
488 | tf.global_variables_initializer(),
489 | tf.local_variables_initializer(),
490 | summary.summary_writer_initializer_op())
491 |
492 | with utils.mpi_session() as sess:
493 | init_ops.run()
494 |
495 | sync_models()
496 |
497 | tf.get_default_graph().finalize()
498 |
499 | try:
500 | while global_step.eval() < nupdates(hparams):
501 | ppo_trainer.step()
502 | increment_global_step.run()
503 |
504 | if saver and global_step.eval() % hparams.run.save_interval == 0:
505 | saver.save(sess, checkpoint_dir, global_step=global_step)
506 | finally:
507 | if saver:
508 | saver.save(sess, checkpoint_dir, global_step=global_step)
509 |
--------------------------------------------------------------------------------
/lm_human_preferences/train_reward.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import json
4 | import os
5 | from dataclasses import dataclass, field
6 | from functools import partial
7 | from typing import Optional
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 | from mpi4py import MPI
12 | from tensorflow.contrib import summary
13 |
14 | from lm_human_preferences import label_types, lm_tasks, rewards
15 | from lm_human_preferences.language import trained_models
16 | from lm_human_preferences.policy import Policy
17 | from lm_human_preferences.utils import core as utils
18 | from lm_human_preferences.utils import gcs, hyperparams
19 | from lm_human_preferences.utils.core import Schema
20 |
21 |
22 | @dataclass
23 | class LabelHParams(hyperparams.HParams):
24 | type: str = None
25 | num_train: int = None
26 | source: str = None
27 |
28 |
29 | @dataclass
30 | class RunHParams(hyperparams.HParams):
31 | seed: Optional[int] = None
32 | log_interval: int = 10
33 | save_interval: int = 50
34 | save_dir: Optional[str] = None
35 |
36 | @dataclass
37 | class HParams(hyperparams.HParams):
38 | run: RunHParams = field(default_factory=RunHParams)
39 |
40 | task: lm_tasks.TaskHParams = field(default_factory=lm_tasks.TaskHParams)
41 | labels: LabelHParams = field(default_factory=LabelHParams)
42 |
43 | batch_size: int = 40 # total across ranks
44 | lr: float = 5e-5
45 |
46 | rollout_batch_size: int = 64
47 | normalize_samples: int = 0 # Samples used to estimate reward mean and std
48 | debug_normalize: int = 0 # Samples used to check that normalization worked
49 | # Whether, before training, to normalize the rewards on the policy to the scales on the training buffer.
50 | # (For comparisons, just use mean 0, var 1.)
51 | normalize_before: bool = False
52 | # Whether, after training, to normalize the rewards on the ref policy to mean 0, var 1
53 | # (so the KL coefficient always has the same meaning).
54 | normalize_after: bool = False
55 |
56 | def validate(self, *, prefix=''):
57 | super().validate(prefix=prefix)
58 | utils.exact_div(self.labels.num_train, self.batch_size)
59 |
60 | def round_down_to_multiple(n, divisor):
61 | return n - n % divisor
62 |
63 |
64 | def download_labels(source, label_type, question_schemas, total_labels, comm):
65 | schemas = {**question_schemas, **label_type.label_schemas()}
66 |
67 | """
68 | if self.is_root:
69 | with tf.device('cpu:0'):
70 | self._enqueue_phs = {
71 | name: tf.placeholder(name=name, dtype=schema.dtype, shape=(None,) + schema.shape)
72 | for name, schema in self.schemas.items()
73 | }
74 | self._enqueue_answers = self.answer_queue.enqueue_many(self._enqueue_phs)
75 | else:
76 | self._enqueue_phs = None
77 | self._enqueue_answers = None
78 | """
79 |
80 | # TODO: download on just one rank? then do: labels = utils.mpi_bcast_tensor_dict(labels, comm=comm)
81 | if source != 'test':
82 | with open(gcs.download_file_cached(source, comm=comm)) as f:
83 | results = json.load(f)
84 | print('Num labels found in source:', len(results))
85 | else:
86 | results = [
87 | {
88 | name: np.zeros(schema.shape, dtype=schema.dtype.as_numpy_dtype)
89 | for name, schema in schemas.items()
90 | }
91 | for _ in range(50)
92 | ]
93 |
94 | assert len(results) >= total_labels
95 | results = results[:total_labels]
96 | return {k: [a[k] for a in results] for k in schemas.keys()}
97 |
98 |
99 | class RewardModelTrainer():
100 | def __init__(self, *, reward_model, policy, query_sampler, hparams, comm):
101 | self.reward_model = reward_model
102 |
103 | self.policy = policy
104 | self.hparams = hparams
105 | self.num_ranks = comm.Get_size()
106 | self.rank = comm.Get_rank()
107 | self.comm = comm
108 |
109 | self.label_type = label_types.get(hparams.labels.type)
110 | self.question_schemas = self.label_type.question_schemas(
111 | query_length=hparams.task.query_length,
112 | response_length=hparams.task.response_length,
113 | )
114 |
115 | data_schemas = {
116 | **self.question_schemas,
117 | **self.label_type.label_schemas(),
118 | }
119 |
120 | with tf.device(None), tf.device('/cpu:0'):
121 | with tf.variable_scope('label_buffer', use_resource=True, initializer=tf.zeros_initializer):
122 | self.train_buffer = utils.SampleBuffer(capacity=hparams.labels.num_train, schemas=data_schemas)
123 |
124 | with tf.name_scope('train_reward'):
125 | summary_writer = utils.get_summary_writer(self.hparams.run.save_dir, subdir='reward_model', comm=comm)
126 |
127 | @utils.graph_function(
128 | indices=Schema(tf.int32, (None,)),
129 | lr=Schema(tf.float32, ()))
130 | def train_batch(indices, lr):
131 | with tf.name_scope('minibatch'):
132 | minibatch = self.train_buffer.read(indices)
133 | stats = self.label_type.loss(reward_model=self.reward_model.get_rewards_op, labels=minibatch)
134 |
135 | train_op = utils.minimize(
136 | loss=stats['loss'], lr=lr, params=self.reward_model.get_params(), name='opt', comm=self.comm)
137 |
138 | with tf.control_dependencies([train_op]):
139 | step_var = tf.get_variable(name='train_step', dtype=tf.int64, shape=(), trainable=False, use_resource=True)
140 | step = step_var.assign_add(1) - 1
141 |
142 | stats = utils.FlatStats.from_dict(stats).map_flat(partial(utils.mpi_allreduce_mean, comm=comm)).as_dict()
143 |
144 | train_stat_op = utils.record_stats(stats=stats, summary_writer=summary_writer, step=step, log_interval=hparams.run.log_interval, comm=comm)
145 |
146 | return train_stat_op
147 | self.train_batch = train_batch
148 |
149 | if self.hparams.normalize_before or self.hparams.normalize_after:
150 | @utils.graph_function()
151 | def target_mean_std():
152 | """Returns the means and variances to target for each reward model"""
153 | # Should be the same on all ranks because the train_buf should be the same
154 | scales = self.label_type.target_scales(self.train_buffer.data())
155 | if scales is None:
156 | return tf.zeros([]), tf.ones([])
157 | else:
158 | mean, var = tf.nn.moments(scales, axes=[0])
159 | return mean, tf.sqrt(var)
160 | self.target_mean_std = target_mean_std
161 |
162 | def stats(query_responses):
163 | rewards = np.concatenate([self.reward_model.get_rewards(qs, rs) for qs, rs in query_responses], axis=0)
164 | assert len(rewards.shape) == 1, f'{rewards.shape}'
165 | sums = np.asarray([rewards.sum(axis=0), np.square(rewards).sum(axis=0)])
166 | means, sqr_means = self.comm.allreduce(sums, op=MPI.SUM) / (self.num_ranks * rewards.shape[0])
167 | stds = np.sqrt(sqr_means - means ** 2)
168 | return means, stds
169 | self.stats = stats
170 |
171 | def log_stats_after_normalize(stats):
172 | if comm.Get_rank() != 0:
173 | return
174 | means, stds = stats
175 | print(f'after normalize: {means} +- {stds}')
176 | self.log_stats_after_normalize = log_stats_after_normalize
177 |
178 | def reset_reward_scales():
179 | self.reward_model.reset_reward_scale()
180 | self.reset_reward_scales = reset_reward_scales
181 |
182 | def set_reward_norms(mean, std, new_mean, new_std):
183 | print(f'targets: {new_mean} +- {new_std}')
184 | print(f'before normalize: {mean} +- {std}')
185 | assert np.isfinite((mean, std, new_mean, new_std)).all()
186 | self.reward_model.set_reward_norm(old_mean=mean, old_std=std, new_mean=new_mean, new_std=new_std)
187 | self.set_reward_norms = set_reward_norms
188 |
189 | if self.hparams.normalize_before or self.hparams.normalize_after:
190 | @utils.graph_function()
191 | def sample_policy_batch():
192 | queries = query_sampler('ref_queries')['tokens']
193 | responses = policy.respond_op(
194 | queries=queries, length=hparams.task.response_length)['responses']
195 | return queries, responses
196 |
197 | def sample_policy_responses(n_samples):
198 | n_batches = utils.ceil_div(n_samples, hparams.rollout_batch_size)
199 | return [sample_policy_batch() for _ in range(n_batches)]
200 | self.sample_policy_responses = sample_policy_responses
201 |
202 | @utils.graph_function(labels=utils.add_batch_dim(data_schemas))
203 | def add_to_buffer(labels):
204 | return self.train_buffer.add(**labels)
205 | self.add_to_buffer = add_to_buffer
206 |
207 | def normalize(self, sample_fn, target_means, target_stds):
208 | if not self.hparams.normalize_samples:
209 | return
210 |
211 | self.reset_reward_scales()
212 | query_responses = sample_fn(self.hparams.normalize_samples)
213 | means, stds = self.stats(query_responses)
214 |
215 | self.set_reward_norms(means, stds, target_means, target_stds)
216 | if self.hparams.debug_normalize:
217 | query_responses = sample_fn(self.hparams.debug_normalize)
218 | stats = self.stats(query_responses)
219 | self.log_stats_after_normalize(stats)
220 |
221 | def train(self):
222 | labels = download_labels(
223 | self.hparams.labels.source,
224 | label_type=self.label_type,
225 | question_schemas=self.question_schemas,
226 | total_labels=self.hparams.labels.num_train,
227 | comm=self.comm
228 | )
229 |
230 | self.add_to_buffer(labels)
231 |
232 | if self.hparams.normalize_before:
233 | target_mean, target_std = self.target_mean_std()
234 | self.normalize(self.sample_policy_responses, target_mean, target_std)
235 |
236 | # Collect training data for reward model training. train_indices will include the indices
237 | # trained on across all ranks, and its size must be a multiple of minibatch_size.
238 | per_rank_batch_size = utils.exact_div(self.hparams.batch_size, self.num_ranks)
239 |
240 | # Make sure each rank gets the same shuffle so we train on each point exactly once
241 | train_indices = self.comm.bcast(np.random.permutation(self.hparams.labels.num_train))
242 |
243 | # Train on train_indices
244 | print(self.rank, "training on", self.hparams.labels.num_train, "in batches of", per_rank_batch_size)
245 | for start_index in range(0, self.hparams.labels.num_train, self.hparams.batch_size):
246 | end_index = start_index + self.hparams.batch_size
247 | all_ranks_indices = train_indices[start_index:end_index]
248 | our_indices = all_ranks_indices[self.rank::self.num_ranks]
249 | lr = (1 - start_index / self.hparams.labels.num_train) * self.hparams.lr
250 | self.train_batch(our_indices, lr)
251 |
252 | if self.hparams.normalize_after:
253 | target_mean, target_std = np.zeros([]), np.ones([])
254 | self.normalize(self.sample_policy_responses, target_mean, target_std)
255 |
256 |
257 |
258 | def train(hparams: HParams):
259 | with tf.Graph().as_default():
260 | hyperparams.dump(hparams)
261 | utils.set_mpi_seed(hparams.run.seed)
262 |
263 | m = trained_models.TrainedModel(hparams.task.policy.initial_model)
264 | encoder = m.encoding.get_encoder()
265 | hyperparams.dump(m.hparams(), name='model_hparams')
266 |
267 | comm = MPI.COMM_WORLD
268 | ref_policy = Policy(
269 | m, scope='ref_policy',
270 | is_root=comm.Get_rank() == 0,
271 | embed_queries=lm_tasks.query_formatter(hparams.task, encoder),
272 | temperature=hparams.task.policy.temperature,
273 | build_respond=False)
274 |
275 | reward_model = rewards.RewardModelTrainer(m, is_root=comm.Get_rank() == 0)
276 |
277 | query_sampler = lm_tasks.make_query_sampler(
278 | hparams=hparams.task, encoder=encoder, comm=comm,
279 | batch_size=utils.exact_div(hparams.rollout_batch_size, comm.Get_size())
280 | )
281 |
282 | tf.train.create_global_step()
283 |
284 | reward_trainer = RewardModelTrainer(
285 | reward_model=reward_model,
286 | policy=ref_policy,
287 | query_sampler=query_sampler,
288 | hparams=hparams,
289 | comm=comm,
290 | )
291 |
292 | save_dir = hparams.run.save_dir
293 | if comm.Get_rank() == 0 and save_dir:
294 | print(f"Will save to {save_dir}")
295 | saver = tf.train.Saver(max_to_keep=20, save_relative_paths=True)
296 | checkpoint_dir = os.path.join(save_dir, 'reward_model/checkpoints/model.ckpt')
297 |
298 | if not save_dir.startswith('gs://'):
299 | os.makedirs(os.path.join(save_dir, 'reward_model'), exist_ok=True)
300 | with tf.gfile.Open(os.path.join(save_dir, 'train_reward_hparams.json'), 'w') as f:
301 | json.dump(hparams.to_nested_dict(), f, indent=2)
302 | with tf.gfile.Open(os.path.join(save_dir, 'reward_model', 'hparams.json'), 'w') as f:
303 | json.dump(reward_model.hparams.to_nested_dict(), f, indent=2)
304 | with tf.gfile.Open(os.path.join(save_dir, 'reward_model', 'encoding'), 'w') as f:
305 | json.dump(reward_model.trained_model.encoding.name, f, indent=2)
306 | else:
307 | saver = None
308 | checkpoint_dir = None
309 |
310 | with utils.variables_on_gpu():
311 | init_ops = tf.group(
312 | tf.global_variables_initializer(),
313 | tf.local_variables_initializer(),
314 | summary.summary_writer_initializer_op())
315 |
316 | @utils.graph_function()
317 | def sync_models():
318 | return utils.variable_synchronizer(comm, vars=ref_policy.get_params() + reward_model.get_params())
319 |
320 | tf.get_default_graph().finalize()
321 |
322 | with utils.mpi_session() as sess:
323 | init_ops.run()
324 | sync_models()
325 |
326 | reward_trainer.train()
327 |
328 | if saver:
329 | saver.save(sess, checkpoint_dir)
330 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/combos.py:
--------------------------------------------------------------------------------
1 | def combos(*xs):
2 | if xs:
3 | return [x + combo for x in xs[0] for combo in combos(*xs[1:])]
4 | else:
5 | return [()]
6 |
7 | def each(*xs):
8 | return [y for x in xs for y in x]
9 |
10 | def bind(var, val, descriptor=''):
11 | extra = {}
12 | if descriptor:
13 | extra['descriptor'] = descriptor
14 | return [((var, val, extra),)]
15 |
16 | def label(descriptor):
17 | return bind(None, None, descriptor)
18 |
19 | def labels(*descriptors):
20 | return each(*[label(d) for d in descriptors])
21 |
22 | def options(var, opts_with_descs):
23 | return each(*[bind(var, val, descriptor) for val, descriptor in opts_with_descs])
24 |
25 | def _shortstr(v):
26 | if isinstance(v, float):
27 | s = f"{v:.03}"
28 | if '.' in s:
29 | s = s.lstrip('0').replace('.','x')
30 | else:
31 | s = str(v)
32 | return s
33 |
34 | def options_shortdesc(var, desc, opts):
35 | return each(*[bind(var, val, desc + _shortstr(val)) for val in opts])
36 |
37 | def options_vardesc(var, opts):
38 | return options_shortdesc(var, var, opts)
39 |
40 | def repeat(n):
41 | return each(*[label(i) for i in range(n)])
42 |
43 | # list monad bind; passes descriptors to body
44 | def foreach(inputs, body):
45 | return [inp + y for inp in inputs for y in body(*[extra['descriptor'] for var, val, extra in inp])]
46 |
47 | def bind_nested(prefix, binds):
48 | return [
49 | tuple([ (var if var is None else prefix + '.' + var, val, extra) for (var, val, extra) in x ])
50 | for x in binds
51 | ]
52 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/core.py:
--------------------------------------------------------------------------------
1 | """Utilities."""
2 |
3 | import collections
4 | import contextlib
5 | import inspect
6 | import os
7 | import platform
8 | import shutil
9 | import subprocess
10 | from dataclasses import dataclass
11 | from functools import lru_cache, partial, wraps
12 | from typing import Any, Dict, Tuple, Optional
13 |
14 | import numpy as np
15 | import tensorflow as tf
16 | from mpi4py import MPI
17 | from tensorflow.contrib import summary
18 |
19 | try:
20 | import horovod.tensorflow as hvd
21 | hvd.init()
22 | except:
23 | hvd = None
24 |
25 |
26 | nest = tf.contrib.framework.nest
27 |
28 |
29 | def nvidia_gpu_count():
30 | """
31 | Count the GPUs on this machine.
32 | """
33 | if shutil.which('nvidia-smi') is None:
34 | return 0
35 | try:
36 | output = subprocess.check_output(['nvidia-smi', '--query-gpu=gpu_name', '--format=csv'])
37 | except subprocess.CalledProcessError:
38 | # Probably no GPUs / no driver running.
39 | return 0
40 | return max(0, len(output.split(b'\n')) - 2)
41 |
42 |
43 | def get_local_rank_size(comm):
44 | """
45 | Returns the rank of each process on its machine
46 | The processes on a given machine will be assigned ranks
47 | 0, 1, 2, ..., N-1,
48 | where N is the number of processes on this machine.
49 | Useful if you want to assign one gpu per machine
50 | """
51 | this_node = platform.node()
52 | ranks_nodes = comm.allgather((comm.Get_rank(), this_node))
53 | node2rankssofar = collections.defaultdict(int)
54 | local_rank = None
55 | for (rank, node) in ranks_nodes:
56 | if rank == comm.Get_rank():
57 | local_rank = node2rankssofar[node]
58 | node2rankssofar[node] += 1
59 | assert local_rank is not None
60 | return local_rank, node2rankssofar[this_node]
61 |
62 |
63 | @lru_cache()
64 | def gpu_devices():
65 | if 'CUDA_VISIBLE_DEVICES' in os.environ:
66 | raise ValueError('CUDA_VISIBLE_DEVICES should not be set (it will cause nccl slowdowns). Use VISIBLE_DEVICES instead!')
67 | devices_str = os.environ.get('VISIBLE_DEVICES')
68 | if devices_str is not None:
69 | return list(map(int, filter(len, devices_str.split(','))))
70 | else:
71 | return list(range(nvidia_gpu_count()))
72 |
73 | @lru_cache()
74 | def gpu_count():
75 | return len(gpu_devices()) or None
76 |
77 |
78 | @lru_cache()
79 | def _our_gpu():
80 | """Figure out which GPU we should be using in an MPI context."""
81 | gpus = gpu_devices()
82 | if not gpus:
83 | return None
84 | rank = MPI.COMM_WORLD.Get_rank()
85 | local_rank, local_size = get_local_rank_size(MPI.COMM_WORLD)
86 | if gpu_count() not in (0, local_size):
87 | raise ValueError('Expected one GPU per rank, got gpus %s, local size %d' % (gpus, local_size))
88 | gpu = gpus[local_rank]
89 | print('rank %d: gpus = %s, our gpu = %d' % (rank, gpus, gpu))
90 | return gpu
91 |
92 |
93 | def mpi_session_config():
94 | """Make a tf.ConfigProto to use only the GPU assigned to this MPI session."""
95 | config = tf.ConfigProto()
96 | gpu = _our_gpu()
97 | if gpu is not None:
98 | config.gpu_options.visible_device_list = str(gpu)
99 | config.gpu_options.allow_growth = True
100 | return config
101 |
102 |
103 | def mpi_session():
104 | """Create a session using only the GPU assigned to this MPI process."""
105 | return tf.Session(config=mpi_session_config())
106 |
107 |
108 | def set_mpi_seed(seed: Optional[int]):
109 | if seed is not None:
110 | rank = MPI.COMM_WORLD.Get_rank()
111 | seed = seed + rank * 100003 # Prime (kept for backwards compatibility even though it does nothing)
112 | np.random.seed(seed)
113 | tf.set_random_seed(seed)
114 |
115 |
116 | def exact_div(a, b):
117 | q = a // b
118 | if tf.contrib.framework.is_tensor(q):
119 | with tf.control_dependencies([tf.debugging.Assert(tf.equal(a, q * b), [a, b])]):
120 | return tf.identity(q)
121 | else:
122 | if a != q * b:
123 | raise ValueError('Inexact division: %s / %s = %s' % (a, b, a / b))
124 | return q
125 |
126 |
127 | def ceil_div(a, b):
128 | return (a - 1) // b + 1
129 |
130 |
131 | def expand_tile(value, size, *, axis, name=None):
132 | """Add a new axis of given size."""
133 | with tf.name_scope(name, 'expand_tile', [value, size, axis]) as scope:
134 | value = tf.convert_to_tensor(value, name='value')
135 | size = tf.convert_to_tensor(size, name='size')
136 | ndims = value.shape.rank
137 | if axis < 0:
138 | axis += ndims + 1
139 | return tf.tile(tf.expand_dims(value, axis=axis), [1]*axis + [size] + [1]*(ndims - axis), name=scope)
140 |
141 |
142 | def index_each(a, ix):
143 | """Do a batched indexing operation: index row i of a by ix[i]
144 |
145 | In the simple case (a is >=2D and ix is 1D), returns [row[i] for row, i in zip(a, ix)].
146 |
147 | If ix has more dimensions, multiple lookups will be done at each batch index.
148 | For instance, if ix is 2D, returns [[row[i] for i in ix_row] for row, ix_row in zip(a, ix)].
149 |
150 | Always indexes into dimension 1 of a.
151 | """
152 | a = tf.convert_to_tensor(a, name='a')
153 | ix = tf.convert_to_tensor(ix, name='ix', dtype=tf.int32)
154 | with tf.name_scope('index_each', values=[a, ix]) as scope:
155 | a.shape[:1].assert_is_compatible_with(ix.shape[:1])
156 | i0 = tf.range(tf.shape(a)[0], dtype=ix.dtype)
157 | if ix.shape.rank > 1:
158 | i0 = tf.tile(tf.reshape(i0, (-1,) + (1,)*(ix.shape.rank - 1)), tf.concat([[1], tf.shape(ix)[1:]], axis=0))
159 | return tf.gather_nd(a, tf.stack([i0, ix], axis=-1), name=scope)
160 |
161 | def cumulative_max(x):
162 | """Takes the (inclusive) cumulative maximum along the last axis of x. (Not efficient.)"""
163 | x = tf.convert_to_tensor(x)
164 | with tf.name_scope('cumulative_max', values=[x]) as scope:
165 | repeated = tf.tile(
166 | tf.expand_dims(x, axis=-1),
167 | tf.concat([tf.ones(x.shape.rank, dtype=tf.int32), tf.shape(x)[-1:]], axis=0))
168 | trues = tf.ones_like(repeated, dtype=tf.bool)
169 | upper_triangle = tf.matrix_band_part(trues, 0, -1)
170 | neg_inf = tf.ones_like(repeated) * tf.dtypes.saturate_cast(-np.inf, dtype=x.dtype)
171 | prefixes = tf.where(upper_triangle, repeated, neg_inf)
172 | return tf.math.reduce_max(prefixes, axis=-2, name=scope)
173 |
174 |
175 | def flatten_dict(nested, sep='.'):
176 | def rec(nest, prefix, into):
177 | for k, v in nest.items():
178 | if sep in k:
179 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
180 | if isinstance(v, collections.Mapping):
181 | rec(v, prefix + k + sep, into)
182 | else:
183 | into[prefix + k] = v
184 | flat = {}
185 | rec(nested, '', flat)
186 | return flat
187 |
188 | @dataclass
189 | class Schema:
190 | dtype: Any
191 | shape: Tuple[Optional[int],...]
192 |
193 |
194 | def add_batch_dim(schemas, batch_size=None):
195 | def add_dim(schema):
196 | return Schema(dtype=schema.dtype, shape=(batch_size,)+schema.shape)
197 | return nest.map_structure(add_dim, schemas)
198 |
199 |
200 | class SampleBuffer:
201 | """A circular buffer for storing and sampling data.
202 |
203 | Data can be added to the buffer with `add`, and old data will be dropped. If you need to
204 | control where the buffer is stored, wrap the constructor call in a `with tf.device` block:
205 |
206 | with tf.device('cpu:0'):
207 | buffer = SampleBuffer(...)
208 | """
209 |
210 | def __init__(self, *, capacity: int, schemas: Dict[str,Schema], name=None) -> None:
211 | with tf.variable_scope(name, 'buffer', use_resource=True, initializer=tf.zeros_initializer):
212 | self._capacity = tf.constant(capacity, dtype=tf.int32, name='capacity')
213 | self._total = tf.get_variable(
214 | 'total', dtype=tf.int32, shape=(), trainable=False, collections=[tf.GraphKeys.LOCAL_VARIABLES],
215 | )
216 | self._vars = {
217 | n: tf.get_variable(
218 | n, dtype=s.dtype, shape=(capacity,) + s.shape, trainable=False,
219 | collections=[tf.GraphKeys.LOCAL_VARIABLES],
220 | )
221 | for n,s in schemas.items()
222 | }
223 |
224 | def add(self, **data):
225 | """Add new data to the end of the buffer, dropping old data if we exceed capacity."""
226 | # Check input shapes
227 | if data.keys() != self._vars.keys():
228 | raise ValueError('data.keys() = %s != %s' % (sorted(data.keys()), sorted(self._vars.keys())))
229 | first = next(iter(data.values()))
230 | pre = first.shape[:1]
231 | for k, d in data.items():
232 | try:
233 | d.shape.assert_is_compatible_with(pre.concatenate(self._vars[k].shape[1:]))
234 | except ValueError as e:
235 | raise ValueError('%s, key %s' % (e, k))
236 | # Enqueue
237 | n = tf.shape(first)[0]
238 | capacity = self._capacity
239 | i0 = (self._total.assign_add(n) - n) % capacity
240 | i0n = i0 + n
241 | i1 = tf.minimum(i0n, capacity)
242 | i2 = i1 % capacity
243 | i3 = i0n % capacity
244 | slices = slice(i0, i1), slice(i2, i3)
245 | sizes = tf.stack([i1 - i0, i3 - i2])
246 | assigns = [self._vars[k][s].assign(part)
247 | for k,d in data.items()
248 | for s, part in zip(slices, tf.split(d, sizes))]
249 | return tf.group(assigns)
250 |
251 | def total(self):
252 | """Total number of entries ever added, including those already discarded."""
253 | return self._total.read_value()
254 |
255 | def size(self):
256 | """Current number of entries."""
257 | return tf.minimum(self.total(), self._capacity)
258 |
259 | def read(self, indices):
260 | """indices: A 1-D Tensor of indices to read from. Each index must be less than
261 | capacity."""
262 | return {k: v.sparse_read(indices) for k,v in self._vars.items()}
263 |
264 | def data(self):
265 | return {k: v[:self.size()] for k,v in self._vars.items()}
266 |
267 | def sample(self, n, seed=None):
268 | """Sample n entries with replacement."""
269 | size = self.size()
270 | indices = tf.random_uniform([n], maxval=size, dtype=tf.int32, seed=seed)
271 | return self.read(indices)
272 |
273 | def write(self, indices, updates):
274 | """
275 | indices: A 1-D Tensor of indices to write to. Each index must be less than `capacity`.
276 | update: A dictionary of new values, where each entry is a tensor with the same length as `indices`.
277 | """
278 | ops = []
279 | for k, v in updates.items():
280 | ops.append(self._vars[k].scatter_update(tf.IndexedSlices(v, tf.cast(indices, dtype=tf.int32))))
281 | return tf.group(*ops)
282 |
283 | def write_add(self, indices, deltas):
284 | ops = []
285 | for k, d in deltas.items():
286 | ops.append(self._vars[k].scatter_add(tf.IndexedSlices(d, tf.cast(indices, dtype=tf.int32))))
287 | return tf.group(*ops)
288 |
289 |
290 | def entropy_from_logits(logits):
291 | pd = tf.nn.softmax(logits, axis=-1)
292 | return tf.math.reduce_logsumexp(logits, axis=-1) - tf.reduce_sum(pd*logits, axis=-1)
293 |
294 |
295 | def logprobs_from_logits(*, logits, labels):
296 | return -tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)
297 |
298 |
299 | def sample_from_logits(logits, dtype=tf.int32):
300 | with tf.name_scope('sample_from_logits', values=[logits]) as scope:
301 | shape = tf.shape(logits)
302 | flat_logits = tf.reshape(logits, [-1, shape[-1]])
303 | flat_samples = tf.random.categorical(flat_logits, num_samples=1, dtype=dtype)
304 | return tf.reshape(flat_samples, shape[:-1], name=scope)
305 |
306 |
307 | def take_top_k_logits(logits, k):
308 | values, _ = tf.nn.top_k(logits, k=k)
309 | min_values = values[:, :, -1, tf.newaxis]
310 | return tf.where(
311 | logits < min_values,
312 | tf.ones_like(logits) * -1e10,
313 | logits,
314 | )
315 |
316 |
317 | def take_top_p_logits(logits, p):
318 | """Nucleus sampling"""
319 | batch, sequence, _ = logits.shape.as_list()
320 | sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
321 | cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
322 | indices = tf.stack([
323 | tf.range(0, batch)[:, tf.newaxis],
324 | tf.range(0, sequence)[tf.newaxis, :],
325 | # number of indices to include
326 | tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
327 | ], axis=-1)
328 | min_values = tf.gather_nd(sorted_logits, indices)
329 | return tf.where(
330 | logits < min_values,
331 | tf.ones_like(logits) * -1e10,
332 | logits,
333 | )
334 |
335 |
336 | def whiten(values, shift_mean=True):
337 | mean, var = tf.nn.moments(values, axes=list(range(values.shape.rank)))
338 | whitened = (values - mean) * tf.rsqrt(var + 1e-8)
339 | if not shift_mean:
340 | whitened += mean
341 | return whitened
342 |
343 |
344 |
345 | def where(cond, true, false, name=None):
346 | """Similar to tf.where, but broadcasts scalar values."""
347 | with tf.name_scope(name, 'where', [cond, true, false]) as name:
348 | cond = tf.convert_to_tensor(cond, name='cond', dtype=tf.bool)
349 | true = tf.convert_to_tensor(true, name='true',
350 | dtype=false.dtype if isinstance(false, tf.Tensor) else None)
351 | false = tf.convert_to_tensor(false, name='false', dtype=true.dtype)
352 | if true.shape.rank == false.shape.rank == 0:
353 | shape = tf.shape(cond)
354 | true = tf.fill(shape, true)
355 | false = tf.fill(shape, false)
356 | elif true.shape.rank == 0:
357 | true = tf.fill(tf.shape(false), true)
358 | elif false.shape.rank == 0:
359 | false = tf.fill(tf.shape(true), false)
360 | return tf.where(cond, true, false, name=name)
361 |
362 |
363 | def map_flat(f, values):
364 | """Apply the function f to flattened, concatenated values, then split and reshape back to original shapes."""
365 | values = tuple(values)
366 | for v in values:
367 | assert not isinstance(v, tf.IndexedSlices)
368 | values = [tf.convert_to_tensor(v) for v in values]
369 | flat = tf.concat([tf.reshape(v, [-1]) for v in values], axis=0)
370 | flat = f(flat)
371 | parts = tf.split(flat, [tf.size(v) for v in values])
372 | return [tf.reshape(p, tf.shape(v)) for p, v in zip(parts, values)]
373 |
374 |
375 | def map_flat_chunked(f, values, *, limit=1<<29):
376 | """
377 | Apply the function f to chunked, flattened, concatenated values, then split and reshape back to original shapes.
378 | """
379 | values = tuple(values)
380 | for v in values:
381 | assert not isinstance(v, tf.IndexedSlices)
382 | values = [tf.convert_to_tensor(v) for v in values]
383 | chunks = chunk_tensors(values, limit=limit)
384 | mapped_values = [v for chunk in chunks for v in map_flat(f, chunk)]
385 | return mapped_values
386 |
387 |
388 | def map_flat_bits(f, values):
389 | """Apply the function f to bit-concatenated values, then convert back to original shapes and dtypes."""
390 | values = [tf.convert_to_tensor(v) for v in values]
391 | def maybe_bitcast(v, dtype):
392 | cast = tf.cast if tf.bool in (v.dtype, dtype) else tf.bitcast
393 | return cast(v, dtype)
394 | bits = [maybe_bitcast(v, tf.uint8) for v in values]
395 | flat = tf.concat([tf.reshape(b, [-1]) for b in bits], axis=0)
396 | flat = f(flat)
397 | parts = tf.split(flat, [tf.size(b) for b in bits])
398 | return [maybe_bitcast(tf.reshape(p, tf.shape(b)), v.dtype)
399 | for p, v, b in zip(parts, values, bits)]
400 |
401 | def mpi_bcast_tensor_dict(d, comm):
402 | sorted_keys = sorted(d.keys())
403 | values = map_flat_bits(partial(mpi_bcast, comm), [d[k] for k in sorted_keys])
404 | return {k: v for k, v in zip(sorted_keys, values)}
405 |
406 | def mpi_bcast(comm, value, root=0):
407 | """Broadcast value from root to other processes via a TensorFlow py_func."""
408 | value = tf.convert_to_tensor(value)
409 | if comm.Get_size() == 1:
410 | return value
411 | comm = comm.Dup() # Allow parallelism at graph execution time
412 | if comm.Get_rank() == root:
413 | out = tf.py_func(partial(comm.bcast, root=root), [value], value.dtype)
414 | else:
415 | out = tf.py_func(partial(comm.bcast, None, root=root), [], value.dtype)
416 | out.set_shape(value.shape)
417 | return out
418 |
419 |
420 | def chunk_tensors(tensors, *, limit=1 << 28):
421 | """Chunk the list of tensors into groups of size at most `limit` bytes.
422 |
423 | The tensors must have a static shape.
424 | """
425 | total = 0
426 | batches = []
427 | for v in tensors:
428 | size = v.dtype.size * v.shape.num_elements()
429 | if not batches or total + size > limit:
430 | total = 0
431 | batches.append([])
432 | total += size
433 | batches[-1].append(v)
434 | return batches
435 |
436 |
437 | def variable_synchronizer(comm, vars, *, limit=1<<28):
438 | """Synchronize `vars` from the root to other processs"""
439 | if comm.Get_size() == 1:
440 | return tf.no_op()
441 |
442 | # Split vars into chunks so that no chunk is over limit bytes
443 | batches = chunk_tensors(sorted(vars, key=lambda v: v.name), limit=limit)
444 |
445 | # Synchronize each batch, using a separate communicator to ensure safety
446 | prev = tf.no_op()
447 | for batch in batches:
448 | with tf.control_dependencies([prev]):
449 | assigns = []
450 | values = map_flat_bits(partial(mpi_bcast, comm), batch)
451 | for var, value in zip(batch, values):
452 | assigns.append(var.assign(value))
453 | prev = tf.group(*assigns)
454 | return prev
455 |
456 |
457 | def mpi_read_file(comm, path):
458 | """Read a file on rank 0 and broadcast the contents to all machines."""
459 | if comm.Get_rank() == 0:
460 | with tf.gfile.Open(path, 'rb') as fh:
461 | data = fh.read()
462 | comm.bcast(data)
463 | else:
464 | data = comm.bcast(None)
465 | return data
466 |
467 |
468 | def mpi_allreduce_sum(values, *, comm):
469 | if comm.Get_size() == 1:
470 | return values
471 | orig_dtype = values.dtype
472 | if hvd is None:
473 | orig_shape = values.shape
474 | def _allreduce(vals):
475 | buf = np.zeros(vals.shape, np.float32)
476 | comm.Allreduce(vals, buf, op=MPI.SUM)
477 | return buf
478 | values = tf.py_func(_allreduce, [values], tf.float32)
479 | values.set_shape(orig_shape)
480 | else:
481 | values = hvd.mpi_ops._allreduce(values)
482 | return tf.cast(values, dtype=orig_dtype)
483 |
484 |
485 | def mpi_allreduce_mean(values, *, comm):
486 | scale = 1 / comm.Get_size()
487 | values = mpi_allreduce_sum(values, comm=comm)
488 | return values if scale == 1 else scale * values
489 |
490 |
491 | class FlatStats:
492 | """A bunch of statistics stored as a single flat tensor."""
493 |
494 | def __init__(self, keys, flat):
495 | keys = tuple(keys)
496 | flat = tf.convert_to_tensor(flat, dtype=tf.float32, name='flat')
497 | assert [len(keys)] == flat.shape.as_list()
498 | self.keys = keys
499 | self.flat = flat
500 |
501 | @staticmethod
502 | def from_dict(stats):
503 | for k, v in stats.items():
504 | if v.dtype != tf.float32:
505 | raise ValueError('Statistic %s has dtype %r, expected %r' % (k, v.dtype, tf.float32))
506 | keys = tuple(sorted(stats.keys()))
507 | flat = tf.stack([stats[k] for k in keys])
508 | return FlatStats(keys, flat)
509 |
510 | def concat(self, more):
511 | dups = set(self.keys) & set(more.keys)
512 | if dups:
513 | raise ValueError('Duplicate statistics: %s' % ', '.join(dups))
514 | return FlatStats(self.keys + more.keys, tf.concat([self.flat, more.flat], axis=0))
515 |
516 | def as_dict(self):
517 | flat = tf.unstack(self.flat, num=len(self.keys))
518 | return dict(safe_zip(self.keys, flat))
519 |
520 | def with_values(self, flat):
521 | return FlatStats(self.keys, flat)
522 |
523 | def map_flat(self, f):
524 | return FlatStats(self.keys, f(self.flat))
525 |
526 |
527 | def find_trainable_variables(key):
528 | return [v for v in tf.trainable_variables() if v.op.name.startswith(key + '/')]
529 |
530 |
531 | def variables_on_gpu():
532 | """Prevent variables from accidentally being placed on the CPU.
533 |
534 | This dodges an obscure bug in tf.train.init_from_checkpoint.
535 | """
536 | if _our_gpu() is None:
537 | return contextlib.suppress()
538 | def device(op):
539 | return '/gpu:0' if op.type == 'VarHandleOp' else ''
540 | return tf.device(device)
541 |
542 |
543 |
544 | def graph_function(**schemas: Schema):
545 | def decorate(make_op):
546 | def make_ph(path, schema):
547 | return tf.placeholder(name=f'arg_{make_op.__name__}_{path}', shape=schema.shape, dtype=schema.dtype)
548 | phs = nest.map_structure_with_paths(make_ph, schemas)
549 | op = make_op(**phs)
550 | sig = inspect.signature(make_op)
551 | @wraps(make_op)
552 | def run(*args, **kwargs):
553 | bound: inspect.BoundArguments = sig.bind(*args, **kwargs)
554 | bound.apply_defaults()
555 |
556 | arg_dict = bound.arguments
557 | for name, param in sig.parameters.items():
558 | if param.kind == inspect.Parameter.VAR_KEYWORD:
559 | kwargs = arg_dict[name]
560 | arg_dict.update(kwargs)
561 | del arg_dict[name]
562 | flat_phs = nest.flatten(phs)
563 | flat_arguments = nest.flatten_up_to(phs, bound.arguments)
564 | feed = {ph: arg for ph, arg in zip(flat_phs, flat_arguments)}
565 | run_options = tf.RunOptions(report_tensor_allocations_upon_oom=True)
566 |
567 | return tf.get_default_session().run(op, feed_dict=feed, options=run_options, run_metadata=None)
568 | return run
569 | return decorate
570 |
571 |
572 |
573 | def pearson_r(x: tf.Tensor, y: tf.Tensor):
574 | assert x.shape.rank == 1
575 | assert y.shape.rank == 1
576 | x_mean, x_var = tf.nn.moments(x, axes=[0])
577 | y_mean, y_var = tf.nn.moments(y, axes=[0])
578 | cov = tf.reduce_mean((x - x_mean)*(y - y_mean), axis=0)
579 | return cov / tf.sqrt(x_var * y_var)
580 |
581 | def shape_list(x):
582 | """Deal with dynamic shape in tensorflow cleanly."""
583 | static = x.shape.as_list()
584 | dynamic = tf.shape(x)
585 | return [dynamic[i] if s is None else s for i, s in enumerate(static)]
586 |
587 | def safe_zip(*args):
588 | """Zip, but require all sequences to be the same length."""
589 | args = tuple(map(tuple, args))
590 | for a in args[1:]:
591 | if len(args[0]) != len(a):
592 | raise ValueError(f'Lengths do not match: {[len(a) for a in args]}')
593 | return zip(*args)
594 |
595 |
596 | def get_summary_writer(save_dir, subdir='', comm=MPI.COMM_WORLD):
597 | if comm.Get_rank() != 0:
598 | return None
599 | if save_dir is None:
600 | return None
601 | with tf.init_scope():
602 | return summary.create_file_writer(os.path.join(save_dir, 'tb', subdir))
603 |
604 |
605 | def record_stats(*, stats, summary_writer, step, log_interval, name=None, comm=MPI.COMM_WORLD):
606 | def log_stats(step, *stat_values):
607 | if comm.Get_rank() != 0 or step % log_interval != 0:
608 | return
609 |
610 | for k, v in safe_zip(stats.keys(), stat_values):
611 | print('k = ', k, ', v = ', v)
612 |
613 | summary_ops = [tf.py_func(log_stats, [step] + list(stats.values()), [])]
614 | if summary_writer:
615 | with summary_writer.as_default(), summary.always_record_summaries():
616 | for key, value in stats.items():
617 | summary_ops.append(summary.scalar(key, value, step=step))
618 | return tf.group(*summary_ops, name=name)
619 |
620 |
621 | def minimize(*, loss, params, lr, name=None, comm=MPI.COMM_WORLD):
622 | with tf.name_scope(name, 'minimize'):
623 | with tf.name_scope('grads'):
624 | grads = tf.gradients(loss, params)
625 | grads, params = zip(*[(g, v) for g, v in zip(grads, params) if g is not None])
626 | grads = map_flat_chunked(partial(mpi_allreduce_mean, comm=comm), grads)
627 | optimizer = tf.train.AdamOptimizer(learning_rate=lr, epsilon=1e-5, name='adam')
628 | opt_op = optimizer.apply_gradients(zip(grads, params), name=name)
629 | return opt_op
630 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/gcs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import subprocess
4 | import time
5 | import traceback
6 | import warnings
7 | from functools import wraps
8 | from urllib.parse import urlparse, unquote
9 |
10 | import requests
11 | from google.api_core.exceptions import InternalServerError, ServiceUnavailable
12 | from google.cloud import storage
13 |
14 | warnings.filterwarnings("ignore", "Your application has authenticated using end user credentials")
15 |
16 |
17 | def exponential_backoff(
18 | retry_on=lambda e: True, *, init_delay_s=1, max_delay_s=600, max_tries=30, factor=2.0,
19 | jitter=0.2, log_errors=True):
20 | """
21 | Returns a decorator which retries the wrapped function as long as retry_on returns True for the exception.
22 | :param init_delay_s: How long to wait to do the first retry (in seconds).
23 | :param max_delay_s: At what duration to cap the retry interval at (in seconds).
24 | :param max_tries: How many total attempts to perform.
25 | :param factor: How much to multiply the delay interval by after each attempt (until it reaches max_delay_s).
26 | :param jitter: How much to jitter by (between 0 and 1) -- each delay will be multiplied by a random value between (1-jitter) and (1+jitter).
27 | :param log_errors: Whether to print tracebacks on every retry.
28 | :param retry_on: A predicate which takes an exception and indicates whether to retry after that exception.
29 | """
30 | def decorate(f):
31 | @wraps(f)
32 | def f_retry(*args, **kwargs):
33 | delay_s = float(init_delay_s)
34 | for i in range(max_tries):
35 | try:
36 | return f(*args, **kwargs)
37 | except Exception as e:
38 | if not retry_on(e) or i == max_tries-1:
39 | raise
40 | if log_errors:
41 | print(f"Retrying after try {i+1}/{max_tries} failed:")
42 | traceback.print_exc()
43 | jittered_delay = random.uniform(delay_s*(1-jitter), delay_s*(1+jitter))
44 | time.sleep(jittered_delay)
45 | delay_s = min(delay_s * factor, max_delay_s)
46 | return f_retry
47 | return decorate
48 |
49 |
50 | def _gcs_should_retry_on(e):
51 | # Retry on all 503 errors and 500, as recommended by https://cloud.google.com/apis/design/errors#error_retries
52 | return isinstance(e, (InternalServerError, ServiceUnavailable, requests.exceptions.ConnectionError))
53 |
54 |
55 | def parse_url(url):
56 | """Given a gs:// path, returns bucket name and blob path."""
57 | result = urlparse(url)
58 | if result.scheme == 'gs':
59 | return result.netloc, unquote(result.path.lstrip('/'))
60 | elif result.scheme == 'https':
61 | assert result.netloc == 'storage.googleapis.com'
62 | bucket, rest = result.path.lstrip('/').split('/', 1)
63 | return bucket, unquote(rest)
64 | else:
65 | raise Exception(f'Could not parse {url} as gcs url')
66 |
67 |
68 | @exponential_backoff(_gcs_should_retry_on)
69 | def get_blob(url, client=None):
70 | if client is None:
71 | client = storage.Client()
72 | bucket_name, path = parse_url(url)
73 | bucket = client.get_bucket(bucket_name)
74 | return bucket.get_blob(path)
75 |
76 |
77 | @exponential_backoff(_gcs_should_retry_on)
78 | def download_contents(url, client=None):
79 | """Given a gs:// path, returns contents of the corresponding blob."""
80 | blob = get_blob(url, client)
81 | if not blob: return None
82 | return blob.download_as_string()
83 |
84 |
85 | @exponential_backoff(_gcs_should_retry_on)
86 | def upload_contents(url, contents, client=None):
87 | """Given a gs:// path, returns contents of the corresponding blob."""
88 | if client is None:
89 | client = storage.Client()
90 | bucket_name, path = parse_url(url)
91 | bucket = client.get_bucket(bucket_name)
92 | blob = storage.Blob(path, bucket)
93 | blob.upload_from_string(contents)
94 |
95 |
96 | def download_directory_cached(url, comm=None):
97 | """ Given a GCS path url, caches the contents locally.
98 | WARNING: only use this function if contents under the path won't change!
99 | """
100 | cache_dir = '/tmp/gcs-cache'
101 | bucket_name, path = parse_url(url)
102 | is_master = not comm or comm.Get_rank() == 0
103 | local_path = os.path.join(cache_dir, bucket_name, path)
104 |
105 | sentinel = os.path.join(local_path, 'SYNCED')
106 | if is_master:
107 | if not os.path.exists(local_path):
108 | os.makedirs(os.path.dirname(local_path), exist_ok=True)
109 | cmd = 'gsutil', '-m', 'cp', '-r', url, os.path.dirname(local_path) + '/'
110 | print(' '.join(cmd))
111 | subprocess.check_call(cmd)
112 | open(sentinel, 'a').close()
113 | else:
114 | while not os.path.exists(sentinel):
115 | time.sleep(1)
116 | return local_path
117 |
118 |
119 | def download_file_cached(url, comm=None):
120 | """ Given a GCS path url, caches the contents locally.
121 | WARNING: only use this function if contents under the path won't change!
122 | """
123 | cache_dir = '/tmp/gcs-cache'
124 | bucket_name, path = parse_url(url)
125 | is_master = not comm or comm.Get_rank() == 0
126 | local_path = os.path.join(cache_dir, bucket_name, path)
127 |
128 | sentinel = local_path + '.SYNCED'
129 | if is_master:
130 | if not os.path.exists(local_path):
131 | os.makedirs(os.path.dirname(local_path), exist_ok=True)
132 | cmd = 'gsutil', '-m', 'cp', url, local_path
133 | print(' '.join(cmd))
134 | subprocess.check_call(cmd)
135 | open(sentinel, 'a').close()
136 | else:
137 | while not os.path.exists(sentinel):
138 | time.sleep(1)
139 | return local_path
140 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/hyperparams.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import typing
4 | from dataclasses import fields, is_dataclass
5 | from functools import lru_cache
6 |
7 | from typeguard import check_type
8 |
9 | from lm_human_preferences.utils import gcs
10 |
11 |
12 | class HParams:
13 | """Used as a base class for hyperparameter structs. They also need to be annotated with @dataclass."""
14 |
15 | def override_from_json_file(self, filename):
16 | if filename.startswith('gs://'):
17 | hparams_str = gcs.download_contents(filename)
18 | else:
19 | hparams_str = open(filename).read()
20 | self.parse_json(hparams_str)
21 |
22 | def override_from_str(self, hparam_str):
23 | """Overrides values from a string like 'x.y=1,name=foobar'.
24 |
25 | Like tensorflow.contrib.training.HParams, this method does not allow specifying string values containing commas.
26 | """
27 | kvp_strs = hparam_str.split(',')
28 | flat_dict = {}
29 | for kvp_str in kvp_strs:
30 | k, sep, v = kvp_str.partition('=')
31 | if not sep:
32 | raise ValueError(f"Malformed hyperparameter value: '{kvp_str}'")
33 | flat_dict[k] = v
34 |
35 | self.override_from_str_dict(flat_dict)
36 |
37 | def override_from_str_dict(self, flat_dict, separator='.'):
38 | """Overrides values from a dict like {'x.y': "1", 'name': "foobar"}.
39 |
40 | Treats keys with dots as paths into nested HParams.
41 | Parses values according to the types in the HParams classes.
42 | """
43 | typemap = _type_map(type(self), separator=separator)
44 |
45 | parsed = {}
46 | for flat_k, s in flat_dict.items():
47 | if flat_k not in typemap:
48 | raise AttributeError(f"no field {flat_k} in {typemap}")
49 | parsed[flat_k] = _parse_typed_value(typemap[flat_k], s)
50 |
51 | self.override_from_dict(parsed, separator=separator)
52 |
53 | def parse_json(self, s: str):
54 | self.override_from_nested_dict(json.loads(s))
55 |
56 | def override_from_dict(self, flat_dict, separator='.'):
57 | """Overrides values from a dict like {'x.y': 1, 'name': "foobar"}.
58 |
59 | Treats keys with dots as paths into nested HParams.
60 | Values should be parsed already.
61 | """
62 | # Parse 'on' and 'off' values.
63 | typemap = _type_map(type(self), separator=separator)
64 |
65 | flat_dict_parsed = {}
66 | for flat_k, v in flat_dict.items():
67 | cls = _type_to_class(typemap[flat_k])
68 | if is_hparam_type(cls) and v == 'on':
69 | parsed_v = cls()
70 | elif is_hparam_type(cls) and v == 'off':
71 | parsed_v = None
72 | else:
73 | parsed_v = v
74 | flat_dict_parsed[flat_k] = parsed_v
75 |
76 | # Expand implicit nested 'on' values. For instance, {'x.y': 'on'} should mean {'x': 'on', 'x.y': 'on'}.
77 | flat_dict_expanded = {}
78 | for flat_k, v in flat_dict_parsed.items():
79 | flat_dict_expanded[flat_k] = v
80 | cls = _type_to_class(typemap[flat_k])
81 | if is_hparam_type(cls) and v is not None:
82 | parts = flat_k.split(separator)
83 | prefix = parts[0]
84 | for i in range(1, len(parts)):
85 | if prefix not in flat_dict_expanded:
86 | flat_dict_expanded[prefix] = _type_to_class(typemap[prefix])()
87 | prefix += separator + parts[i]
88 |
89 | # Set all the values. The sort ensures that outer classes get initialized before their fields.
90 | for flat_k in sorted(flat_dict_expanded.keys()):
91 | v = flat_dict_expanded[flat_k]
92 | *ks, f = flat_k.split(separator)
93 | hp = self
94 | for i, k in enumerate(ks):
95 | try:
96 | hp = getattr(hp, k)
97 | except AttributeError:
98 | raise AttributeError(f"{hp} {'(' + separator.join(ks[:i]) + ') ' if i else ''}has no field '{k}'")
99 | try:
100 | setattr(hp, f, v)
101 | except AttributeError:
102 | raise AttributeError(f"{hp} ({separator.join(ks)}) has no field '{f}'")
103 |
104 | def override_from_nested_dict(self, nested_dict):
105 | for k, v in nested_dict.items():
106 | if isinstance(v, dict):
107 | if getattr(self, k) is None:
108 | cls = _type_to_class(_get_field(self, k).type)
109 | setattr(self, k, cls())
110 | getattr(self, k).override_from_nested_dict(v)
111 | else:
112 | setattr(self, k, v)
113 |
114 | def to_nested_dict(self):
115 | d = {}
116 | for f in fields(self):
117 | fieldval = getattr(self, f.name)
118 | if isinstance(fieldval, HParams):
119 | fieldval = fieldval.to_nested_dict()
120 | d[f.name] = fieldval
121 | return d
122 |
123 | def validate(self, *, prefix=''):
124 | assert is_dataclass(self), f"You forgot to annotate {type(self)} with @dataclass"
125 | for f in fields(self):
126 | fieldval = getattr(self, f.name)
127 | check_type(prefix + f.name, fieldval, f.type)
128 | if isinstance(fieldval, HParams):
129 | fieldval.validate(prefix=prefix + f.name + '.')
130 |
131 |
132 | def is_hparam_type(ty):
133 | if isinstance(ty, type) and issubclass(ty, HParams):
134 | assert is_dataclass(ty)
135 | return True
136 | else:
137 | return False
138 |
139 |
140 | def _is_union_type(ty):
141 | return getattr(ty, '__origin__', None) is typing.Union
142 |
143 |
144 | def dump(hparams, *, name='hparams', out=sys.stdout):
145 | out.write('%s:\n' % name)
146 | def dump_nested(hp, indent):
147 | for f in sorted(fields(hp), key=lambda f: f.name):
148 | v = getattr(hp, f.name)
149 | if isinstance(v, HParams):
150 | out.write('%s%s:\n' % (indent, f.name))
151 | dump_nested(v, indent=indent+' ')
152 | else:
153 | out.write('%s%s: %s\n' % (indent, f.name, v))
154 | dump_nested(hparams, indent=' ')
155 |
156 |
157 | def _can_distinguish_unambiguously(type_set):
158 | """Whether it's always possible to tell which type in type_set a certain value is supposed to be"""
159 | if len(type_set) == 1:
160 | return True
161 | if type(None) in type_set:
162 | return True
163 | if str in type_set:
164 | return False
165 | if int in type_set and float in type_set:
166 | return False
167 | if any(_is_union_type(ty) for ty in type_set):
168 | # Nested unions *might* be unambiguous, but don't support for now
169 | return False
170 | return True
171 |
172 |
173 | def _parse_typed_value(ty, s):
174 | if ty is str:
175 | return s
176 | elif ty in (int, float):
177 | return ty(s)
178 | elif ty is bool:
179 | if s in ('t', 'true', 'True'):
180 | return True
181 | elif s in ('f', 'false', 'False'):
182 | return False
183 | else:
184 | raise ValueError(f"Invalid bool '{s}'")
185 | elif ty is type(None):
186 | if s in ('None', 'none', ''):
187 | return None
188 | else:
189 | raise ValueError(f"Invalid None value '{s}'")
190 | elif is_hparam_type(ty):
191 | if s in ('on', 'off'):
192 | # The class will be constructed later
193 | return s
194 | else:
195 | raise ValueError(f"Invalid hparam class value '{s}'")
196 | elif _is_union_type(ty):
197 | if not _can_distinguish_unambiguously(ty.__args__):
198 | raise TypeError(f"Can't always unambiguously parse a value of union '{ty}'")
199 | for ty_option in ty.__args__:
200 | try:
201 | return _parse_typed_value(ty_option, s)
202 | except ValueError:
203 | continue
204 | raise ValueError(f"Couldn't parse '{s}' as any of the types in '{ty}'")
205 | else:
206 | raise ValueError(f"Unsupported hparam type '{ty}'")
207 |
208 |
209 | def _get_field(data, fieldname):
210 | matching_fields = [f for f in fields(data) if f.name == fieldname]
211 | if len(matching_fields) != 1:
212 | raise AttributeError(f"couldn't find field '{fieldname}' in {data}")
213 | return matching_fields[0]
214 |
215 |
216 | def _update_disjoint(dst: dict, src: dict):
217 | for k, v in src.items():
218 | assert k not in dst
219 | dst[k] = v
220 |
221 |
222 | @lru_cache()
223 | def _type_map(ty, separator):
224 | typemap = {}
225 | for f in fields(ty):
226 | typemap[f.name] = f.type
227 | if is_hparam_type(f.type):
228 | nested = _type_map(f.type, separator=separator)
229 | elif _is_union_type(f.type):
230 | nested = {}
231 | for ty_option in f.type.__args__:
232 | if is_hparam_type(ty_option):
233 | _update_disjoint(nested, _type_map(ty_option, separator=separator))
234 | else:
235 | nested = {}
236 | _update_disjoint(typemap, {f'{f.name}{separator}{k}': t for k, t in nested.items()})
237 | return typemap
238 |
239 |
240 | def _type_to_class(ty):
241 | """Extract a constructible class from a type. For instance, `typing.Optional[int]` gives `int`"""
242 | if _is_union_type(ty):
243 | # Only typing.Optional supported: must be of form typing.Union[ty, None]
244 | assert len(ty.__args__) == 2
245 | assert ty.__args__[1] is type(None)
246 | return ty.__args__[0]
247 | else:
248 | return ty
249 |
250 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/launch.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import os
3 | import subprocess
4 | from functools import partial
5 |
6 | import cloudpickle
7 | import fire
8 |
9 | def launch(name, f, *, namespace='safety', mode='local', mpi=1) -> None:
10 | if mode == 'local':
11 | with open('/tmp/pickle_fn', 'wb') as file:
12 | cloudpickle.dump(f, file)
13 |
14 | subprocess.check_call(['mpiexec', '-n', str(mpi), 'python', '-c', 'import sys; import pickle; pickle.loads(open("/tmp/pickle_fn", "rb").read())()'])
15 | return
16 | raise Exception('Other modes unimplemented!')
17 |
18 | def parallel(jobs, mode):
19 | if mode == 'local':
20 | assert len(jobs) == 1, "Cannot run jobs in parallel locally"
21 | for job in jobs:
22 | job()
23 | else:
24 | with concurrent.futures.ThreadPoolExecutor() as executor:
25 | futures = [executor.submit(job) for job in jobs]
26 | for f in futures:
27 | f.result()
28 |
29 | def launch_trials(name, fn, trials, hparam_class, extra_hparams=None, dry_run=False, mpi=1, mode='local', save_dir=None):
30 | jobs = []
31 | for trial in trials:
32 | descriptors = []
33 | kwargs = {}
34 | for k, v, s in trial:
35 | if k is not None:
36 | if k in kwargs:
37 | print(f'WARNING: overriding key {k} from {kwargs[k]} to {v}')
38 | kwargs[k] = v
39 | if s.get('descriptor'):
40 | descriptors.append(str(s['descriptor']))
41 | hparams = hparam_class()
42 | hparams.override_from_dict(kwargs)
43 | if extra_hparams:
44 | hparams.override_from_str_dict(extra_hparams)
45 | job_name = (name + '/' + '-'.join(descriptors)).rstrip('/')
46 | hparams.validate()
47 | if dry_run:
48 | print(f"{job_name}: {kwargs}")
49 | else:
50 | if save_dir:
51 | hparams.run.save_dir = os.path.join(save_dir, job_name)
52 | trial_fn = partial(fn, hparams)
53 | jobs.append(partial(launch, job_name, trial_fn, mpi=mpi, mode=mode))
54 |
55 | parallel(jobs, mode=mode)
56 |
57 | def main(commands_dict):
58 | """Similar to fire.Fire, but with support for multiple commands without having a class."""
59 | class _Commands:
60 | def __init__(self):
61 | for name, cmd in commands_dict.items():
62 | setattr(self, name, cmd)
63 | fire.Fire(_Commands)
64 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/test_core_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """utils tests"""
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 |
7 | from lm_human_preferences.utils import core as utils
8 |
9 |
10 | def test_exact_div():
11 | assert utils.exact_div(12, 4) == 3
12 | assert utils.exact_div(12, 3) == 4
13 | try:
14 | utils.exact_div(7, 3)
15 | assert False
16 | except ValueError:
17 | pass
18 |
19 |
20 | def test_ceil_div():
21 | for b in range(1, 10 + 1):
22 | for a in range(-10, 10 + 1):
23 | assert utils.ceil_div(a, b) == int(np.ceil(a / b))
24 |
25 |
26 | def test_expand_tile():
27 | np.random.seed(7)
28 | size = 11
29 | with tf.Session():
30 | for shape in (), (7,), (3, 5):
31 | data = np.asarray(np.random.randn(*shape), dtype=np.float32)
32 | x = tf.constant(data)
33 | for axis in range(-len(shape) - 1, len(shape) + 1):
34 | y = utils.expand_tile(x, size, axis=axis).eval()
35 | assert np.all(np.expand_dims(data, axis=axis) == y)
36 |
37 |
38 | def test_sample_buffer():
39 | capacity = 100
40 | batch = 17
41 | lots = 100
42 | with tf.Graph().as_default(), tf.Session() as sess:
43 | buffer = utils.SampleBuffer(capacity=capacity, schemas=dict(x=utils.Schema(tf.int32, ())))
44 | tf.variables_initializer(tf.global_variables() + tf.local_variables()).run()
45 | i_p = tf.placeholder(dtype=tf.int32, shape=())
46 | add = buffer.add(x=batch * i_p + tf.range(batch))
47 | sample = buffer.sample(lots, seed=7)['x']
48 | all_data_1 = buffer.data()
49 | all_data_2 = buffer.read(tf.range(buffer.size()))
50 | for i in range(20):
51 | add.run(feed_dict={i_p: i})
52 | samples = sample.eval()
53 | hi = batch * (i + 1)
54 | lo = max(0, hi - capacity)
55 | assert lo <= samples.min() <= lo + 3
56 | assert hi - 5 <= samples.max() < hi
57 | np.testing.assert_equal(sess.run(all_data_1), sess.run(all_data_2))
58 |
59 |
60 | def test_where():
61 | with tf.Session():
62 | assert np.all(utils.where([False, True], 7, 8).eval() == [8, 7])
63 | assert np.all(utils.where([False, True, True], [1, 2, 3], 8).eval() == [8, 2, 3])
64 | assert np.all(utils.where([False, False, True], 8, [1, 2, 3]).eval() == [1, 2, 8])
65 | assert np.all(utils.where([False, True], [[1, 2], [3, 4]], -1).eval() == [[-1, -1], [3, 4]])
66 | assert np.all(utils.where([False, True], -1, [[1, 2], [3, 4]]).eval() == [[1, 2], [-1, -1]])
67 |
68 |
69 | def test_map_flat():
70 | with tf.Session() as sess:
71 | inputs = [2], [3, 5], [[7, 11], [13, 17]]
72 | inputs = map(np.asarray, inputs)
73 | outputs = sess.run(utils.map_flat(tf.square, inputs))
74 | for i, o in zip(inputs, outputs):
75 | assert np.all(i * i == o)
76 |
77 |
78 | def test_map_flat_bits():
79 | with tf.Session() as sess:
80 | inputs = [2], [3, 5], [[7, 11], [13, 17]], [True, False, True]
81 | dtypes = np.uint8, np.uint16, np.int32, np.int64, np.bool
82 | inputs = [np.asarray(i, dtype=d) for i, d in zip(inputs, dtypes)]
83 | outputs = sess.run(utils.map_flat_bits(lambda x: x + 1, inputs))
84 |
85 | def tweak(n):
86 | return n + sum(2 ** (8 * i) for i in range(n.dtype.itemsize))
87 |
88 | for i, o in zip(inputs, outputs):
89 | assert np.all(tweak(i) == o)
90 |
91 |
92 | def test_cumulative_max():
93 | np.random.seed(7)
94 | with tf.Session().as_default():
95 | for x in [
96 | np.random.randn(10),
97 | np.random.randn(11, 7),
98 | np.random.randint(-10, 10, size=10),
99 | np.random.randint(-10, 10, size=(12, 8)),
100 | np.random.randint(-10, 10, size=(3, 3, 4)),
101 | ]:
102 | assert np.all(utils.cumulative_max(x).eval() == np.maximum.accumulate(x, axis=-1))
103 |
104 |
105 | def test_index_each():
106 | np.random.seed(7)
107 | x = np.random.randn(7, 11)
108 | i = np.random.randint(x.shape[1], size=x.shape[0])
109 | y = utils.index_each(x, i)
110 |
111 | x2 = np.random.randn(3, 2, 4)
112 | i2 = np.random.randint(x2.shape[1], size=x2.shape[0])
113 | y2 = utils.index_each(x2, i2)
114 |
115 | x3 = np.random.randn(5, 9)
116 | i3 = np.random.randint(x3.shape[1], size=(x3.shape[0], 2))
117 | y3 = utils.index_each(x3, i3)
118 | with tf.Session():
119 | assert np.all(y.eval() == x[np.arange(7), i])
120 | assert np.all(y2.eval() == x2[np.arange(3), i2])
121 | y3val = y3.eval()
122 | assert np.all(y3val[:,0] == x3[np.arange(5), i3[:,0]])
123 | assert np.all(y3val[:,1] == x3[np.arange(5), i3[:,1]])
124 |
125 |
126 | def test_index_each_many():
127 | np.random.seed(7)
128 | x = np.random.randn(7, 11)
129 | i = np.random.randint(x.shape[1], size=[x.shape[0],3])
130 | y = utils.index_each(x, i)
131 | with tf.Session():
132 | assert np.all(y.eval() == x[np.arange(7)[:,None], i])
133 |
134 |
135 | @utils.graph_function(x=utils.Schema(tf.int32, ()), y=utils.Schema(tf.int32, ()))
136 | def tf_sub(x, y=1):
137 | return tf.math.subtract(x, y)
138 |
139 | @utils.graph_function(x=utils.Schema(tf.int32, ()), y=dict(z1=utils.Schema(tf.int32, ()), z2=utils.Schema(tf.int32, ())))
140 | def tf_sub_2(x, y):
141 | return tf.math.subtract(x, y['z1']) - y['z2']
142 |
143 | def test_graph_function():
144 | with tf.Session().as_default():
145 | assert tf_sub(3) == 2
146 | assert tf_sub(x=3) == 2
147 | assert tf_sub(5, 2) == 3
148 | assert tf_sub(y=2, x=5) == 3
149 | assert tf_sub_2(5, dict(z1=1, z2=2)) == 2
150 |
151 | def test_top_k():
152 | with tf.Session().as_default():
153 | logits = tf.constant([[[1,1.01,1.001,0,0,0,2]]], dtype=tf.float32)
154 | np.testing.assert_allclose(
155 | utils.take_top_k_logits(logits, 1).eval(),
156 | [[[-1e10,-1e10,-1e10,-1e10,-1e10,-1e10,2]]]
157 | )
158 | np.testing.assert_allclose(
159 | utils.take_top_k_logits(logits, 2).eval(),
160 | [[[-1e10,1.01,-1e10,-1e10,-1e10,-1e10,2]]]
161 | )
162 | np.testing.assert_allclose(
163 | utils.take_top_k_logits(logits, 3).eval(),
164 | [[[-1e10,1.01,1.001,-1e10,-1e10,-1e10,2]]]
165 | )
166 | np.testing.assert_allclose(
167 | utils.take_top_k_logits(logits, 4).eval(),
168 | [[[1,1.01,1.001,-1e10,-1e10,-1e10,2]]]
169 | )
170 | np.testing.assert_allclose(
171 | utils.take_top_k_logits(logits, 5).eval(),
172 | [[[1,1.01,1.001,0,0,0,2]]]
173 | )
174 |
175 |
176 | def test_top_p():
177 | with tf.Session().as_default():
178 | logits = tf.constant([[[1,1.01,1.001,0,0,0,2]]], dtype=tf.float32)
179 | np.testing.assert_allclose(
180 | utils.take_top_p_logits(logits, 1).eval(),
181 | logits.eval()
182 | )
183 | np.testing.assert_allclose(
184 | utils.take_top_p_logits(logits, 0).eval(),
185 | [[[-1e10,-1e10,-1e10,-1e10,-1e10,-1e10,2]]]
186 | )
187 | np.testing.assert_allclose(
188 | utils.take_top_p_logits(logits, 0.7).eval(),
189 | [[[-1e10,1.01,1.001,-1e10,-1e10,-1e10,2]]]
190 | )
191 | np.testing.assert_allclose(
192 | utils.take_top_p_logits(logits, 0.6).eval(),
193 | [[[-1e10,1.01,-1e10,-1e10,-1e10,-1e10,2]]]
194 | )
195 | np.testing.assert_allclose(
196 | utils.take_top_p_logits(logits, 0.5).eval(),
197 | [[[-1e10,-1e10,-1e10,-1e10,-1e10,-1e10,2]]]
198 | )
199 |
200 | def test_safe_zip():
201 | assert list(utils.safe_zip([1, 2], [3, 4])) == [(1, 3), (2, 4)]
202 | try:
203 | utils.safe_zip([1, 2], [3, 4, 5])
204 | assert False
205 | except ValueError:
206 | pass
207 |
208 |
209 | if __name__ == '__main__':
210 | test_sample_buffer()
211 | test_cumulative_max()
212 | test_where()
213 | test_index_each()
214 | test_graph_function()
215 | test_top_k()
216 | test_top_p()
217 | test_safe_zip()
218 |
--------------------------------------------------------------------------------
/lm_human_preferences/utils/test_hyperparams.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from dataclasses import dataclass, field
3 | from typing import Optional
4 |
5 | import pytest
6 |
7 | from lm_human_preferences.utils import hyperparams
8 |
9 |
10 | @dataclass
11 | class Simple(hyperparams.HParams):
12 | mandatory_nodefault: int = None
13 | mandatory_withdefault: str = "foo"
14 | optional_nodefault: Optional[int] = None
15 | fun: bool = True
16 |
17 | def test_simple_works():
18 | hp = Simple()
19 | hp.override_from_str("mandatory_nodefault=3,optional_nodefault=None,fun=false")
20 | hp.validate()
21 | assert hp.mandatory_nodefault == 3
22 | assert hp.mandatory_withdefault == "foo"
23 | assert hp.optional_nodefault is None
24 | assert not hp.fun
25 |
26 | def test_simple_failures():
27 | hp = Simple()
28 | with pytest.raises(TypeError):
29 | hp.validate() # mandatory_nodefault unset
30 | with pytest.raises(ValueError):
31 | hp.override_from_str("mandatory_nodefault=abc")
32 | with pytest.raises(AttributeError):
33 | hp.override_from_str("nonexistent_field=7.0")
34 | with pytest.raises(ValueError):
35 | hp.override_from_str("fun=?")
36 |
37 | @dataclass
38 | class Nested(hyperparams.HParams):
39 | first: bool = False
40 | simple_1: Simple = field(default_factory=Simple)
41 | simple_2: Optional[Simple] = None
42 |
43 | def test_nested():
44 | hp = Nested()
45 | hp.override_from_str("simple_1.mandatory_nodefault=8,simple_2=on,simple_2.mandatory_withdefault=HELLO")
46 | with pytest.raises(TypeError):
47 | hp.validate() # simple_2.mandatory_nodefault unset
48 | hp.override_from_dict({'simple_2/mandatory_nodefault': 7, 'simple_1/optional_nodefault': 55}, separator='/')
49 | hp.validate()
50 | assert hp.simple_1.mandatory_nodefault == 8
51 | assert hp.simple_1.mandatory_withdefault == "foo"
52 | assert hp.simple_1.optional_nodefault == 55
53 | assert hp.simple_2.mandatory_nodefault == 7
54 | assert hp.simple_2.mandatory_withdefault == "HELLO"
55 | assert hp.simple_2.optional_nodefault is None
56 |
57 | hp.override_from_str("simple_2=off")
58 | hp.validate()
59 | assert hp.simple_2 is None
60 |
61 | with pytest.raises((TypeError, AttributeError)):
62 | hp.override_from_str("simple_2.fun=True")
63 | with pytest.raises(ValueError):
64 | hp.override_from_str("simple_2=BADVAL")
65 |
66 | def test_nested_dict():
67 | hp = Nested()
68 | hp.override_from_nested_dict(
69 | {'simple_1': {'mandatory_nodefault': 8}, 'simple_2': {'mandatory_withdefault': "HELLO"}})
70 | with pytest.raises(TypeError):
71 | hp.validate() # simple_2.mandatory_nodefault unset
72 | hp.override_from_nested_dict(
73 | {'simple_2': {'mandatory_nodefault': 7}, 'simple_1': {'optional_nodefault': 55}, 'first': True})
74 | hp.validate()
75 | assert hp.to_nested_dict() == {
76 | 'first': True,
77 | 'simple_1': {
78 | 'mandatory_nodefault': 8,
79 | 'mandatory_withdefault': "foo",
80 | 'optional_nodefault': 55,
81 | 'fun': True,
82 | },
83 | 'simple_2': {
84 | 'mandatory_nodefault': 7,
85 | 'mandatory_withdefault': "HELLO",
86 | 'optional_nodefault': None,
87 | 'fun': True,
88 | },
89 | }
90 |
91 | def test_nested_order():
92 | hp = Nested()
93 | # Either order should work
94 | hp.override_from_str_dict(OrderedDict([('simple_2.fun', 'True'), ('simple_2', 'on')]))
95 | hp.override_from_str_dict(OrderedDict([('simple_2', 'on'), ('simple_2.fun', 'True')]))
96 |
97 | @dataclass
98 | class Deeply(hyperparams.HParams):
99 | nested: Nested = None
100 |
101 | def test_deeply_nested():
102 | hp = Deeply()
103 | hp.override_from_str("nested.simple_2=on")
104 | assert hp.nested is not None
105 | assert hp.nested.simple_2 is not None
106 |
107 | hp = Deeply()
108 | hp.override_from_dict({'nested.simple_2': 'on'})
109 | assert hp.nested is not None
110 | assert hp.nested.simple_2 is not None
111 |
112 | def test_set_order():
113 | hp = Deeply()
114 | hp.override_from_dict(OrderedDict([('nested.first', True), ('nested.simple_1', 'on')]))
115 | assert hp.nested.first is True
116 |
117 | hp = Deeply()
118 | hp.override_from_dict(OrderedDict([('nested.simple_1', 'on'), ('nested.first', True)]))
119 | assert hp.nested.first is True
120 |
--------------------------------------------------------------------------------
/sample.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 |
3 | import os
4 | from functools import partial
5 |
6 | from mpi4py import MPI
7 | import tensorflow as tf
8 |
9 | from lm_human_preferences.utils import launch, hyperparams
10 | from lm_human_preferences.utils import core as utils
11 | from lm_human_preferences.policy import Policy
12 | from lm_human_preferences.language import trained_models
13 | from lm_human_preferences import lm_tasks
14 | from lm_human_preferences import train_policy
15 |
16 | def sample_policy(save_dir=None, savescope='policy', temperature=1.0, seed=None, batch_size=4, nsamples=0):
17 | hparams = train_policy.HParams()
18 | hparams.override_from_json_file(os.path.join(save_dir, 'train_policy_hparams.json'))
19 | print('hparams', hparams)
20 | task = hparams.task
21 |
22 | comm = MPI.COMM_WORLD
23 | nsamples_per_rank = utils.exact_div(nsamples, comm.Get_size())
24 | with tf.Graph().as_default():
25 | m = trained_models.TrainedModel(name='sample', savedir=os.path.join(save_dir, 'policy'), scope='policy')
26 | encoder = m.encoding.get_encoder()
27 | hyperparams.dump(m.hparams(), name='model_hparams')
28 |
29 | utils.set_mpi_seed(seed)
30 |
31 | policy = Policy(
32 | m, scope='policy',
33 | is_root=True, # just init on every rank, simplifies code
34 | embed_queries=lm_tasks.query_formatter(task, encoder),
35 | temperature=temperature,
36 | )
37 |
38 | query_sampler = lm_tasks.make_query_sampler(
39 | hparams=task, encoder=encoder, comm=comm,
40 | batch_size=batch_size, mode='test'
41 | )
42 |
43 | init_ops = tf.group(
44 | tf.global_variables_initializer(),
45 | tf.local_variables_initializer(),
46 | )
47 |
48 | with utils.mpi_session() as sess:
49 | init_ops.run()
50 | @utils.graph_function()
51 | def sample_queries():
52 | return query_sampler()['tokens']
53 |
54 | tf.get_default_graph().finalize()
55 |
56 | generated = 0
57 | while nsamples_per_rank == 0 or generated < nsamples_per_rank:
58 | queries = sample_queries()
59 | rollouts = policy.respond(queries, length=task.response_length)
60 | assert len(queries.tolist()) == batch_size
61 | assert len(rollouts['responses'].tolist()) == batch_size
62 | for q, r in zip(queries.tolist(), rollouts['responses'].tolist()):
63 | print('=' * 80)
64 | print(encoder.decode(q).replace("\n", "⏎"))
65 | print(encoder.decode(r).replace("\n", "⏎"))
66 | generated += batch_size
67 |
68 | def launch_sample(mode='local', mpi=8, **kwargs):
69 | launch.launch('sample', partial(sample_policy, **kwargs), mode=mode, mpi=mpi)
70 |
71 | if __name__ == '__main__':
72 | launch.main(dict(
73 | sample=launch_sample,
74 | ))
75 |
76 | """
77 | ./sample.py sample --save_dir gs://jeffwu-rcall/results/safety/lmhf-sent-69c5170-1909161359/ --mpi 8
78 | """
79 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import pkg_resources
4 |
5 | from setuptools import setup, find_packages
6 |
7 | os.environ['CC'] = 'g++'
8 |
9 | setup(name='lm_human_preferences',
10 | version='0.0.1',
11 | packages=find_packages(include=['lm_human_preferences']),
12 | include_package_data=True,
13 | )
14 |
--------------------------------------------------------------------------------