├── .gitattributes ├── .gitignore ├── CONTRIBUTORS.md ├── DEVELOPERS.md ├── Dockerfile.cpu ├── Dockerfile.gpu ├── LICENSE ├── README.md ├── domains.txt ├── download_model.py ├── encode.py ├── model_card.md ├── requirements.txt ├── src ├── accumulate.py ├── encoder.py ├── generate_unconditional_samples.py ├── interactive_conditional_samples.py ├── load_dataset.py ├── memory_saving_gradients.py ├── model.py ├── sample.py ├── tfremat.py └── twremat.py ├── train.py └── twremat ├── README.md ├── main ├── remat.hs └── test.hs ├── src ├── Balanced.hs ├── Dense.hs ├── Filter.hs ├── Graph.hs ├── TWRemat.hs ├── TreeWidth.hs ├── Tupfile └── Util.hs ├── test ├── TestBalanced.hs ├── TestGraph.hs ├── TestTreeWidth.hs └── Tupfile └── twremat.cabal /.gitattributes: -------------------------------------------------------------------------------- 1 | # convert to OS line endings on checkout, back to LF on commit 2 | * text=auto 3 | 4 | # ensure anything copied to the container has unix style line endings 5 | *.sh text eol=lf 6 | requirements.txt text eol=lf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .mypy_cache/ 3 | models/ 4 | checkpoint 5 | samples 6 | dist-newstyle 7 | bin 8 | -------------------------------------------------------------------------------- /CONTRIBUTORS.md: -------------------------------------------------------------------------------- 1 | # Contributors (alphabetically) 2 | 3 | * **[madisonmay](https://github.com/madisonmay)** 4 | 5 | Added Dockerfiles 6 | 7 | * **[Margaret Mitchell et al](https://arxiv.org/abs/1810.03993)** 8 | 9 | Our [usage](./README.md#usage) writeup was loosely inspired by the paper 10 | [Model Cards for Model Reporting](https://arxiv.org/abs/1810.03993) 11 | and related conversations with some of the authors. 12 | 13 | * **[webproduktion01](https://github.com/webproduktion01)** 14 | 15 | Ported download script to python. 16 | 17 | **[Full code contributors list](https://github.com/openai/gpt-2/contributors).** 18 | -------------------------------------------------------------------------------- /DEVELOPERS.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | Git clone this repository, and `cd` into directory for remaining commands 4 | ``` 5 | git clone https://github.com/openai/gpt-2.git && cd gpt-2 6 | ``` 7 | 8 | Then, follow instructions for either native or Docker installation. 9 | 10 | ## Native Installation 11 | 12 | All steps can optionally be done in a virtual environment using tools such as `virtualenv` or `conda`. 13 | 14 | Install tensorflow 1.12 (with GPU support, if you have a GPU and want everything to run faster) 15 | ``` 16 | pip3 install tensorflow==1.12.0 17 | ``` 18 | or 19 | ``` 20 | pip3 install tensorflow-gpu==1.12.0 21 | ``` 22 | 23 | Install other python packages: 24 | ``` 25 | pip3 install -r requirements.txt 26 | ``` 27 | 28 | Download the model data 29 | ``` 30 | python3 download_model.py 124M 31 | python3 download_model.py 355M 32 | python3 download_model.py 774M 33 | python3 download_model.py 1558M 34 | ``` 35 | 36 | ## Docker Installation 37 | 38 | Build the Dockerfile and tag the created image as `gpt-2`: 39 | ``` 40 | docker build --tag gpt-2 -f Dockerfile.gpu . # or Dockerfile.cpu 41 | ``` 42 | 43 | Start an interactive bash session from the `gpt-2` docker image. 44 | 45 | You can opt to use the `--runtime=nvidia` flag if you have access to a NVIDIA GPU 46 | and a valid install of [nvidia-docker 2.0](https://github.com/nvidia/nvidia-docker/wiki/Installation-(version-2.0)). 47 | ``` 48 | docker run --runtime=nvidia -it gpt-2 bash 49 | ``` 50 | 51 | # Running 52 | 53 | | WARNING: Samples are unfiltered and may contain offensive content. | 54 | | --- | 55 | 56 | Some of the examples below may include Unicode text characters. Set the environment variable: 57 | ``` 58 | export PYTHONIOENCODING=UTF-8 59 | ``` 60 | to override the standard stream settings in UTF-8 mode. 61 | 62 | ## Unconditional sample generation 63 | 64 | To generate unconditional samples from the small model: 65 | ``` 66 | python3 src/generate_unconditional_samples.py | tee /tmp/samples 67 | ``` 68 | There are various flags for controlling the samples: 69 | ``` 70 | python3 src/generate_unconditional_samples.py --top_k 40 --temperature 0.7 | tee /tmp/samples 71 | ``` 72 | 73 | To check flag descriptions, use: 74 | ``` 75 | python3 src/generate_unconditional_samples.py -- --help 76 | ``` 77 | 78 | ## Conditional sample generation 79 | 80 | To give the model custom prompts, you can use: 81 | ``` 82 | python3 src/interactive_conditional_samples.py --top_k 40 83 | ``` 84 | 85 | To check flag descriptions, use: 86 | ``` 87 | python3 src/interactive_conditional_samples.py -- --help 88 | ``` 89 | -------------------------------------------------------------------------------- /Dockerfile.cpu: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.12.0-py3 2 | 3 | ENV LANG=C.UTF-8 4 | RUN mkdir /gpt-2 5 | WORKDIR /gpt-2 6 | ADD . /gpt-2 7 | RUN pip3 install -r requirements.txt 8 | RUN python3 download_model.py 124M 9 | RUN python3 download_model.py 355M 10 | RUN python3 download_model.py 774M 11 | RUN python3 download_model.py 1558M 12 | -------------------------------------------------------------------------------- /Dockerfile.gpu: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.12.0-gpu-py3 2 | 3 | # nvidia-docker 1.0 4 | LABEL com.nvidia.volumes.needed="nvidia_driver" 5 | LABEL com.nvidia.cuda.version="${CUDA_VERSION}" 6 | 7 | # nvidia-container-runtime 8 | ENV NVIDIA_VISIBLE_DEVICES=all \ 9 | NVIDIA_DRIVER_CAPABILITIES=compute,utility \ 10 | NVIDIA_REQUIRE_CUDA="cuda>=8.0" \ 11 | LANG=C.UTF-8 12 | 13 | RUN mkdir /gpt-2 14 | WORKDIR /gpt-2 15 | ADD . /gpt-2 16 | RUN pip3 install -r requirements.txt 17 | RUN python3 download_model.py 124M 18 | RUN python3 download_model.py 355M 19 | RUN python3 download_model.py 774M 20 | RUN python3 download_model.py 1558M 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Modified MIT License 2 | 3 | Software Copyright (c) 2019 OpenAI 4 | 5 | We don’t claim ownership of the content you create with GPT-2, so it is yours to do with as you please. 6 | We only ask that you use GPT-2 responsibly and clearly indicate your content was created using GPT-2. 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and 9 | associated documentation files (the "Software"), to deal in the Software without restriction, 10 | including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, 11 | and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, 12 | subject to the following conditions: 13 | 14 | The above copyright notice and this permission notice shall be included 15 | in all copies or substantial portions of the Software. 16 | The above copyright notice and this permission notice need not be included 17 | with content created by the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 20 | INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 22 | BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 23 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 24 | OR OTHER DEALINGS IN THE SOFTWARE. 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Fine tuning on custom datasets 2 | 3 | Reference: ["Beginner’s Guide to Retrain GPT-2 (117M) to Generate Custom Text Content"](https://medium.com/@ngwaifoong92/beginners-guide-to-retrain-gpt-2-117m-to-generate-custom-text-content-8bb5363d8b7f) 4 | 5 | To retrain GPT-2 117M model on a custom text dataset: 6 | 7 | ``` 8 | PYTHONPATH=src ./train.py --dataset 9 | ``` 10 | 11 | If you want to precompute the dataset's encoding for multiple runs, you can instead use: 12 | 13 | ``` 14 | PYTHONPATH=src ./encode.py /path/to/encoded.npz 15 | PYTHONPATH=src ./train.py --dataset /path/to/encoded.npz 16 | ``` 17 | 18 | Make sure `cudnn` is installed. [Some have 19 | reported](https://github.com/nshepperd/gpt-2/issues/8) that `train.py` 20 | runs without it but has worse memory usage and might OOM. 21 | 22 | ### Tensor Rematerialization 23 | 24 | Experimental: a rematerialization rewriter based on `Efficient 25 | Rematerialization for Deep Networks` 26 | , 27 | which unlike gradient checkpointing works in tensorflow 2.0 and is 28 | able to automatically select checkpoints in arbitrary graphs. Using 29 | this I was able to finetune GPT-2 1.5B on a single graphics card using 30 | slightly less than 12G of video ram with very little slowdown. 31 | 32 | To use this is a little involved, because the graph optimization 33 | algorithm is offloaded to an optimized Haskell program. First, go into 34 | subdirectory `twremat`, and build it by invoking: 35 | 36 | cabal v2-install --installdir=../bin 37 | 38 | (You'll need to install cabal if you haven't already -- but setting up 39 | ghc and haskell compilation is beyond the scope of this README.) 40 | 41 | Then run `train.py` as normal, enabling `--twremat` and setting 42 | `--twremat_memlimit` to an appropriate value -- this sets the amount 43 | of memory assumed to be available for computation of gradients, so it 44 | should be roughly the memory size of your graphics card minus whatever 45 | is taken up by the gpt-2 weights, and any other bookkeeping 46 | variables. You may need to experiment with the memlimit until you find 47 | the largest value that doesn't OOM. 48 | 49 | (You probably also want to use SGD as optimizer instead of Adam to 50 | minimize those bookkeeping variables, of which Adam uses a lot). 51 | 52 | ### Gradient Checkpointing 53 | 54 | https://github.com/openai/gradient-checkpointing is included to reduce 55 | the memory requirements of the model, and can be enabled by 56 | `--memory_saving_gradients`. The checkpoints are currently chosen 57 | manually (poorly) by just adding layer 10 to the 'checkpoints' 58 | collection in model.py. 59 | 60 | Gradient checkpointing doesn't work in tensorflow v2.0 and later due 61 | to the removal of tf.contrib. You should use tensor rematerialization 62 | instead if possible. 63 | 64 | ### Validation loss 65 | 66 | Set `--val_every` to a number of steps `N > 0`, and "validation" loss 67 | against a fixed sample of the dataset will be calculated every N steps 68 | to get a better sense of training progress. N around 200 69 | suggested. You can set `--val_dataset` to choose a separate validation 70 | dataset, otherwise it defaults to a sample from the train dataset (so 71 | not a real cross-validation loss!). 72 | 73 | ### Optimizer 74 | 75 | You can use SGD instead of Adam with `--optimizer sgd`. This also 76 | helps conserve memory when training larger models. Note: the learning 77 | rate needs to be adjusted for SGD, due to not having Adam's gradient 78 | normalization (0.0006 seems to be a good number from some 79 | experiments). 80 | 81 | # Original README 82 | 83 | **Status:** Archive (code is provided as-is, no updates expected) 84 | 85 | # gpt-2 86 | 87 | Code and models from the paper ["Language Models are Unsupervised Multitask Learners"](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf). 88 | 89 | You can read about GPT-2 and its staged release in our [original blog post](https://blog.openai.com/better-language-models/), [6 month follow-up post](https://openai.com/blog/gpt-2-6-month-follow-up/), and [final post](https://www.openai.com/blog/gpt-2-1-5b-release/). 90 | 91 | We have also [released a dataset](https://github.com/openai/gpt-2-output-dataset) for researchers to study their behaviors. 92 | 93 | * *Note that our original parameter counts were wrong due to an error (in our previous blog posts and paper). Thus you may have seen small referred to as 117M and medium referred to as 345M.* 94 | 95 | ## Usage 96 | 97 | This repository is meant to be a starting point for researchers and engineers to experiment with GPT-2. 98 | 99 | For basic information, see our [model card](./model_card.md). 100 | 101 | ### Some caveats 102 | 103 | - GPT-2 models' robustness and worst case behaviors are not well-understood. As with any machine-learned model, carefully evaluate GPT-2 for your use case, especially if used without fine-tuning or in safety-critical applications where reliability is important. 104 | - The dataset our GPT-2 models were trained on contains many texts with [biases](https://twitter.com/TomerUllman/status/1101485289720242177) and factual inaccuracies, and thus GPT-2 models are likely to be biased and inaccurate as well. 105 | - To avoid having samples mistaken as human-written, we recommend clearly labeling samples as synthetic before wide dissemination. Our models are often incoherent or inaccurate in subtle ways, which takes more than a quick read for a human to notice. 106 | 107 | ### Work with us 108 | 109 | Please [let us know](mailto:languagequestions@openai.com) if you’re doing interesting research with or working on applications of GPT-2! We’re especially interested in hearing from and potentially working with those who are studying 110 | - Potential malicious use cases and defenses against them (e.g. the detectability of synthetic text) 111 | - The extent of problematic content (e.g. bias) being baked into the models and effective mitigations 112 | 113 | ## Development 114 | 115 | See [DEVELOPERS.md](./DEVELOPERS.md) 116 | 117 | ## Contributors 118 | 119 | See [CONTRIBUTORS.md](./CONTRIBUTORS.md) 120 | 121 | ## Citation 122 | 123 | Please use the following bibtex entry: 124 | ``` 125 | @article{radford2019language, 126 | title={Language Models are Unsupervised Multitask Learners}, 127 | author={Radford, Alec and Wu, Jeff and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya}, 128 | year={2019} 129 | } 130 | ``` 131 | 132 | ## Future work 133 | 134 | We may release code for evaluating the models on various benchmarks. 135 | 136 | We are still considering release of the larger models. 137 | 138 | ## License 139 | 140 | [Modified MIT](./LICENSE) 141 | -------------------------------------------------------------------------------- /domains.txt: -------------------------------------------------------------------------------- 1 | 1542261 google 2 | 596207 archive 3 | 456344 blogspot 4 | 414695 github 5 | 333160 nytimes 6 | 321622 wordpress 7 | 315368 washingtonpost 8 | 313137 wikia 9 | 311917 bbc 10 | 246303 theguardian 11 | 210714 ebay 12 | 209416 pastebin 13 | 199360 cnn 14 | 196124 yahoo 15 | 186668 huffingtonpost 16 | 186137 go 17 | 183592 reuters 18 | 183080 imdb 19 | 160553 goo 20 | 139965 nih 21 | 135562 cbc 22 | 128011 apple 23 | 125615 medium 24 | 118676 dailymail 25 | 108012 steampowered 26 | 106417 independent 27 | 105239 etsy 28 | 98941 craigslist 29 | 93048 businessinsider 30 | 92712 telegraph 31 | 90262 wizards 32 | 83266 usatoday 33 | 80384 thehill 34 | 79655 nhl 35 | 79494 foxnews 36 | 79167 taobao 37 | 78070 bloomberg 38 | 77515 npr 39 | 77407 mlb 40 | 77172 latimes 41 | 75676 megalodon 42 | 72525 espn 43 | 72523 kickstarter 44 | 71743 breitbart 45 | 69334 abc 46 | 68009 newegg 47 | 67008 wwe 48 | 66278 myanimelist 49 | 65520 microsoft 50 | 64723 buzzfeed 51 | 63162 vice 52 | 62911 indiatimes 53 | 61845 forbes 54 | 61772 tappedout 55 | 60889 wsj 56 | 60240 vid 57 | 60239 battle 58 | 59996 adf 59 | 58706 politico 60 | 58345 redditgifts 61 | 56769 nexusmods 62 | 56469 goodreads 63 | 54866 magiccards 64 | 53973 nbcnews 65 | 53060 gamepedia 66 | 52110 mediafire 67 | 50567 time 68 | 50144 cbsnews 69 | 49203 ppy 70 | 48442 gstatic 71 | 48042 nfl 72 | 47460 steamusercontent 73 | 47046 thestar 74 | 46603 bugguide 75 | 46340 fanfiction 76 | 45505 mturk 77 | 45458 cbslocal 78 | 44729 theglobeandmail 79 | 44134 nydailynews 80 | 42992 theatlantic 81 | 42941 netflix 82 | 42328 theverge 83 | 41952 smh 84 | 40694 nbcsports 85 | 40613 cnbc 86 | 40469 slate 87 | 40071 ign 88 | 39655 dotabuff 89 | 38968 wired 90 | 38779 chicagotribune 91 | 38590 urbandictionary 92 | 38575 rt 93 | 38092 wuxiaworld 94 | 38065 wowhead 95 | 37954 wolframalpha 96 | 37749 guardian 97 | 37594 xboxdvr 98 | 36841 nypost 99 | 36741 ravelry 100 | 36321 thedailybeast 101 | 36298 nba 102 | 36188 yelp 103 | 36008 arstechnica 104 | 35485 csgo 105 | 35365 flic 106 | 35269 stackexchange 107 | 35124 vidble 108 | 35024 googleusercontent 109 | 34311 msn 110 | 34121 gizmodo 111 | 34120 boardgamegeek 112 | 33867 aljazeera 113 | 33598 rawstory 114 | 33516 scryfall 115 | 33467 bleacherreport 116 | 33419 bit 117 | 33395 thinkprogress 118 | 33170 dailycaller 119 | 32843 ap 120 | 32433 fangraphs 121 | 31742 salon 122 | 31728 mirror 123 | 31496 nintendo 124 | 31294 nationalpost 125 | 31278 nasa 126 | 31110 oddshot 127 | 31057 hltv 128 | 30952 amzn 129 | 30877 quora 130 | 30586 engadget 131 | 30397 stackoverflow 132 | 30201 aliexpress 133 | 29710 cnet 134 | 28850 leagueoflegends 135 | 28822 surveymonkey 136 | 28704 ctvnews 137 | 28650 walmart 138 | 28644 plays 139 | 28536 sfgate 140 | 28375 cbssports 141 | 28210 globo 142 | 27992 discogs 143 | 27630 wiktionary 144 | 27588 ibb 145 | 27544 stuff 146 | 27349 nature 147 | 27112 news 148 | 27020 biblegateway 149 | 26801 subtletv 150 | 26427 change 151 | 26355 zippyshare 152 | 26311 guildwars2 153 | 26231 vox 154 | 26205 zkillboard 155 | 26174 techcrunch 156 | 25993 economist 157 | 25964 globalnews 158 | 25621 washingtontimes 159 | 25610 hollywoodreporter 160 | 25351 archiveofourown 161 | 25336 ibtimes 162 | 25257 newsweek 163 | 25139 zerohedge 164 | 25074 fav 165 | 25050 sciencedirect 166 | 24894 bestbuy 167 | 24870 spiegel 168 | 24869 247sports 169 | 24866 smmry 170 | 24764 xda-developers 171 | 24726 tvtropes 172 | 24698 phys 173 | 24663 teamliquid 174 | 24619 state 175 | 23953 gleam 176 | 23676 sbnation 177 | 23644 asahi 178 | 23620 foxsports 179 | 23240 ndtv 180 | 23189 si 181 | 23183 alternet 182 | 23009 redbubble 183 | 22846 metro 184 | 22845 theonion 185 | 22835 playstation 186 | 22808 washingtonexaminer 187 | 22682 thehindu 188 | 22557 espncricinfo 189 | 22482 mozilla 190 | 22219 op 191 | 22038 t 192 | 21984 nj 193 | 21921 indianexpress 194 | 21707 apnews 195 | 21603 dw 196 | 21422 nationalgeographic 197 | 21399 pinterest 198 | 21368 ft 199 | 21319 wiley 200 | 21254 about 201 | 21074 skysports 202 | 21033 gamespot 203 | 21014 dailykos 204 | 21009 goal 205 | 20858 patheos 206 | 20842 irishtimes 207 | 20664 variety 208 | 20592 kotaku 209 | 20584 mashable 210 | 20575 scientificamerican 211 | 20448 basketball-reference 212 | 20262 yle 213 | 20218 theage 214 | 20176 usnews 215 | 20133 animenewsnetwork 216 | 20092 livejournal 217 | 20068 218 | 20024 pbs 219 | 19802 nhk 220 | 19741 newyorker 221 | 19727 seattletimes 222 | 19672 mlssoccer 223 | 19619 meetup 224 | 19543 nzherald 225 | 19509 philly 226 | 19496 uol 227 | 19470 patreon 228 | 19429 wikileaks 229 | 19400 gravitytales 230 | 19294 oregonlive 231 | 19267 xbox 232 | 19216 linkedin 233 | 19202 crunchyroll 234 | 19045 target 235 | 19021 ew 236 | 18922 redditpoll 237 | 18875 homedepot 238 | 18867 qz 239 | 18865 donmai 240 | 18653 baseball-reference 241 | 18646 talkingpointsmemo 242 | 18576 pathofexile 243 | 18536 makeameme 244 | 18489 postimg 245 | 18308 clyp 246 | 18175 scribd 247 | 18120 thegatewaypundit 248 | 18097 removeddit 249 | 18063 deadspin 250 | 18049 sciencedaily 251 | 18019 huffpost 252 | 17987 dallasnews 253 | 17956 europa 254 | 17878 merriam-webster 255 | 17816 haaretz 256 | 17746 deadline 257 | 17637 msnbc 258 | 17579 hindustantimes 259 | 17531 nymag 260 | 17429 gph 261 | 17208 typepad 262 | 17204 express 263 | 17098 naver 264 | 17085 bizjournals 265 | 17084 mlive 266 | 16834 rollingstone 267 | 16793 motherjones 268 | 16704 okcupid 269 | 16441 tinyurl 270 | 16410 espnfc 271 | 16397 bostonglobe 272 | 16374 thingiverse 273 | 16351 denverpost 274 | 16332 bitcointalk 275 | 16256 timesofisrael 276 | 16209 xnxx 277 | 16202 wikihow 278 | 16051 neopets 279 | 16043 indiegogo 280 | 16033 al 281 | 16032 chron 282 | 16004 avclub 283 | 15970 marketwatch 284 | 15933 mercurynews 285 | 15675 startribune 286 | 15646 pro-football-reference 287 | 15568 d20pfsrd 288 | 15545 pcgamer 289 | 15451 reason 290 | 15422 uesp 291 | 15356 lds 292 | 15152 polygon 293 | 15132 humblebundle 294 | 14962 tradingview 295 | 14931 baltimoresun 296 | 14914 strava 297 | 14912 firstpost 298 | 14856 commondreams 299 | 14801 sky 300 | 14739 eventbrite 301 | 14722 nicovideo 302 | 14697 fortune 303 | 14693 knowyourmeme 304 | 14666 robertsspaceindustries 305 | 14471 pitchfork 306 | 14466 psychologytoday 307 | 14435 combodeck 308 | 14392 mixcloud 309 | 14372 lemonde 310 | 14290 sciencemag 311 | 14060 jpost 312 | 13926 miamiherald 313 | 13902 patch 314 | 13850 nationalreview 315 | 13849 gofundme 316 | 13798 thelocal 317 | 13763 derpibooru 318 | 13726 techdirt 319 | 13658 townhall 320 | 13596 mtg 321 | 13588 gettyimages 322 | 13530 mit 323 | 13436 challonge 324 | 13369 mediaite 325 | 13357 tsn 326 | 13350 pokemonshowdown 327 | 13176 neogaf 328 | 13130 publico 329 | 13126 snopes 330 | 13092 scmp 331 | 13082 cleveland 332 | 13044 thesun 333 | 13025 mtggoldfish 334 | 12994 freep 335 | 12984 grailed 336 | 12948 standard 337 | 12923 theconversation 338 | 12913 upi 339 | 12870 bing 340 | 12778 blockchain 341 | 12774 people 342 | 12771 arxiv 343 | 12760 hearthpwn 344 | 12668 reference 345 | 12626 edhrec 346 | 12611 sputniknews 347 | 12551 nordstrom 348 | 12550 lapresse 349 | 12496 metacritic 350 | 12447 last 351 | 12395 ajc 352 | 12355 mangadex 353 | 12349 ycombinator 354 | 12345 csmonitor 355 | 12240 sportsnet 356 | 12229 cornell 357 | 12205 smithsonianmag 358 | 12201 sephora 359 | 12194 bulbagarden 360 | 12181 japantimes 361 | 12171 zdnet 362 | 12152 comicbook 363 | 12139 whitehouse 364 | 12109 theregister 365 | 12089 libsyn 366 | 12052 asos 367 | 12016 neatclip 368 | 12001 imirhil 369 | 12000 boston 370 | 11973 behance 371 | 11966 eveonline 372 | 11954 androidpolice 373 | 11935 livescience 374 | 11843 instructables 375 | 11817 hs 376 | 11788 infowars 377 | 11712 ca 378 | 11704 runescape 379 | 11699 suntimes 380 | 11697 eurogamer 381 | 11654 roblox 382 | 11622 genius 383 | 11602 stltoday 384 | 11499 elpais 385 | 11494 motorsport 386 | 11461 ceddit 387 | 11426 france24 388 | 11373 bungie 389 | 11371 youtubedoubler 390 | 11362 openload 391 | 11348 jstor 392 | 11328 thefreedictionary 393 | 11307 inquisitr 394 | 11215 nhentai 395 | 11204 zeit 396 | 11198 ikea 397 | 11114 springer 398 | 11108 tripadvisor 399 | 11082 thescore 400 | 11036 kerbalspaceprogram 401 | 11007 cdc 402 | 10995 dailywire 403 | 10965 gawker 404 | 10953 a 405 | 10950 brooksbaseball 406 | 10940 dn 407 | 10927 sltrib 408 | 10867 brickset 409 | 10823 dictionary 410 | 10821 squarespace 411 | 10819 battlefield 412 | 10807 harvard 413 | 10786 afpbb 414 | 10734 steemit 415 | 10730 billboard 416 | 10707 tampabay 417 | 10654 nola 418 | 10621 stanford 419 | 10602 sbs 420 | 10524 cc 421 | 10520 dailydot 422 | 10510 straitstimes 423 | 10493 itch 424 | 10490 foreignpolicy 425 | 10465 vancouversun 426 | 10440 rottentomatoes 427 | 10419 dnainfo 428 | 10389 digi24 429 | 10348 dropboxusercontent 430 | 10332 complex 431 | 10330 scp-wiki 432 | 10327 prnt 433 | 10313 ottawacitizen 434 | 10304 anandtech 435 | 10269 thenation 436 | 10253 fivethirtyeight 437 | 10244 newscientist 438 | 10240 svt 439 | 10240 inquirer 440 | 10236 coindesk 441 | 10227 codepen 442 | 10208 lichess 443 | 10204 sankei 444 | 10189 ted 445 | 10181 roosterteeth 446 | 10170 livemint 447 | 10161 teamfortress 448 | 10141 sourceforge 449 | 10119 sapo 450 | 10113 countle 451 | 10086 mtv 452 | 10075 sacbee 453 | 10066 fimfiction 454 | 10057 hentai-foundry 455 | 10054 gamesplanet 456 | 10044 io9 457 | 10032 lifehacker 458 | 10007 cracked 459 | 9991 mainichi 460 | 9984 itmedia 461 | 9966 warthunder 462 | 9936 nos 463 | 9935 boingboing 464 | 9925 vulture 465 | 9904 lanacion 466 | 9892 qualtrics 467 | 9884 muthead 468 | 9856 jcrew 469 | 9814 jsonline 470 | 9787 spacebattles 471 | 9748 worldstarhiphop 472 | 9734 jalopnik 473 | 9721 welt 474 | 9717 curbed 475 | 9708 dbr 476 | 9705 mmafighting 477 | 9697 bigcartel 478 | 9682 transfermarkt 479 | 9680 vlive 480 | 9659 vanityfair 481 | 9658 dawn 482 | 9621 dnaindia 483 | 9601 theblaze 484 | 9599 allrecipes 485 | 9576 thejournal 486 | 9572 dailystar 487 | 9521 minecraftforum 488 | 9505 theweek 489 | 9502 kansascity 490 | 9494 anilist 491 | 9443 gog 492 | 9420 bato 493 | 9401 oxforddictionaries 494 | 9400 soompi 495 | 9394 sagepub 496 | 9389 wikiwand 497 | 9382 lolking 498 | 9322 torontosun 499 | 9319 mangapanda 500 | 9316 politifact 501 | 9306 realclearpolitics 502 | 9278 tagpro 503 | 9261 webmd 504 | 9206 app 505 | 9202 hotnews 506 | 9184 9news 507 | 9174 bhphotovideo 508 | 9147 giantbomb 509 | 9132 gamestop 510 | 9073 azcentral 511 | 9053 noaa 512 | 9040 repubblica 513 | 9021 mangaupdates 514 | 8998 space 515 | 8998 researchgate 516 | 8971 bitcoin 517 | 8957 sueddeutsche 518 | 8898 rightwingwatch 519 | 8892 mediacru 520 | 8890 afl 521 | 8862 fasttech 522 | 8858 tmz 523 | 8841 orlandosentinel 524 | 8832 tomshardware 525 | 8828 altomfotball 526 | 8822 mtgprice 527 | 8821 haskell 528 | 8816 discovery 529 | 8810 destinytracker 530 | 8808 massdrop 531 | 8800 csgolounge 532 | 8791 weather 533 | 8778 daddyleagues 534 | 8720 govtrack 535 | 8678 mentalfloss 536 | 8678 justice 537 | 8663 frontier 538 | 8655 youporn 539 | 8641 paradoxplaza 540 | 8640 rockstargames 541 | 8632 derstandard 542 | 8622 pinknews 543 | 8619 macrumors 544 | 8598 gamefaqs 545 | 8587 thepiratebay 546 | 8586 4chan 547 | 8582 post-gazette 548 | 8573 faz 549 | 8563 e-hentai 550 | 8530 jiji 551 | 8525 quoracdn 552 | 8519 fullmatchesandshows 553 | 8516 sun-sentinel 554 | 8513 xboxclips 555 | 8488 financialpost 556 | 8476 audible 557 | 8439 investopedia 558 | 8425 loc 559 | 8418 venturebeat 560 | 8414 amazonaws 561 | 8368 ubi 562 | 8345 etymonline 563 | 8326 wsws 564 | 8316 jezebel 565 | 8300 americanthinker 566 | 8284 wikidot 567 | 8269 digitaltrends 568 | 8260 nrk 569 | 8232 weebly 570 | 8228 thenextweb 571 | 8225 snahp 572 | 8223 gematsu 573 | 8210 daum 574 | 8206 ea 575 | 8189 liverpoolecho 576 | 8186 freebeacon 577 | 8178 thetimes 578 | 8168 naturalcrit 579 | 8153 warframe 580 | 8150 1drv 581 | 8143 gap 582 | 8131 seriouseats 583 | 8119 myfigurecollection 584 | 8109 gov 585 | 8086 eporner 586 | 8080 hulu 587 | 8077 senate 588 | 8046 esquire 589 | 8015 gosugamers 590 | 8000 radionz 591 | 7997 eater 592 | 7982 politicususa 593 | 7978 rte 594 | 7956 marvel 595 | 7942 metronews 596 | 7917 starcitygames 597 | 7917 hotair 598 | 7914 marca 599 | 7872 eurekalert 600 | 7840 screenrant 601 | 7834 dota2 602 | 7797 truth-out 603 | 7784 dell 604 | 7783 eldiario 605 | 7782 pcworld 606 | 7782 doi 607 | 7780 comicbookresources 608 | 7765 dr 609 | 7729 howstuffworks 610 | 7727 gocomics 611 | 7715 worldoftanks 612 | 7707 tandfonline 613 | 7690 examiner 614 | 7688 newrepublic 615 | 7682 curseforge 616 | 7680 findlaw 617 | 7673 nikkei 618 | 7665 heraldsun 619 | 7652 podbean 620 | 7645 aftonbladet 621 | 7638 duckduckgo 622 | 7633 ynetnews 623 | 7629 timesofindia 624 | 7628 freshphase 625 | 7591 westeros 626 | 7576 youjizz 627 | 7574 spectator 628 | 7548 justia 629 | 7537 antiwar 630 | 7536 mmajunkie 631 | 7516 yomiuri 632 | 7485 newstatesman 633 | 7481 greenmangaming 634 | 7475 joystiq 635 | 7444 jsfiddle 636 | 7424 anime-planet 637 | 7415 counterpunch 638 | 7410 autosport 639 | 7395 archlinux 640 | 7384 berkeley 641 | 7383 smbc-comics 642 | 7374 rockpapershotgun 643 | 7372 pjmedia 644 | 7367 estadao 645 | 7365 intoday 646 | 7361 newsmax 647 | 7346 newsbusters 648 | 7337 grantland 649 | 7329 voanews 650 | 7292 myshopify 651 | 7286 wnd 652 | 7265 9to5mac 653 | 7257 hurriyetdailynews 654 | 7229 bleedingcool 655 | 7225 indiewire 656 | 7222 radio-canada 657 | 7216 viewsync 658 | 7211 cambridge 659 | 7204 drsd 660 | 7197 house 661 | 7185 uproxx 662 | 7152 mlbtraderumors 663 | 7145 gamasutra 664 | 7134 bricklink 665 | 7122 foodnetwork 666 | 7122 presstv 667 | 7119 opensecrets 668 | 7118 canada 669 | 7116 bgr 670 | 7097 democracynow 671 | 7091 businessweek 672 | 7085 smash 673 | 7080 usda 674 | 7078 cloudfront 675 | 7044 psu 676 | 7028 detroitnews 677 | 7028 explosm 678 | 7013 woobox 679 | 7011 football-italia 680 | 7005 academia 681 | 6948 channelnewsasia 682 | 6927 siliconera 683 | 6923 rei 684 | 6917 deseretnews 685 | 6916 supload 686 | 6914 mises 687 | 6905 rotoworld 688 | 6886 gsmarena 689 | 6878 rappler 690 | 6876 kijiji 691 | 6866 metal-archives 692 | 6826 theaustralian 693 | 6823 mediamatters 694 | 6823 wa 695 | 6818 bodybuilding 696 | 6811 memedad 697 | 6803 ucsd 698 | 6802 barnesandnoble 699 | 6791 india 700 | 6780 readability 701 | 6777 today 702 | 6726 indystar 703 | 6720 scotsman 704 | 6694 impress 705 | 6689 torrentfreak 706 | 6675 heise 707 | 6668 sportingnews 708 | 6658 pnas 709 | 6650 chzbgr 710 | 6650 milb 711 | 6631 business-standard 712 | 6630 bustle 713 | 6623 square-enix 714 | 6622 madison 715 | 6615 moddb 716 | 6613 uniqlo 717 | 6599 zillow 718 | 6577 tribune 719 | 6556 airliners 720 | 6552 svd 721 | 6547 gameinformer 722 | 6536 brisbanetimes 723 | 6536 ocregister 724 | 6533 swtor 725 | 6526 calgaryherald 726 | 6521 c-span 727 | 6518 slashdot 728 | 6505 belfasttelegraph 729 | 6499 hiyo 730 | 6494 news24 731 | 6484 theintercept 732 | 6479 technologyreview 733 | 6455 gutenberg 734 | 6449 cinemablend 735 | 6438 dailytelegraph 736 | 6424 globalresearch 737 | 6411 lefigaro 738 | 6405 tenor 739 | 6381 redstate 740 | 6374 aclu 741 | 6361 bloodyelbow 742 | 6357 axios 743 | 6353 thewrap 744 | 6349 redditmetrics 745 | 6345 evike 746 | 6339 aol 747 | 6327 ulta 748 | 6326 plos 749 | 6324 periscope 750 | 6312 drivethrurpg 751 | 6308 infobae 752 | 6300 debian 753 | 6298 congress 754 | 6289 warcraftlogs 755 | 6284 gothamist 756 | 6281 mangastream 757 | 6276 newgrounds 758 | 6275 berniesanders 759 | 6263 lolesports 760 | 6262 mayoclinic 761 | 6242 sfchronicle 762 | 6235 edmontonjournal 763 | 6200 dhgate 764 | 6194 cincinnati 765 | 6180 history 766 | 6176 xtube 767 | 6169 nike 768 | 6160 kiji 769 | 6147 tube8 770 | 6140 vdare 771 | 6133 unity3d 772 | 6130 twincities 773 | 6127 escapistmagazine 774 | 6126 komonews 775 | 6104 openneo 776 | 6090 oup 777 | 6082 dispatch 778 | 6079 newsobserver 779 | 6060 ballotpedia 780 | 6058 indiegala 781 | 6054 index 782 | 6050 charlotteobserver 783 | 6048 androidcentral 784 | 6032 webtoons 785 | 6028 tcgplayer 786 | 6018 zappos 787 | 6004 intel 788 | 5998 seattlepi 789 | 5996 profootballfocus 790 | 5990 ksl 791 | 5989 macleans 792 | 5984 atlasobscura 793 | 5981 yugiohprices 794 | 5980 ubuntu 795 | 5964 gq 796 | 5952 myvidster 797 | 5941 tv2 798 | 5930 paizo 799 | 5926 montrealgazette 800 | 5919 al-monitor 801 | 5919 herokuapp 802 | 5918 volarenovels 803 | 5909 usgs 804 | 5906 nme 805 | 5906 society6 806 | 5905 vg247 807 | 5902 popsci 808 | 5895 lowes 809 | 5893 thefederalist 810 | 5878 amiami 811 | 5862 nyti 812 | 5848 steamdb 813 | 5841 crooksandliars 814 | 5833 popularmechanics 815 | 5832 slashfilm 816 | 5826 woot 817 | 5818 ev 818 | 5807 illinois 819 | 5792 nps 820 | 5791 destructoid 821 | 5790 mysanantonio 822 | 5772 sbtl 823 | 5742 smashboards 824 | 5700 biblehub 825 | 5696 euronews 826 | 5694 urbanoutfitters 827 | 5687 itv 828 | 5685 fastcompany 829 | 5684 techpowerup 830 | 5674 hearthhead 831 | 5656 mic 832 | 5649 autoblog 833 | 5646 futbin 834 | 5638 voat 835 | 5636 statesman 836 | 5626 zap2it 837 | 5623 userbenchmark 838 | 5623 legaliq 839 | 5622 mspaintadventures 840 | 5622 familysearch 841 | 5616 themoscowtimes 842 | 5606 theprovince 843 | 5604 allkpop 844 | 5594 Omegle 845 | 5570 activistpost 846 | 5565 thefreethoughtproject 847 | 5565 in 848 | 5559 sandiegouniontribune 849 | 5556 consumerist 850 | 5554 eff 851 | 5532 lego 852 | 5520 translationnations 853 | 5515 clickhole 854 | 5498 etherscan 855 | 5491 live 856 | 5486 vndb 857 | 5484 poll-maker 858 | 5481 mtgsalvation 859 | 5481 computerworld 860 | 5475 comicvine 861 | 5470 python 862 | 5469 digitalspy 863 | 5468 citylab 864 | 5458 expressen 865 | 5455 oxfordjournals 866 | 5451 collider 867 | 5447 statista 868 | 5437 apa 869 | 5434 g 870 | 5430 thenational 871 | 5430 eslgaming 872 | 5425 politiken 873 | 5421 ktla 874 | 5420 webmshare 875 | 5408 bostonherald 876 | 5407 comixology 877 | 5400 ustream 878 | 5399 sony 879 | 5396 tennessean 880 | 5377 scout 881 | 5374 drop 882 | 5372 ieee 883 | 5359 sverigesradio 884 | 5356 sherdog 885 | 5353 viooz 886 | 5353 marxists 887 | 5353 adobe 888 | 5349 myfitnesspal 889 | 5342 seahawks 890 | 5339 rferl 891 | 5338 thediplomat 892 | 5335 storeparser 893 | 5332 prnewswire 894 | 5330 midwayusa 895 | 5327 liverpoolfc 896 | 5326 cisco 897 | 5326 windowsphone 898 | 5323 toysrus 899 | 5321 archivesofnethys 900 | 5317 eluniversal 901 | 5309 gmanetwork 902 | 5303 asus 903 | 5297 android 904 | 5297 finalfantasyxiv 905 | 5296 cyclingnews 906 | 5293 worldbank 907 | 5288 boxingscene 908 | 5285 ticketmaster 909 | 5279 grooveshark 910 | 5277 khl 911 | 5276 gallup 912 | 5268 britannica 913 | 5263 abc7 914 | 5260 penny-arcade 915 | 5257 hsreplay 916 | 5257 oculus 917 | 5256 bt 918 | 5250 theroot 919 | 5246 makeagif 920 | 5246 cnsnews 921 | 5243 nbc 922 | 5243 rbc 923 | 5243 fextralife 924 | 5234 legislation 925 | 5225 sendvid 926 | 5221 sciencealert 927 | 5214 wbur 928 | 5212 myfonts 929 | 5207 picsarus 930 | 5206 phoronix 931 | 5204 nerdist 932 | 5203 eonline 933 | 5195 advocate 934 | 5191 king5 935 | 5189 xkcd 936 | 5183 kitsu 937 | 5182 weibo 938 | 5181 mangareader 939 | 5178 palmbeachpost 940 | 5176 go1dfish 941 | 5175 livestrong 942 | 5174 truthdig 943 | 5173 lgbtqnation 944 | 5172 nikkansports 945 | 5167 slickdeals 946 | 5166 streamja 947 | 5164 irs 948 | 5158 readms 949 | 5152 microcenter 950 | 5137 telesurtv 951 | 5135 lastwordonsports 952 | 5129 alarabiya 953 | 5117 cointelegraph 954 | 5114 iltalehti 955 | 5112 fc2 956 | 5108 wral 957 | 5108 thinkgeek 958 | 5102 bitbucket 959 | 5101 letterboxd 960 | 5098 ehow 961 | 5092 abc13 962 | 5083 beeradvocate 963 | 5077 umich 964 | 5067 macys 965 | 5064 factorio 966 | 5063 comicbookmovie 967 | 5042 telegram 968 | 5039 scroll 969 | 5034 setlist 970 | 5028 dailyherald 971 | 5019 games-workshop 972 | 5015 irishexaminer 973 | 5008 fbi 974 | 5007 heraldscotland 975 | 5001 jellyneo 976 | 4999 yale 977 | 4996 cbr 978 | 4994 masslive 979 | 4984 thestranger 980 | 4982 bundlestars 981 | 4981 alibaba 982 | 4977 filedropper 983 | 4974 monoprice 984 | 4968 forward 985 | 4964 parliament 986 | 4960 theringer 987 | 4950 hobbyking 988 | 4950 manchestereveningnews 989 | 4949 bmj 990 | 4948 thewire 991 | 4947 ff2ebook 992 | 4938 ashemaletube 993 | 4937 Twitch 994 | 4933 sketchtoy 995 | 4932 mcclatchydc 996 | 4931 memory-alpha 997 | 4925 newsok 998 | 4911 desmoinesregister 999 | 4901 puzzledragonx 1000 | 4889 memecrunch 1001 | -------------------------------------------------------------------------------- /download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | from tqdm import tqdm 5 | 6 | if len(sys.argv) != 2: 7 | print('You must enter the model name as a parameter, e.g.: download_model.py 124M') 8 | sys.exit(1) 9 | 10 | model = sys.argv[1] 11 | 12 | subdir = os.path.join('models', model) 13 | if not os.path.exists(subdir): 14 | os.makedirs(subdir) 15 | subdir = subdir.replace('\\','/') # needed for Windows 16 | 17 | for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: 18 | 19 | r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True) 20 | 21 | with open(os.path.join(subdir, filename), 'wb') as f: 22 | file_size = int(r.headers["content-length"]) 23 | chunk_size = 1000 24 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 25 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 26 | for chunk in r.iter_content(chunk_size=chunk_size): 27 | f.write(chunk) 28 | pbar.update(chunk_size) 29 | -------------------------------------------------------------------------------- /encode.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Usage: 3 | # PYTHONPATH=src ./encode.py /path/to/output.npz 4 | # PYTHONPATH=src ./train --dataset /path/to/output.npz 5 | 6 | import argparse 7 | import numpy as np 8 | 9 | import encoder 10 | from load_dataset import load_dataset 11 | 12 | parser = argparse.ArgumentParser( 13 | description='Pre-encode text files into tokenized training set.', 14 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 15 | parser.add_argument('--model_name', metavar='MODEL', type=str, default='117M', help='Pretrained model name') 16 | parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory') 17 | parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate files with <|endoftext|> separator into chunks of this minimum size') 18 | parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.') 19 | parser.add_argument('in_text', metavar='PATH', type=str, help='Input file, directory, or glob pattern (utf-8 text).') 20 | parser.add_argument('out_npz', metavar='OUT.npz', type=str, help='Output file path') 21 | 22 | def main(): 23 | args = parser.parse_args() 24 | enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir) 25 | print('Reading files') 26 | chunks = load_dataset(enc, args.in_text, args.combine, encoding=args.encoding) 27 | print('Writing', args.out_npz) 28 | np.savez_compressed(args.out_npz, *chunks) 29 | 30 | 31 | if __name__ == '__main__': 32 | main() 33 | -------------------------------------------------------------------------------- /model_card.md: -------------------------------------------------------------------------------- 1 | # GPT-2 model card 2 | 3 | Last updated: November 2019 4 | 5 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993), we’re providing some accompanying information about the GPT-2 family of models we're releasing. 6 | 7 | ## Model Details. 8 | 9 | This model was developed by researchers at OpenAI to help us understand how the capabilities of language model capabilities scale as a function of the size of the models (by parameter count) combined with very large internet-scale datasets (WebText). 10 | 11 | ### Model date 12 | 13 | February 2019, trained on data that cuts off at the end of 2017. 14 | 15 | ### Model type 16 | 17 | Language model 18 | 19 | ### Model version 20 | 21 | 1.5 billion parameters: the fourth and largest GPT-2 version. We have also released 124 million, 355 million, and 774 million parameter models. 22 | 23 | ### Paper or other resource for more information 24 | [Blog post](https://openai.com/blog/better-language-models/) and [paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 25 | 26 | ### Where to send questions or comments about the model 27 | Please use this [Google Form](https://forms.gle/A7WBSbTY2EkKdroPA) 28 | 29 | ## Intended Uses: 30 | 31 | ### Primary intended uses 32 | 33 | The primary intended users of these models are *AI researchers and practitioners*. 34 | 35 | We primarily imagine these language models will be used by researchers to better understand the behaviors, capabilities, biases, and constraints of large-scale generative language models. 36 | 37 | ### Secondary uses 38 | 39 | Here are some secondary use cases we believe are likely: 40 | 41 | - **Writing assistance**: Grammar assistance, autocompletion (for normal prose or code) 42 | - **Creative writing and art**: exploring the generation of creative, fictional texts; aiding creation of poetry and other literary art. 43 | - **Entertainment**: Creation of games, chat bots, and amusing generations. 44 | 45 | ### Out-of-scope use cases 46 | 47 | Because large-scale language models like GPT-2 do not distinguish fact from fiction, we don’t support use-cases that require the generated text to be true. 48 | 49 | Additionally, language models like GPT-2 reflect the biases inherent to the systems they were trained on, so we do not recommend that they be deployed into systems that interact with humans unless the deployers first carry out a study of biases relevant to the intended use-case. We found no statistically significant difference in gender, race, and religious bias probes between 774M and 1.5B, implying all versions of GPT-2 should be approached with similar levels of caution around use cases that are sensitive to biases around human attributes. 50 | 51 | ## Evaluation Data 52 | 53 | ### Datasets 54 | 55 | This model was trained on (and evaluated against) WebText, a dataset consisting of the text contents of 45 million links posted by users of the ‘Reddit’ social network. WebText is made of data derived from outbound links from Reddit and does not consist of data taken directly from Reddit itself. Before generating the dataset we used a blocklist to ensure we didn’t sample from a variety of subreddits which contain sexually explicit or otherwise offensive content. 56 | 57 | To get a sense of the data that went into GPT-2, we’ve [published a list](domains.txt) of the top 1,000 domains present in WebText and their frequency. The top 15 domains by volume in WebText are: Google, Archive, Blogspot, GitHub, NYTimes, Wordpress, Washington Post, Wikia, BBC, The Guardian, eBay, Pastebin, CNN, Yahoo!, and the Huffington Post. 58 | 59 | ### Motivation 60 | 61 | The motivation behind WebText was to create an Internet-scale, heterogeneous dataset that we could use to test large-scale language models against. WebText was (and is) intended to be primarily for research purposes rather than production purposes. 62 | 63 | ### Caveats and Recommendations 64 | 65 | Because GPT-2 is an internet-scale language model, it’s currently difficult to know what disciplined testing procedures can be applied to it to fully understand its capabilities and how the data it is trained on influences its vast range of outputs. We recommend researchers investigate these aspects of the model and share their results. 66 | 67 | Additionally, as indicated in our discussion of issues relating to potential misuse of the model, it remains unclear what the long-term dynamics are of detecting outputs from these models. We conducted [in-house automated ML-based detection research](https://github.com/openai/gpt-2-output-dataset/tree/master/detector) using simple classifiers, zero shot, and fine-tuning methods. Our fine-tuned detector model reached accuracy levels of approximately 95%. However, no one detection method is a panacea; automated ML-based detection, human detection, human-machine teaming, and metadata-based detection are all methods that can be combined for more confident classification. Developing better approaches to detection today will give us greater intuitions when thinking about future models and could help us understand ahead of time if detection methods will eventually become ineffective. 68 | 69 | 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire>=0.1.3 2 | regex==2017.4.5 3 | requests==2.21.0 4 | tqdm==4.31.1 5 | toposort==1.5 6 | -------------------------------------------------------------------------------- /src/accumulate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import numpy as np 5 | import tensorflow.compat.v1 as tf 6 | import time 7 | 8 | 9 | class AccumulatingOptimizer(object): 10 | def __init__(self, opt, var_list): 11 | self.opt = opt 12 | self.var_list = var_list 13 | self.accum_vars = {tv : tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) 14 | for tv in var_list} 15 | self.total_loss = tf.Variable(tf.zeros(shape=[], dtype=tf.float32)) 16 | self.count_loss = tf.Variable(tf.zeros(shape=[], dtype=tf.float32)) 17 | 18 | def reset(self): 19 | updates = [tv.assign(tf.zeros_like(tv)) for tv in self.accum_vars.values()] 20 | updates.append(self.total_loss.assign(tf.zeros(shape=[], dtype=tf.float32))) 21 | updates.append(self.count_loss.assign(tf.zeros(shape=[], dtype=tf.float32))) 22 | with tf.control_dependencies(updates): 23 | return tf.no_op() 24 | 25 | def compute_gradients(self, loss): 26 | grads = self.opt.compute_gradients(loss, self.var_list) 27 | updates = [self.accum_vars[v].assign_add(g) for (g,v) in grads] 28 | updates.append(self.total_loss.assign_add(loss)) 29 | updates.append(self.count_loss.assign_add(1.0)) 30 | with tf.control_dependencies(updates): 31 | return tf.no_op() 32 | 33 | def apply_gradients(self): 34 | grads = [(g,v) for (v,g) in self.accum_vars.items()] 35 | with tf.control_dependencies([self.opt.apply_gradients(grads)]): 36 | return self.total_loss / self.count_loss 37 | -------------------------------------------------------------------------------- /src/encoder.py: -------------------------------------------------------------------------------- 1 | """Byte pair encoding utilities""" 2 | 3 | import os 4 | import json 5 | import regex as re 6 | from functools import lru_cache 7 | 8 | @lru_cache() 9 | def bytes_to_unicode(): 10 | """ 11 | Returns list of utf-8 byte and a corresponding list of unicode strings. 12 | The reversible bpe codes work on unicode strings. 13 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 14 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 15 | This is a signficant percentage of your normal, say, 32K bpe vocab. 16 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 17 | And avoids mapping to whitespace/control characters the bpe code barfs on. 18 | """ 19 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 20 | cs = bs[:] 21 | n = 0 22 | for b in range(2**8): 23 | if b not in bs: 24 | bs.append(b) 25 | cs.append(2**8+n) 26 | n += 1 27 | cs = [chr(n) for n in cs] 28 | return dict(zip(bs, cs)) 29 | 30 | def get_pairs(word): 31 | """Return set of symbol pairs in a word. 32 | 33 | Word is represented as tuple of symbols (symbols being variable-length strings). 34 | """ 35 | pairs = set() 36 | prev_char = word[0] 37 | for char in word[1:]: 38 | pairs.add((prev_char, char)) 39 | prev_char = char 40 | return pairs 41 | 42 | class Encoder: 43 | def __init__(self, encoder, bpe_merges, errors='replace'): 44 | self.encoder = encoder 45 | self.decoder = {v:k for k,v in self.encoder.items()} 46 | self.errors = errors # how to handle errors in decoding 47 | self.byte_encoder = bytes_to_unicode() 48 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 49 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 50 | self.cache = {} 51 | 52 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 53 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 54 | 55 | def bpe(self, token): 56 | if token in self.cache: 57 | return self.cache[token] 58 | word = tuple(token) 59 | pairs = get_pairs(word) 60 | 61 | if not pairs: 62 | return token 63 | 64 | while True: 65 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 66 | if bigram not in self.bpe_ranks: 67 | break 68 | first, second = bigram 69 | new_word = [] 70 | i = 0 71 | while i < len(word): 72 | try: 73 | j = word.index(first, i) 74 | new_word.extend(word[i:j]) 75 | i = j 76 | except: 77 | new_word.extend(word[i:]) 78 | break 79 | 80 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 81 | new_word.append(first+second) 82 | i += 2 83 | else: 84 | new_word.append(word[i]) 85 | i += 1 86 | new_word = tuple(new_word) 87 | word = new_word 88 | if len(word) == 1: 89 | break 90 | else: 91 | pairs = get_pairs(word) 92 | word = ' '.join(word) 93 | self.cache[token] = word 94 | return word 95 | 96 | def encode(self, text): 97 | bpe_tokens = [] 98 | for token in re.findall(self.pat, text): 99 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 100 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 101 | return bpe_tokens 102 | 103 | def decode(self, tokens): 104 | text = ''.join([self.decoder[token] for token in tokens]) 105 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 106 | return text 107 | 108 | def get_encoder(model_name, models_dir): 109 | with open(os.path.join(models_dir, model_name, 'encoder.json'), 'r') as f: 110 | encoder = json.load(f) 111 | with open(os.path.join(models_dir, model_name, 'vocab.bpe'), 'r', encoding="utf-8") as f: 112 | bpe_data = f.read() 113 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 114 | return Encoder( 115 | encoder=encoder, 116 | bpe_merges=bpe_merges, 117 | ) 118 | -------------------------------------------------------------------------------- /src/generate_unconditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow.compat.v1 as tf 8 | 9 | import model, sample, encoder 10 | 11 | def sample_model( 12 | model_name='124M', 13 | seed=None, 14 | nsamples=0, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=0, 19 | top_p=1, 20 | models_dir='models', 21 | ): 22 | """ 23 | Run the sample_model 24 | :model_name=124M : String, which model to use 25 | :seed=None : Integer seed for random number generators, fix seed to 26 | reproduce results 27 | :nsamples=0 : Number of samples to return, if 0, continues to 28 | generate samples indefinately. 29 | :batch_size=1 : Number of batches (only affects speed/memory). 30 | :length=None : Number of tokens in generated text, if None (default), is 31 | determined by model hyperparameters 32 | :temperature=1 : Float value controlling randomness in boltzmann 33 | distribution. Lower temperature results in less random completions. As the 34 | temperature approaches zero, the model will become deterministic and 35 | repetitive. Higher temperature results in more random completions. 36 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 37 | considered for each step (token), resulting in deterministic completions, 38 | while 40 means 40 words are considered at each step. 0 (default) is a 39 | special setting meaning no restrictions. 40 generally is a good value. 40 | :models_dir : path to parent folder containing model subfolders 41 | (i.e. contains the folder) 42 | """ 43 | models_dir = os.path.expanduser(os.path.expandvars(models_dir)) 44 | enc = encoder.get_encoder(model_name, models_dir) 45 | hparams = model.default_hparams() 46 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: 47 | hparams.override_from_dict(json.load(f)) 48 | 49 | if length is None: 50 | length = hparams.n_ctx 51 | elif length > hparams.n_ctx: 52 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 53 | 54 | with tf.Session(graph=tf.Graph()) as sess: 55 | np.random.seed(seed) 56 | tf.set_random_seed(seed) 57 | 58 | output = sample.sample_sequence( 59 | hparams=hparams, length=length, 60 | start_token=enc.encoder['<|endoftext|>'], 61 | batch_size=batch_size, 62 | temperature=temperature, top_k=top_k, top_p=top_p 63 | )[:, 1:] 64 | 65 | saver = tf.train.Saver() 66 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) 67 | saver.restore(sess, ckpt) 68 | 69 | generated = 0 70 | while nsamples == 0 or generated < nsamples: 71 | out = sess.run(output) 72 | for i in range(batch_size): 73 | generated += batch_size 74 | text = enc.decode(out[i]) 75 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 76 | print(text) 77 | 78 | if __name__ == '__main__': 79 | fire.Fire(sample_model) 80 | -------------------------------------------------------------------------------- /src/interactive_conditional_samples.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import fire 4 | import json 5 | import os 6 | import numpy as np 7 | import tensorflow.compat.v1 as tf 8 | 9 | import model, sample, encoder 10 | 11 | def interact_model( 12 | model_name='124M', 13 | seed=None, 14 | nsamples=1, 15 | batch_size=1, 16 | length=None, 17 | temperature=1, 18 | top_k=0, 19 | top_p=1, 20 | models_dir='models', 21 | ): 22 | """ 23 | Interactively run the model 24 | :model_name=124M : String, which model to use 25 | :seed=None : Integer seed for random number generators, fix seed to reproduce 26 | results 27 | :nsamples=1 : Number of samples to return total 28 | :batch_size=1 : Number of batches (only affects speed/memory). Must divide nsamples. 29 | :length=None : Number of tokens in generated text, if None (default), is 30 | determined by model hyperparameters 31 | :temperature=1 : Float value controlling randomness in boltzmann 32 | distribution. Lower temperature results in less random completions. As the 33 | temperature approaches zero, the model will become deterministic and 34 | repetitive. Higher temperature results in more random completions. 35 | :top_k=0 : Integer value controlling diversity. 1 means only 1 word is 36 | considered for each step (token), resulting in deterministic completions, 37 | while 40 means 40 words are considered at each step. 0 (default) is a 38 | special setting meaning no restrictions. 40 generally is a good value. 39 | :models_dir : path to parent folder containing model subfolders 40 | (i.e. contains the folder) 41 | """ 42 | models_dir = os.path.expanduser(os.path.expandvars(models_dir)) 43 | if batch_size is None: 44 | batch_size = 1 45 | assert nsamples % batch_size == 0 46 | 47 | enc = encoder.get_encoder(model_name, models_dir) 48 | hparams = model.default_hparams() 49 | with open(os.path.join(models_dir, model_name, 'hparams.json')) as f: 50 | hparams.override_from_dict(json.load(f)) 51 | 52 | if length is None: 53 | length = hparams.n_ctx // 2 54 | elif length > hparams.n_ctx: 55 | raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx) 56 | 57 | with tf.Session(graph=tf.Graph()) as sess: 58 | context = tf.placeholder(tf.int32, [batch_size, None]) 59 | np.random.seed(seed) 60 | tf.set_random_seed(seed) 61 | output = sample.sample_sequence( 62 | hparams=hparams, length=length, 63 | context=context, 64 | batch_size=batch_size, 65 | temperature=temperature, top_k=top_k, top_p=top_p 66 | ) 67 | 68 | saver = tf.train.Saver() 69 | ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name)) 70 | saver.restore(sess, ckpt) 71 | 72 | while True: 73 | raw_text = input("Model prompt >>> ") 74 | while not raw_text: 75 | print('Prompt should not be empty!') 76 | raw_text = input("Model prompt >>> ") 77 | context_tokens = enc.encode(raw_text) 78 | generated = 0 79 | for _ in range(nsamples // batch_size): 80 | out = sess.run(output, feed_dict={ 81 | context: [context_tokens for _ in range(batch_size)] 82 | })[:, len(context_tokens):] 83 | for i in range(batch_size): 84 | generated += 1 85 | text = enc.decode(out[i]) 86 | print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) 87 | print(text) 88 | print("=" * 80) 89 | 90 | if __name__ == '__main__': 91 | fire.Fire(interact_model) 92 | -------------------------------------------------------------------------------- /src/load_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import numpy as np 3 | import os 4 | import tensorflow.compat.v1 as tf 5 | import tqdm 6 | 7 | 8 | def load_dataset(enc, path, combine, encoding=None): 9 | paths = [] 10 | if os.path.isfile(path): 11 | # Simple file 12 | paths.append(path) 13 | elif os.path.isdir(path): 14 | # Directory 15 | for (dirpath, _, fnames) in os.walk(path): 16 | for fname in fnames: 17 | paths.append(os.path.join(dirpath, fname)) 18 | else: 19 | # Assume glob 20 | paths = glob.glob(path) 21 | 22 | token_chunks = [] 23 | raw_text = '' 24 | for path in tqdm.tqdm(paths): 25 | if path.endswith('.npz'): 26 | # Pre-encoded 27 | with np.load(path) as npz: 28 | for item in npz.files: 29 | token_chunks.append(npz[item]) 30 | else: 31 | # Plain text 32 | with open(path, 'r', encoding=encoding) as fp: 33 | raw_text += fp.read() 34 | if len(raw_text) >= combine: 35 | tokens = np.stack(enc.encode(raw_text)) 36 | token_chunks.append(tokens) 37 | raw_text = '' 38 | else: 39 | raw_text += '<|endoftext|>' 40 | if raw_text: 41 | tokens = np.stack(enc.encode(raw_text)) 42 | token_chunks.append(tokens) 43 | return token_chunks 44 | 45 | 46 | def binary_search(f, lo, hi): 47 | if f(lo) or not f(hi): 48 | return None 49 | while hi > lo + 1: 50 | mid = (lo + hi) // 2 51 | if f(mid): 52 | hi = mid 53 | else: 54 | lo = mid 55 | return hi 56 | 57 | 58 | class Sampler(object): 59 | """Fairly samples a slice from a set of variable sized chunks. 60 | 61 | 'Fairly' means that the distribution is the same as sampling from one concatenated chunk, 62 | but without crossing chunk boundaries.""" 63 | 64 | def __init__(self, chunks, seed=None): 65 | self.chunks = chunks 66 | self.total_size = sum(chunk.shape[0] for chunk in chunks) 67 | self.boundaries = [0] 68 | for i in range(len(chunks)): 69 | self.boundaries.append(self.boundaries[-1] + chunks[i].shape[0]) 70 | self.rs = np.random.RandomState(seed=seed) 71 | 72 | def sample(self, length): 73 | assert length < self.total_size // len( 74 | self.chunks 75 | ), "Dataset files are too small to sample {} tokens at a time".format( 76 | length) 77 | while True: 78 | index = self.rs.randint(0, self.total_size - length - 1) 79 | i = binary_search(lambda j: self.boundaries[j] > index, 0, 80 | len(self.boundaries) - 1) - 1 81 | if self.boundaries[i + 1] > index + length: 82 | within_chunk = index - self.boundaries[i] 83 | return self.chunks[i][within_chunk:within_chunk + length] 84 | -------------------------------------------------------------------------------- /src/memory_saving_gradients.py: -------------------------------------------------------------------------------- 1 | from toposort import toposort 2 | import contextlib 3 | import numpy as np 4 | import tensorflow as tf 5 | import tensorflow.contrib.graph_editor as ge 6 | import time 7 | import sys 8 | sys.setrecursionlimit(10000) 9 | # refers back to current module if we decide to split helpers out 10 | util = sys.modules[__name__] 11 | 12 | # getting rid of "WARNING:tensorflow:VARIABLES collection name is deprecated" 13 | setattr(tf.GraphKeys, "VARIABLES", "variables") 14 | 15 | # save original gradients since tf.gradient could be monkey-patched to point 16 | # to our version 17 | from tensorflow.python.ops import gradients as tf_gradients_lib 18 | tf_gradients = tf_gradients_lib.gradients 19 | 20 | MIN_CHECKPOINT_NODE_SIZE=1024 # use lower value during testing 21 | 22 | # specific versions we can use to do process-wide replacement of tf.gradients 23 | def gradients_speed(ys, xs, grad_ys=None, **kwargs): 24 | return gradients(ys, xs, grad_ys, checkpoints='speed', **kwargs) 25 | 26 | def gradients_memory(ys, xs, grad_ys=None, **kwargs): 27 | return gradients(ys, xs, grad_ys, checkpoints='memory', **kwargs) 28 | 29 | def gradients_collection(ys, xs, grad_ys=None, **kwargs): 30 | return gradients(ys, xs, grad_ys, checkpoints='collection', **kwargs) 31 | 32 | def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs): 33 | ''' 34 | Authors: Tim Salimans & Yaroslav Bulatov 35 | 36 | memory efficient gradient implementation inspired by "Training Deep Nets with Sublinear Memory Cost" 37 | by Chen et al. 2016 (https://arxiv.org/abs/1604.06174) 38 | 39 | ys,xs,grad_ys,kwargs are the arguments to standard tensorflow tf.gradients 40 | (https://www.tensorflow.org/versions/r0.12/api_docs/python/train.html#gradients) 41 | 42 | 'checkpoints' can either be 43 | - a list consisting of tensors from the forward pass of the neural net 44 | that we should re-use when calculating the gradients in the backward pass 45 | all other tensors that do not appear in this list will be re-computed 46 | - a string specifying how this list should be determined. currently we support 47 | - 'speed': checkpoint all outputs of convolutions and matmuls. these ops are usually the most expensive, 48 | so checkpointing them maximizes the running speed 49 | (this is a good option if nonlinearities, concats, batchnorms, etc are taking up a lot of memory) 50 | - 'memory': try to minimize the memory usage 51 | (currently using a very simple strategy that identifies a number of bottleneck tensors in the graph to checkpoint) 52 | - 'collection': look for a tensorflow collection named 'checkpoints', which holds the tensors to checkpoint 53 | ''' 54 | 55 | # print("Calling memsaving gradients with", checkpoints) 56 | if not isinstance(ys,list): 57 | ys = [ys] 58 | if not isinstance(xs,list): 59 | xs = [xs] 60 | 61 | bwd_ops = ge.get_backward_walk_ops([y.op for y in ys], 62 | inclusive=True) 63 | 64 | debug_print("bwd_ops: %s", bwd_ops) 65 | 66 | # forward ops are all ops that are candidates for recomputation 67 | fwd_ops = ge.get_forward_walk_ops([x.op for x in xs], 68 | inclusive=True, 69 | within_ops=bwd_ops) 70 | debug_print("fwd_ops: %s", fwd_ops) 71 | 72 | # exclude ops with no inputs 73 | fwd_ops = [op for op in fwd_ops if op.inputs] 74 | 75 | # don't recompute xs, remove variables 76 | xs_ops = _to_ops(xs) 77 | fwd_ops = [op for op in fwd_ops if not op in xs_ops] 78 | fwd_ops = [op for op in fwd_ops if not '/assign' in op.name] 79 | fwd_ops = [op for op in fwd_ops if not '/Assign' in op.name] 80 | fwd_ops = [op for op in fwd_ops if not '/read' in op.name] 81 | ts_all = ge.filter_ts(fwd_ops, True) # get the tensors 82 | ts_all = [t for t in ts_all if '/read' not in t.name] 83 | ts_all = set(ts_all) - set(xs) - set(ys) 84 | 85 | # construct list of tensors to checkpoint during forward pass, if not 86 | # given as input 87 | if type(checkpoints) is not list: 88 | if checkpoints == 'collection': 89 | checkpoints = tf.get_collection('checkpoints') 90 | 91 | elif checkpoints == 'speed': 92 | # checkpoint all expensive ops to maximize running speed 93 | checkpoints = ge.filter_ts_from_regex(fwd_ops, 'conv2d|Conv|MatMul') 94 | 95 | elif checkpoints == 'memory': 96 | 97 | # remove very small tensors and some weird ops 98 | def fixdims(t): # tf.Dimension values are not compatible with int, convert manually 99 | try: 100 | return [int(e if e.value is not None else 64) for e in t] 101 | except: 102 | return [0] # unknown shape 103 | ts_all = [t for t in ts_all if np.prod(fixdims(t.shape)) > MIN_CHECKPOINT_NODE_SIZE] 104 | ts_all = [t for t in ts_all if 'L2Loss' not in t.name] 105 | ts_all = [t for t in ts_all if 'entropy' not in t.name] 106 | ts_all = [t for t in ts_all if 'FusedBatchNorm' not in t.name] 107 | ts_all = [t for t in ts_all if 'Switch' not in t.name] 108 | ts_all = [t for t in ts_all if 'dropout' not in t.name] 109 | # DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16 110 | ts_all = [t for t in ts_all if 'Cast' not in t.name] 111 | 112 | # filter out all tensors that are inputs of the backward graph 113 | with util.capture_ops() as bwd_ops: 114 | tf_gradients(ys, xs, grad_ys, **kwargs) 115 | 116 | bwd_inputs = [t for op in bwd_ops for t in op.inputs] 117 | # list of tensors in forward graph that is in input to bwd graph 118 | ts_filtered = list(set(bwd_inputs).intersection(ts_all)) 119 | debug_print("Using tensors %s", ts_filtered) 120 | 121 | # try two slightly different ways of getting bottlenecks tensors 122 | # to checkpoint 123 | for ts in [ts_filtered, ts_all]: 124 | 125 | # get all bottlenecks in the graph 126 | bottleneck_ts = [] 127 | for t in ts: 128 | b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops)) 129 | f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops)) 130 | # check that there are not shortcuts 131 | b_inp = set([inp for op in b for inp in op.inputs]).intersection(ts_all) 132 | f_inp = set([inp for op in f for inp in op.inputs]).intersection(ts_all) 133 | if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all): 134 | bottleneck_ts.append(t) # we have a bottleneck! 135 | else: 136 | debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp))) 137 | 138 | # success? or try again without filtering? 139 | if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found! 140 | break 141 | 142 | if not bottleneck_ts: 143 | raise Exception('unable to find bottleneck tensors! please provide checkpoint nodes manually, or use checkpoints="speed".') 144 | 145 | # sort the bottlenecks 146 | bottlenecks_sorted_lists = tf_toposort(bottleneck_ts, within_ops=fwd_ops) 147 | sorted_bottlenecks = [t for ts in bottlenecks_sorted_lists for t in ts] 148 | 149 | # save an approximately optimal number ~ sqrt(N) 150 | N = len(ts_filtered) 151 | if len(bottleneck_ts) <= np.ceil(np.sqrt(N)): 152 | checkpoints = sorted_bottlenecks 153 | else: 154 | step = int(np.ceil(len(bottleneck_ts) / np.sqrt(N))) 155 | checkpoints = sorted_bottlenecks[step::step] 156 | 157 | else: 158 | raise Exception('%s is unsupported input for "checkpoints"' % (checkpoints,)) 159 | 160 | checkpoints = list(set(checkpoints).intersection(ts_all)) 161 | 162 | # at this point automatic selection happened and checkpoints is list of nodes 163 | assert isinstance(checkpoints, list) 164 | 165 | debug_print("Checkpoint nodes used: %s", checkpoints) 166 | # better error handling of special cases 167 | # xs are already handled as checkpoint nodes, so no need to include them 168 | xs_intersect_checkpoints = set(xs).intersection(set(checkpoints)) 169 | if xs_intersect_checkpoints: 170 | debug_print("Warning, some input nodes are also checkpoint nodes: %s", 171 | xs_intersect_checkpoints) 172 | ys_intersect_checkpoints = set(ys).intersection(set(checkpoints)) 173 | debug_print("ys: %s, checkpoints: %s, intersect: %s", ys, checkpoints, 174 | ys_intersect_checkpoints) 175 | # saving an output node (ys) gives no benefit in memory while creating 176 | # new edge cases, exclude them 177 | if ys_intersect_checkpoints: 178 | debug_print("Warning, some output nodes are also checkpoints nodes: %s", 179 | format_ops(ys_intersect_checkpoints)) 180 | 181 | # remove initial and terminal nodes from checkpoints list if present 182 | checkpoints = list(set(checkpoints) - set(ys) - set(xs)) 183 | 184 | # check that we have some nodes to checkpoint 185 | # if not checkpoints: 186 | # raise Exception('no checkpoints nodes found or given as input! ') 187 | 188 | # disconnect dependencies between checkpointed tensors 189 | checkpoints_disconnected = {} 190 | for x in checkpoints: 191 | if x.op and x.op.name is not None: 192 | grad_node = tf.stop_gradient(x, name=x.op.name+"_sg") 193 | else: 194 | grad_node = tf.stop_gradient(x) 195 | checkpoints_disconnected[x] = grad_node 196 | 197 | # partial derivatives to the checkpointed tensors and xs 198 | ops_to_copy = fast_backward_ops(seed_ops=[y.op for y in ys], 199 | stop_at_ts=checkpoints, within_ops=fwd_ops) 200 | debug_print("Found %s ops to copy within fwd_ops %s, seed %s, stop_at %s", 201 | len(ops_to_copy), fwd_ops, [r.op for r in ys], checkpoints) 202 | debug_print("ops_to_copy = %s", ops_to_copy) 203 | debug_print("Processing list %s", ys) 204 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 205 | for origin_op, op in info._transformed_ops.items(): 206 | op._set_device(origin_op.node_def.device) 207 | copied_ops = info._transformed_ops.values() 208 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 209 | ge.reroute_ts(checkpoints_disconnected.values(), checkpoints_disconnected.keys(), can_modify=copied_ops) 210 | debug_print("Rewired %s in place of %s restricted to %s", 211 | checkpoints_disconnected.values(), checkpoints_disconnected.keys(), copied_ops) 212 | 213 | # get gradients with respect to current boundary + original x's 214 | copied_ys = [info._transformed_ops[y.op]._outputs[0] for y in ys] 215 | boundary = list(checkpoints_disconnected.values()) 216 | dv = tf_gradients(ys=copied_ys, xs=boundary+xs, grad_ys=grad_ys, **kwargs) 217 | debug_print("Got gradients %s", dv) 218 | debug_print("for %s", copied_ys) 219 | debug_print("with respect to %s", boundary+xs) 220 | 221 | inputs_to_do_before = [y.op for y in ys] 222 | if grad_ys is not None: 223 | inputs_to_do_before += grad_ys 224 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 225 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 226 | 227 | # partial derivatives to the checkpointed nodes 228 | # dictionary of "node: backprop" for nodes in the boundary 229 | d_checkpoints = {r: dr for r,dr in zip(checkpoints_disconnected.keys(), 230 | dv[:len(checkpoints_disconnected)])} 231 | # partial derivatives to xs (usually the params of the neural net) 232 | d_xs = dv[len(checkpoints_disconnected):] 233 | 234 | # incorporate derivatives flowing through the checkpointed nodes 235 | checkpoints_sorted_lists = tf_toposort(checkpoints, within_ops=fwd_ops) 236 | for ts in checkpoints_sorted_lists[::-1]: 237 | debug_print("Processing list %s", ts) 238 | checkpoints_other = [r for r in checkpoints if r not in ts] 239 | checkpoints_disconnected_other = [checkpoints_disconnected[r] for r in checkpoints_other] 240 | 241 | # copy part of the graph below current checkpoint node, stopping at 242 | # other checkpoints nodes 243 | ops_to_copy = fast_backward_ops(within_ops=fwd_ops, seed_ops=[r.op for r in ts], stop_at_ts=checkpoints_other) 244 | debug_print("Found %s ops to copy within %s, seed %s, stop_at %s", 245 | len(ops_to_copy), fwd_ops, [r.op for r in ts], 246 | checkpoints_other) 247 | debug_print("ops_to_copy = %s", ops_to_copy) 248 | if not ops_to_copy: # we're done! 249 | break 250 | copied_sgv, info = ge.copy_with_input_replacements(ge.sgv(ops_to_copy), {}) 251 | for origin_op, op in info._transformed_ops.items(): 252 | op._set_device(origin_op.node_def.device) 253 | copied_ops = info._transformed_ops.values() 254 | debug_print("Copied %s to %s", ops_to_copy, copied_ops) 255 | ge.reroute_ts(checkpoints_disconnected_other, checkpoints_other, can_modify=copied_ops) 256 | debug_print("Rewired %s in place of %s restricted to %s", 257 | checkpoints_disconnected_other, checkpoints_other, copied_ops) 258 | 259 | # gradient flowing through the checkpointed node 260 | boundary = [info._transformed_ops[r.op]._outputs[0] for r in ts] 261 | substitute_backprops = [d_checkpoints[r] for r in ts] 262 | dv = tf_gradients(boundary, 263 | checkpoints_disconnected_other+xs, 264 | grad_ys=substitute_backprops, **kwargs) 265 | debug_print("Got gradients %s", dv) 266 | debug_print("for %s", boundary) 267 | debug_print("with respect to %s", checkpoints_disconnected_other+xs) 268 | debug_print("with boundary backprop substitutions %s", substitute_backprops) 269 | 270 | inputs_to_do_before = [d_checkpoints[r].op for r in ts] 271 | wait_to_do_ops = list(copied_ops) + [g.op for g in dv if g is not None] 272 | my_add_control_inputs(wait_to_do_ops, inputs_to_do_before) 273 | 274 | # partial derivatives to the checkpointed nodes 275 | for r, dr in zip(checkpoints_other, dv[:len(checkpoints_other)]): 276 | if dr is not None: 277 | if d_checkpoints[r] is None: 278 | d_checkpoints[r] = dr 279 | else: 280 | d_checkpoints[r] += dr 281 | def _unsparsify(x): 282 | if not isinstance(x, tf.IndexedSlices): 283 | return x 284 | assert x.dense_shape is not None, "memory_saving_gradients encountered sparse gradients of unknown shape" 285 | indices = x.indices 286 | while indices.shape.ndims < x.values.shape.ndims: 287 | indices = tf.expand_dims(indices, -1) 288 | return tf.scatter_nd(indices, x.values, x.dense_shape) 289 | 290 | # partial derivatives to xs (usually the params of the neural net) 291 | d_xs_new = dv[len(checkpoints_other):] 292 | for j in range(len(xs)): 293 | if d_xs_new[j] is not None: 294 | if d_xs[j] is None: 295 | d_xs[j] = _unsparsify(d_xs_new[j]) 296 | else: 297 | d_xs[j] += _unsparsify(d_xs_new[j]) 298 | 299 | 300 | return d_xs 301 | 302 | def tf_toposort(ts, within_ops=None): 303 | all_ops = ge.get_forward_walk_ops([x.op for x in ts], within_ops=within_ops) 304 | 305 | deps = {} 306 | for op in all_ops: 307 | for o in op.outputs: 308 | deps[o] = set(op.inputs) 309 | sorted_ts = toposort(deps) 310 | 311 | # only keep the tensors from our original list 312 | ts_sorted_lists = [] 313 | for l in sorted_ts: 314 | keep = list(set(l).intersection(ts)) 315 | if keep: 316 | ts_sorted_lists.append(keep) 317 | 318 | return ts_sorted_lists 319 | 320 | def fast_backward_ops(within_ops, seed_ops, stop_at_ts): 321 | bwd_ops = set(ge.get_backward_walk_ops(seed_ops, stop_at_ts=stop_at_ts)) 322 | ops = bwd_ops.intersection(within_ops).difference([t.op for t in stop_at_ts]) 323 | return list(ops) 324 | 325 | @contextlib.contextmanager 326 | def capture_ops(): 327 | """Decorator to capture ops created in the block. 328 | with capture_ops() as ops: 329 | # create some ops 330 | print(ops) # => prints ops created. 331 | """ 332 | 333 | micros = int(time.time()*10**6) 334 | scope_name = str(micros) 335 | op_list = [] 336 | with tf.name_scope(scope_name): 337 | yield op_list 338 | 339 | g = tf.get_default_graph() 340 | op_list.extend(ge.select_ops(scope_name+"/.*", graph=g)) 341 | 342 | def _to_op(tensor_or_op): 343 | if hasattr(tensor_or_op, "op"): 344 | return tensor_or_op.op 345 | return tensor_or_op 346 | 347 | def _to_ops(iterable): 348 | if not _is_iterable(iterable): 349 | return iterable 350 | return [_to_op(i) for i in iterable] 351 | 352 | def _is_iterable(o): 353 | try: 354 | _ = iter(o) 355 | except Exception: 356 | return False 357 | return True 358 | 359 | DEBUG_LOGGING=False 360 | def debug_print(s, *args): 361 | """Like logger.log, but also replaces all TensorFlow ops/tensors with their 362 | names. Sensitive to value of DEBUG_LOGGING, see enable_debug/disable_debug 363 | 364 | Usage: 365 | debug_print("see tensors %s for %s", tensorlist, [1,2,3]) 366 | """ 367 | 368 | if DEBUG_LOGGING: 369 | formatted_args = [format_ops(arg) for arg in args] 370 | print("DEBUG "+s % tuple(formatted_args)) 371 | 372 | def format_ops(ops, sort_outputs=True): 373 | """Helper method for printing ops. Converts Tensor/Operation op to op.name, 374 | rest to str(op).""" 375 | 376 | if hasattr(ops, '__iter__') and not isinstance(ops, str): 377 | l = [(op.name if hasattr(op, "name") else str(op)) for op in ops] 378 | if sort_outputs: 379 | return sorted(l) 380 | return l 381 | else: 382 | return ops.name if hasattr(ops, "name") else str(ops) 383 | 384 | def my_add_control_inputs(wait_to_do_ops, inputs_to_do_before): 385 | for op in wait_to_do_ops: 386 | ci = [i for i in inputs_to_do_before if op.control_inputs is None or i not in op.control_inputs] 387 | ge.add_control_inputs(op, ci) 388 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow.compat.v1 as tf 3 | 4 | class HParams(object): 5 | def __init__(self, **kwargs): 6 | for (k, v) in kwargs.items(): 7 | setattr(self, k, v) 8 | 9 | def override_from_dict(self, kwargs): 10 | for (k, v) in kwargs.items(): 11 | setattr(self, k, v) 12 | 13 | 14 | def default_hparams(): 15 | return HParams( 16 | n_vocab=0, 17 | n_ctx=1024, 18 | n_embd=768, 19 | n_head=12, 20 | n_layer=12, 21 | ) 22 | 23 | def shape_list(x): 24 | """Deal with dynamic shape in tensorflow cleanly.""" 25 | static = x.shape.as_list() 26 | dynamic = tf.shape(x) 27 | return [dynamic[i] if s is None else s for i, s in enumerate(static)] 28 | 29 | def softmax(x, axis=-1): 30 | x = x - tf.reduce_max(x, axis=axis, keepdims=True) 31 | ex = tf.exp(x) 32 | return ex / tf.reduce_sum(ex, axis=axis, keepdims=True) 33 | 34 | def gelu(x): 35 | return 0.5*x*(1+tf.tanh(np.sqrt(2/np.pi)*(x+0.044715*tf.pow(x, 3)))) 36 | 37 | def norm(x, scope, *, axis=-1, epsilon=1e-5): 38 | """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" 39 | with tf.variable_scope(scope): 40 | n_state = shape_list(x)[-1] 41 | g = tf.get_variable('g', [n_state], initializer=tf.constant_initializer(1)) 42 | b = tf.get_variable('b', [n_state], initializer=tf.constant_initializer(0)) 43 | u = tf.reduce_mean(x, axis=axis, keepdims=True) 44 | s = tf.reduce_mean(tf.square(x-u), axis=axis, keepdims=True) 45 | x = (x - u) * tf.rsqrt(s + epsilon) 46 | x = x*g + b 47 | return x 48 | 49 | def split_states(x, n): 50 | """Reshape the last dimension of x into [n, x.shape[-1]/n].""" 51 | *start, m = shape_list(x) 52 | return tf.reshape(x, start + [n, m//n]) 53 | 54 | def merge_states(x): 55 | """Smash the last two dimensions of x into a single dimension.""" 56 | *start, a, b = shape_list(x) 57 | return tf.reshape(x, start + [a*b]) 58 | 59 | def conv1d(x, scope, nf, *, w_init_stdev=0.02): 60 | with tf.variable_scope(scope): 61 | *start, nx = shape_list(x) 62 | w = tf.get_variable('w', [1, nx, nf], initializer=tf.random_normal_initializer(stddev=w_init_stdev)) 63 | b = tf.get_variable('b', [nf], initializer=tf.constant_initializer(0)) 64 | c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, start+[nf]) 65 | return c 66 | 67 | def attention_mask(nd, ns, *, dtype): 68 | """1's in the lower triangle, counting from the lower right corner. 69 | 70 | Same as tf.matrix_band_part(tf.ones([nd, ns]), -1, ns-nd), but doesn't produce garbage on TPUs. 71 | """ 72 | i = tf.range(nd)[:,None] 73 | j = tf.range(ns) 74 | m = i >= j - ns + nd 75 | return tf.cast(m, dtype) 76 | 77 | 78 | def attn(x, scope, n_state, *, past, hparams): 79 | assert x.shape.ndims == 3 # Should be [batch, sequence, features] 80 | assert n_state % hparams.n_head == 0 81 | if past is not None: 82 | assert past.shape.ndims == 5 # Should be [batch, 2, heads, sequence, features], where 2 is [k, v] 83 | 84 | def split_heads(x): 85 | # From [batch, sequence, features] to [batch, heads, sequence, features] 86 | return tf.transpose(split_states(x, hparams.n_head), [0, 2, 1, 3]) 87 | 88 | def merge_heads(x): 89 | # Reverse of split_heads 90 | return merge_states(tf.transpose(x, [0, 2, 1, 3])) 91 | 92 | def mask_attn_weights(w): 93 | # w has shape [batch, heads, dst_sequence, src_sequence], where information flows from src to dst. 94 | _, _, nd, ns = shape_list(w) 95 | b = attention_mask(nd, ns, dtype=w.dtype) 96 | b = tf.reshape(b, [1, 1, nd, ns]) 97 | w = w*b - tf.cast(1e10, w.dtype)*(1-b) 98 | return w 99 | 100 | def multihead_attn(q, k, v): 101 | # q, k, v have shape [batch, heads, sequence, features] 102 | w = tf.matmul(q, k, transpose_b=True) 103 | w = w * tf.rsqrt(tf.cast(shape_list(v)[-1], w.dtype)) 104 | 105 | w = mask_attn_weights(w) 106 | w = softmax(w) 107 | a = tf.matmul(w, v) 108 | return a 109 | 110 | with tf.variable_scope(scope): 111 | c = conv1d(x, 'c_attn', n_state*3) 112 | q, k, v = map(split_heads, tf.split(c, 3, axis=2)) 113 | present = tf.stack([k, v], axis=1) 114 | if past is not None: 115 | pk, pv = tf.unstack(past, axis=1) 116 | k = tf.concat([pk, k], axis=-2) 117 | v = tf.concat([pv, v], axis=-2) 118 | a = multihead_attn(q, k, v) 119 | a = merge_heads(a) 120 | a = conv1d(a, 'c_proj', n_state) 121 | return a, present 122 | 123 | 124 | def mlp(x, scope, n_state, *, hparams): 125 | with tf.variable_scope(scope): 126 | nx = shape_list(x)[-1] 127 | h = gelu(conv1d(x, 'c_fc', n_state)) 128 | h2 = conv1d(h, 'c_proj', nx) 129 | return h2 130 | 131 | 132 | def block(x, scope, *, past, hparams): 133 | with tf.variable_scope(scope): 134 | nx = shape_list(x)[-1] 135 | a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams) 136 | x = x + a 137 | m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams) 138 | x = x + m 139 | return x, present 140 | 141 | def past_shape(*, hparams, batch_size=None, sequence=None): 142 | return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head] 143 | 144 | def expand_tile(value, size): 145 | """Add a new axis of given size.""" 146 | value = tf.convert_to_tensor(value, name='value') 147 | ndims = value.shape.ndims 148 | return tf.tile(tf.expand_dims(value, axis=0), [size] + [1]*ndims) 149 | 150 | def positions_for(tokens, past_length): 151 | batch_size = tf.shape(tokens)[0] 152 | nsteps = tf.shape(tokens)[1] 153 | return expand_tile(past_length + tf.range(nsteps), batch_size) 154 | 155 | 156 | def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE): 157 | with tf.variable_scope(scope, reuse=reuse): 158 | results = {} 159 | batch, sequence = shape_list(X) 160 | 161 | wpe = tf.get_variable('wpe', [hparams.n_ctx, hparams.n_embd], 162 | initializer=tf.random_normal_initializer(stddev=0.01)) 163 | wte = tf.get_variable('wte', [hparams.n_vocab, hparams.n_embd], 164 | initializer=tf.random_normal_initializer(stddev=0.02)) 165 | past_length = 0 if past is None else tf.shape(past)[-2] 166 | h = tf.gather(wte, X) + tf.gather(wpe, positions_for(X, past_length)) 167 | 168 | # Transformer 169 | presents = [] 170 | pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer 171 | assert len(pasts) == hparams.n_layer 172 | for layer, past in enumerate(pasts): 173 | h, present = block(h, 'h%d' % layer, past=past, hparams=hparams) 174 | if layer == 10: 175 | tf.add_to_collection('checkpoints', h) 176 | presents.append(present) 177 | results['present'] = tf.stack(presents, axis=1) 178 | h = norm(h, 'ln_f') 179 | 180 | # Language model loss. Do tokens 0.0: 62 | logits = top_p_logits(logits, p=top_p) 63 | else: 64 | logits = top_k_logits(logits, k=top_k) 65 | samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) 66 | return [ 67 | next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2), 68 | samples, 69 | tf.concat([output, samples], axis=1) 70 | ] 71 | 72 | past, prev, output = body(None, context, context) 73 | 74 | def cond(*args): 75 | return True 76 | 77 | _, _, tokens = tf.while_loop( 78 | cond=cond, body=body, 79 | maximum_iterations=length - 1, 80 | loop_vars=[ 81 | past, 82 | prev, 83 | output 84 | ], 85 | shape_invariants=[ 86 | tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)), 87 | tf.TensorShape([batch_size, None]), 88 | tf.TensorShape([batch_size, None]), 89 | ], 90 | back_prop=False, 91 | ) 92 | 93 | return tokens 94 | -------------------------------------------------------------------------------- /src/tfremat.py: -------------------------------------------------------------------------------- 1 | import random 2 | import os 3 | import tensorflow.compat.v1 as tf 4 | import tempfile 5 | 6 | import twremat 7 | 8 | def splice_op(op, input_map, control_inputs=None): 9 | g = op.graph 10 | node_def = tf.NodeDef() 11 | node_def.CopyFrom(op.node_def) 12 | node_def.name = g.unique_name(op.name + '_copy') 13 | inputs = [input_map.get(x, x) for x in op.inputs] 14 | new_control_inputs = [input_map.get(x, x) for x in op.control_inputs] 15 | if control_inputs: 16 | new_control_inputs.extend([x for x in control_inputs if x is not None]) 17 | # new_control_inputs = control_inputs 18 | output_types = [o.dtype for o in op.outputs] 19 | op_def = op.op_def 20 | return tf.Operation(node_def, g, inputs=inputs, output_types=output_types, op_def=op_def, control_inputs=new_control_inputs) 21 | 22 | def splice_tensor(ten, new_op): 23 | i = ten.op.outputs.index(ten) 24 | return new_op.outputs[i] 25 | 26 | def splice(obj, input_map, control_inputs=None): 27 | if type(obj) is tf.Operation: 28 | return splice_op(obj, input_map, control_inputs=control_inputs) 29 | elif type(obj) is tf.Tensor: 30 | return splice_tensor(obj, input_map.get(obj.op, obj.op)) 31 | elif type(obj) is tf.IndexedSlices: 32 | return tf.IndexedSlices(values=input_map.get(obj.values, obj.values), 33 | indices=input_map.get(obj.indices, obj.indices), 34 | dense_shape=input_map.get(obj.dense_shape, obj.dense_shape)) 35 | else: 36 | raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}') 37 | 38 | def product(xs): 39 | r = 1 40 | for x in xs: 41 | r *= x 42 | return r 43 | 44 | def shape_size(shape): 45 | if shape.rank is None: 46 | return 16 47 | shape = shape.as_list() 48 | for i in range(len(shape)): 49 | if shape[i] is None and i == 0: 50 | shape[i] = 1 51 | elif shape[i] is None: 52 | shape[i] = 1024 53 | return product(shape) 54 | 55 | def graph_from_dfs(deps, starts): 56 | visited = set() 57 | frontier = starts 58 | while frontier: 59 | x = frontier.pop() 60 | if x in visited: 61 | continue 62 | visited.add(x) 63 | frontier.extend(list(deps(x))) 64 | return {x : list(deps(x)) for x in visited} 65 | 66 | def get_deps(obj): 67 | if type(obj) is tf.Operation: 68 | return list(obj.inputs) + list(obj.control_inputs) 69 | elif type(obj) is tf.Tensor: 70 | return [obj.op] 71 | elif type(obj) is tf.IndexedSlices: 72 | return [obj.indices, obj.values, obj.dense_shape] 73 | else: 74 | raise AssertionError(f'Could not get deps from{repr(type(obj))} {repr(obj)}') 75 | 76 | 77 | def tensor_graph(compute): 78 | return graph_from_dfs(get_deps, list(compute)) 79 | 80 | def blacklist(obj): 81 | if type(obj) is tf.Operation: 82 | if 'Assign' in obj.type or 'Variable' in obj.type or 'Placeholder' in obj.type: 83 | # TODO: Should we do special accounting for 84 | # ReadVariableOp? Currently we forbid cloning altogether, 85 | # but it's actually ok to clone this op as long as it 86 | # doesn't float across an effectful op (Assign). Also 87 | # currently we don't account for the memory used by 88 | # ReadVariableOp (is it copy-on-write?). 89 | # https://www.tensorflow.org/api_docs/python/tf/raw_ops/ReadVariableOp?hl=uk 90 | return True 91 | elif type(obj) is tf.Tensor: 92 | return blacklist(obj.op) 93 | return False 94 | 95 | def estimate_cpu(op): 96 | return sum(4 * shape_size(t.shape) for t in op.inputs if type(t) is tf.Tensor) + sum(4 * shape_size(t.shape) for t in op.outputs) 97 | 98 | def estimate_mem(op): 99 | return sum(4 * shape_size(t.shape) for t in op.outputs) 100 | 101 | def info(op): 102 | if blacklist(op): 103 | return {'type': 'effectful'} 104 | elif type(op) is tf.Operation: 105 | if 'Reshape' in op.type: 106 | return {'type': 'pointer'} 107 | return {'type': 'normal', 108 | 'cpu': estimate_cpu(op), 109 | 'mem': estimate_mem(op)} 110 | elif type(op) is tf.Tensor: 111 | return {'type': 'pointer'} 112 | elif type(op) is tf.IndexedSlices: 113 | return {'type': 'pointer'} 114 | else: 115 | raise AssertionError(repr((type(op), op))) 116 | 117 | 118 | # Helper functions to flatten and unflatten nested structures of 119 | # tensors and ops so that tf_remat can be applied to structures 120 | # without fiddly marshalling. 121 | def get_ops(compute): 122 | output = [] 123 | stack = [compute] 124 | while stack: 125 | top = stack.pop() 126 | if type(top) is dict: 127 | for v in top.values(): 128 | stack.append(v) 129 | elif type(top) in (list, tuple): 130 | stack.extend(top) 131 | elif type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices): 132 | output.append(top) 133 | return output 134 | 135 | def replace_ops(top, live): 136 | if type(top) in (tf.Operation, tf.Tensor, tf.IndexedSlices): 137 | return live[top] 138 | elif type(top) is dict: 139 | return {k : replace_ops(v, live) for (k,v) in top.items()} 140 | elif type(top) is list: 141 | return [replace_ops(v, live) for v in top] 142 | elif type(top) is tuple: 143 | return tuple(replace_ops(v, live) for v in top) 144 | else: 145 | return top 146 | 147 | 148 | def tf_remat(compute, memlimit): 149 | compute_ops = get_ops(compute) 150 | tf_deps = tensor_graph(compute_ops) 151 | 152 | # Relabel with integers 153 | from_op = {op : i for (i, op) in enumerate(tf_deps.keys())} 154 | from_node = {i : op for (op, i) in from_op.items()} 155 | nodes = set(from_node.keys()) 156 | node_deps = {n : [from_op[d] for d in tf_deps[from_node[n]]] for n in nodes} 157 | 158 | node_info = {} 159 | for n in nodes: 160 | node_info[n] = info(from_node[n]) 161 | node_info[n]['deps'] = [from_op[d] for d in tf_deps[from_node[n]]] 162 | 163 | steps = twremat.runtwremat(node_info, memlimit, {from_op[c] for c in compute_ops}) 164 | 165 | print('Constructing tensorflow graph...') 166 | live = {} 167 | last_op = None 168 | for (action, n) in steps: 169 | base = from_node[n] 170 | if action == 'compute': 171 | input_map = {d : live[d] for d in tf_deps[base] if live[d] != d} 172 | if blacklist(base) and not input_map: 173 | live[base] = base 174 | else: 175 | live[base] = splice(base, input_map, control_inputs=[last_op]) 176 | if type(base) is tf.Operation: 177 | last_op = live[base] 178 | elif action == 'free': 179 | del live[base] 180 | 181 | return replace_ops(compute, live) 182 | -------------------------------------------------------------------------------- /src/twremat.py: -------------------------------------------------------------------------------- 1 | from subprocess import Popen, PIPE 2 | import random 3 | import os 4 | import sys 5 | import tempfile 6 | from tqdm import tqdm 7 | 8 | BINDIR=os.path.join(os.path.dirname(sys.argv[0]), 'bin') 9 | TWREMAT=os.path.join(BINDIR, 'twremat') 10 | 11 | # Allow users to pass 'humanized' memlimit values as strings. 12 | def parse_memlimit(memlimit): 13 | if memlimit[-1] == 'K': 14 | return int(memlimit[:-1]) * 1000 15 | elif memlimit[-1] == 'M': 16 | return int(memlimit[:-1]) * 1000000 17 | elif memlimit[-1] == 'G': 18 | return int(memlimit[:-1]) * 1000000000 19 | else: 20 | return int(memlimit) 21 | 22 | def runtwremat(gr, memlimit, target): 23 | if type(memlimit) is str: 24 | memlimit = parse_memlimit(memlimit) 25 | 26 | fname = tempfile.mktemp() 27 | outname = tempfile.mktemp() 28 | with open(fname, 'w') as fp: 29 | print('p remat2', file=fp) 30 | print(f'memlimit {memlimit}', file=fp) 31 | for (n, info) in gr.items(): 32 | deps = ' '.join(str(d) for d in info['deps']) 33 | if info['type'] == 'normal': 34 | cpu = info['cpu'] 35 | mem = info['mem'] 36 | weight = f'cpu {cpu} mem {mem}' 37 | elif info['type'] == 'effectful': 38 | weight = 'effectful' 39 | elif info['type'] == 'pointer': 40 | weight = 'pointer' 41 | if n in target: 42 | tstr = 'target' 43 | else: 44 | tstr = '' 45 | print(f'node {n} deps {deps} {weight} {tstr}', file=fp) 46 | print(' '.join([TWREMAT, fname, outname])) 47 | proc = Popen([TWREMAT, fname, outname]) 48 | assert proc.wait() == 0 49 | out = [] 50 | with open(outname, 'r') as fp: 51 | for line in fp: 52 | line = line.split() 53 | if line and line[0] == 'c': 54 | out.append(('compute', int(line[1]))) 55 | elif line and line[0] == 'f': 56 | out.append(('free', int(line[1]))) 57 | elif line: 58 | print(line) 59 | exit() 60 | return out 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Usage: 3 | # PYTHONPATH=src ./train --dataset 4 | 5 | import argparse 6 | import json 7 | import os, sys 8 | import numpy as np 9 | import tensorflow.compat.v1 as tf 10 | import tensorflow as tf2 11 | import time 12 | import tqdm 13 | 14 | if tf.VERSION >= '2': 15 | tf.disable_eager_execution() 16 | tf.config.experimental.enable_tensor_float_32_execution(False) 17 | tf.config.optimizer.set_experimental_options({'layout_optimizer': False, 18 | 'constant_folding': False, 19 | 'shape_optimization': False, 20 | 'remapping': False, 21 | 'arithmetic_optimization': False, 22 | 'dependency_optimization': False, 23 | 'loop_optimization': False, 24 | 'disable_meta_optimizer': True 25 | }) 26 | 27 | 28 | import model, sample, encoder 29 | from load_dataset import load_dataset, Sampler 30 | 31 | CHECKPOINT_DIR = 'checkpoint' 32 | SAMPLE_DIR = 'samples' 33 | 34 | 35 | parser = argparse.ArgumentParser( 36 | description='Fine-tune GPT-2 on your custom dataset.', 37 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 38 | 39 | parser.add_argument('--dataset', metavar='PATH', type=str, required=True, help='Input file, directory, or glob pattern (utf-8 text, or preencoded .npz files).') 40 | parser.add_argument('--model_name', metavar='MODEL', type=str, default='124M', help='Pretrained model name') 41 | parser.add_argument('--models_dir', metavar='PATH', type=str, default='models', help='Path to models directory') 42 | parser.add_argument('--combine', metavar='CHARS', type=int, default=50000, help='Concatenate input files with <|endoftext|> separator into chunks of this minimum size') 43 | parser.add_argument('--encoding', type=str, default='utf-8', help='Set the encoding for reading and writing files.') 44 | 45 | parser.add_argument('--batch_size', metavar='SIZE', type=int, default=1, help='Batch size') 46 | parser.add_argument('--learning_rate', metavar='LR', type=float, default=0.00002, help='Learning rate for Adam') 47 | parser.add_argument('--accumulate_gradients', metavar='N', type=int, default=1, help='Accumulate gradients across N minibatches.') 48 | parser.add_argument('--memory_saving_gradients', default=False, action='store_true', help='Use gradient checkpointing to reduce vram usage.') 49 | parser.add_argument('--twremat', default=False, action='store_true', help='Use tensor rematerialization (better than memory_saving_gradients and works with tensorflow 2.0).') 50 | parser.add_argument('--twremat_memlimit', type=str, default='12G', help='Memory usage limit/target for twremat. Can be an integer, or an integer suffixed with K/M/G for kilo/mega/giga-bytes.') 51 | parser.add_argument('--only_train_transformer_layers', default=False, action='store_true', help='Restrict training to the transformer blocks.') 52 | parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer. .') 53 | parser.add_argument('--noise', type=float, default=0.0, help='Add noise to input training data to regularize against typos.') 54 | 55 | parser.add_argument('--top_k', type=int, default=40, help='K for top-k sampling.') 56 | parser.add_argument('--top_p', type=float, default=0.0, help='P for top-p sampling. Overrides top_k if set > 0.') 57 | 58 | parser.add_argument('--restore_from', type=str, default='latest', help='Either "latest", "fresh", or a path to a checkpoint file') 59 | parser.add_argument('--run_name', type=str, default='run1', help='Run id. Name of subdirectory in checkpoint/ and samples/') 60 | parser.add_argument('--sample_every', metavar='N', type=int, default=100, help='Generate samples every N steps') 61 | parser.add_argument('--sample_length', metavar='TOKENS', type=int, default=1023, help='Sample this many tokens') 62 | parser.add_argument('--sample_num', metavar='N', type=int, default=1, help='Generate this many samples') 63 | parser.add_argument('--save_every', metavar='N', type=int, default=1000, help='Write a checkpoint every N steps') 64 | 65 | parser.add_argument('--val_dataset', metavar='PATH', type=str, default=None, help='Dataset for validation loss, defaults to --dataset.') 66 | parser.add_argument('--val_batch_size', metavar='SIZE', type=int, default=2, help='Batch size for validation.') 67 | parser.add_argument('--val_batch_count', metavar='N', type=int, default=40, help='Number of batches for validation.') 68 | parser.add_argument('--val_every', metavar='STEPS', type=int, default=0, help='Calculate validation loss every STEPS steps.') 69 | 70 | 71 | def maketree(path): 72 | try: 73 | os.makedirs(path) 74 | except: 75 | pass 76 | 77 | 78 | def randomize(context, hparams, p): 79 | if p > 0: 80 | mask = tf.random.uniform(shape=tf.shape(context)) < p 81 | noise = tf.random.uniform(shape=tf.shape(context), minval=0, maxval=hparams.n_vocab, dtype=tf.int32) 82 | return tf.where(mask, noise, context) 83 | else: 84 | return context 85 | 86 | 87 | def main(): 88 | args = parser.parse_args() 89 | enc = encoder.get_encoder(args.model_name, models_dir=args.models_dir) 90 | hparams = model.default_hparams() 91 | with open(os.path.join('models', args.model_name, 'hparams.json')) as f: 92 | hparams.override_from_dict(json.load(f)) 93 | 94 | if args.sample_length > hparams.n_ctx: 95 | raise ValueError( 96 | "Can't get samples longer than window size: %s" % hparams.n_ctx) 97 | 98 | with tf.Session() as sess: 99 | # Fully static shape required to make memory accounting in 100 | # twremat accurate. 101 | train_context = tf.placeholder(tf.int32, [args.batch_size, 1024]) 102 | train_context_in = randomize(train_context, hparams, args.noise) 103 | train_output = model.model(hparams=hparams, X=train_context_in) 104 | train_loss = tf.reduce_mean( 105 | tf.nn.sparse_softmax_cross_entropy_with_logits( 106 | labels=train_context[:, 1:], logits=train_output['logits'][:, :-1])) 107 | 108 | if args.val_every > 0: 109 | val_context = tf.placeholder(tf.int32, [args.val_batch_size, None]) 110 | val_output = model.model(hparams=hparams, X=val_context) 111 | val_loss = tf.reduce_mean( 112 | tf.nn.sparse_softmax_cross_entropy_with_logits( 113 | labels=val_context[:, 1:], logits=val_output['logits'][:, :-1])) 114 | val_loss_summary = tf.summary.scalar('val_loss', val_loss) 115 | 116 | sample_context = tf.placeholder(tf.int32, [args.batch_size, None]) 117 | tf_sample = sample.sample_sequence( 118 | hparams=hparams, 119 | length=args.sample_length, 120 | context=sample_context, 121 | batch_size=args.batch_size, 122 | temperature=1.0, 123 | top_k=args.top_k, 124 | top_p=args.top_p) 125 | 126 | all_vars = [v for v in tf.trainable_variables() if 'model' in v.name] 127 | train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars 128 | 129 | if args.optimizer == 'adam': 130 | print('Using Adam optimizer', file=sys.stderr) 131 | opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate) 132 | elif args.optimizer == 'sgd': 133 | print('Using SGD optimizer', file=sys.stderr) 134 | opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate) 135 | else: 136 | exit('Bad optimizer:', args.optimizer) 137 | 138 | if args.memory_saving_gradients: 139 | if tf.VERSION >= '2': 140 | exit('Memory saving gradients are not supported in tensorflow 2.x') 141 | import memory_saving_gradients 142 | opt_grads = memory_saving_gradients.gradients(train_loss, train_vars) 143 | elif args.twremat: 144 | import tfremat 145 | opt_grads = tf.gradients(train_loss, train_vars) 146 | (train_loss, opt_grads) = tfremat.tf_remat((train_loss, opt_grads), memlimit=args.twremat_memlimit) 147 | else: 148 | opt_grads = tf.gradients(train_loss, train_vars) 149 | opt_grads = list(zip(opt_grads, train_vars)) 150 | opt_apply = opt.apply_gradients(opt_grads) 151 | summary_loss = tf.summary.scalar('loss', train_loss) 152 | 153 | # if args.twremat: 154 | # import tfremat 155 | # # Applying tfremat to opt_apply has more accurate 156 | # # accounting but is a bit iffier since side effecting ops 157 | # # have more restrictions for correctness. If in doubt 158 | # # revert back to version using opt_grads above. 159 | # (opt_apply, train_loss, summary_loss) = ( 160 | # tfremat.tf_remat((opt_apply, train_loss, summary_loss), memlimit=args.twremat_memlimit)) 161 | 162 | 163 | summary_lr = tf.summary.scalar('learning_rate', args.learning_rate) 164 | summaries = tf.summary.merge([summary_lr, summary_loss]) 165 | 166 | summary_log = tf.summary.FileWriter( 167 | os.path.join(CHECKPOINT_DIR, args.run_name)) 168 | 169 | saver = tf.train.Saver( 170 | var_list=all_vars, 171 | max_to_keep=5, 172 | keep_checkpoint_every_n_hours=2) 173 | sess.run(tf.global_variables_initializer()) 174 | 175 | if args.restore_from == 'latest': 176 | ckpt = tf.train.latest_checkpoint( 177 | os.path.join(CHECKPOINT_DIR, args.run_name)) 178 | if ckpt is None: 179 | # Get fresh GPT weights if new run. 180 | ckpt = tf.train.latest_checkpoint( 181 | os.path.join('models', args.model_name)) 182 | elif args.restore_from == 'fresh': 183 | ckpt = tf.train.latest_checkpoint( 184 | os.path.join('models', args.model_name)) 185 | else: 186 | ckpt = tf.train.latest_checkpoint(args.restore_from) 187 | print('Loading checkpoint', ckpt) 188 | saver.restore(sess, ckpt) 189 | 190 | print('Loading dataset...') 191 | chunks = load_dataset(enc, args.dataset, args.combine, encoding=args.encoding) 192 | data_sampler = Sampler(chunks) 193 | if args.val_every > 0: 194 | if args.val_dataset: 195 | val_chunks = load_dataset(enc, args.val_dataset, args.combine, encoding=args.encoding) 196 | else: 197 | val_chunks = chunks 198 | print('dataset has', data_sampler.total_size, 'tokens') 199 | print('Training...') 200 | 201 | if args.val_every > 0: 202 | # Sample from validation set once with fixed seed to make 203 | # it deterministic during training as well as across runs. 204 | val_data_sampler = Sampler(val_chunks, seed=1) 205 | val_batches = [[val_data_sampler.sample(1024) for _ in range(args.val_batch_size)] 206 | for _ in range(args.val_batch_count)] 207 | 208 | counter = 1 209 | counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter') 210 | if os.path.exists(counter_path): 211 | # Load the step number if we're resuming a run 212 | # Add 1 so we don't immediately try to save again 213 | with open(counter_path, 'r') as fp: 214 | counter = int(fp.read()) + 1 215 | 216 | def save(): 217 | maketree(os.path.join(CHECKPOINT_DIR, args.run_name)) 218 | print( 219 | 'Saving', 220 | os.path.join(CHECKPOINT_DIR, args.run_name, 221 | 'model-{}').format(counter)) 222 | saver.save( 223 | sess, 224 | os.path.join(CHECKPOINT_DIR, args.run_name, 'model'), 225 | global_step=counter) 226 | with open(counter_path, 'w') as fp: 227 | fp.write(str(counter) + '\n') 228 | 229 | def generate_samples(): 230 | print('Generating samples...') 231 | context_tokens = data_sampler.sample(1) 232 | all_text = [] 233 | index = 0 234 | while index < args.sample_num: 235 | out = sess.run( 236 | tf_sample, 237 | feed_dict={sample_context: args.batch_size * [context_tokens]}) 238 | for i in range(min(args.sample_num - index, args.batch_size)): 239 | text = enc.decode(out[i]) 240 | text = '======== SAMPLE {} ========\n{}\n'.format( 241 | index + 1, text) 242 | all_text.append(text) 243 | index += 1 244 | print(text) 245 | maketree(os.path.join(SAMPLE_DIR, args.run_name)) 246 | with open( 247 | os.path.join(SAMPLE_DIR, args.run_name, 248 | 'samples-{}').format(counter), 'w', encoding=args.encoding) as fp: 249 | fp.write('\n'.join(all_text)) 250 | 251 | def validation(): 252 | print('Calculating validation loss...') 253 | losses = [] 254 | for batch in tqdm.tqdm(val_batches): 255 | losses.append(sess.run(val_loss, feed_dict={val_context: batch})) 256 | v_val_loss = np.mean(losses) 257 | v_summary = sess.run(val_loss_summary, feed_dict={val_loss: v_val_loss}) 258 | summary_log.add_summary(v_summary, counter) 259 | summary_log.flush() 260 | print( 261 | '[{counter} | {time:2.2f}] validation loss = {loss:2.2f}' 262 | .format( 263 | counter=counter, 264 | time=time.time() - start_time, 265 | loss=v_val_loss)) 266 | 267 | def sample_batch(): 268 | return [data_sampler.sample(1024) for _ in range(args.batch_size)] 269 | 270 | 271 | avg_loss = (0.0, 0.0) 272 | start_time = time.time() 273 | 274 | # print('Evaluating grads..') 275 | # tf2.profiler.experimental.start('logdir') 276 | # sess.run((opt_apply, train_loss, summaries), feed_dict={train_context: sample_batch()}) 277 | # tf2.profiler.experimental.stop() 278 | # print('Succeeded') 279 | # exit() 280 | 281 | try: 282 | while True: 283 | if counter % args.save_every == 0: 284 | save() 285 | if counter % args.sample_every == 0: 286 | generate_samples() 287 | if args.val_every > 0 and (counter % args.val_every == 0 or counter == 1): 288 | validation() 289 | 290 | (_, v_loss, v_summary) = sess.run( 291 | (opt_apply, train_loss, summaries), 292 | feed_dict={train_context: sample_batch()}) 293 | 294 | summary_log.add_summary(v_summary, counter) 295 | 296 | avg_loss = (avg_loss[0] * 0.99 + v_loss, 297 | avg_loss[1] * 0.99 + 1.0) 298 | 299 | print( 300 | '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}' 301 | .format( 302 | counter=counter, 303 | time=time.time() - start_time, 304 | loss=v_loss, 305 | avg=avg_loss[0] / avg_loss[1])) 306 | 307 | counter += 1 308 | except KeyboardInterrupt: 309 | print('interrupted') 310 | save() 311 | 312 | 313 | if __name__ == '__main__': 314 | main() 315 | -------------------------------------------------------------------------------- /twremat/README.md: -------------------------------------------------------------------------------- 1 | Fast implementation of `Efficient Rematerialization for Deep Networks` . 2 | -------------------------------------------------------------------------------- /twremat/main/remat.hs: -------------------------------------------------------------------------------- 1 | {-# Language NamedFieldPuns #-} 2 | {-# Language OverloadedStrings #-} 3 | module Main where 4 | 5 | import Control.Applicative 6 | import Control.Monad 7 | import Data.Foldable 8 | import Data.IntSet (IntSet) 9 | import qualified Data.IntSet as IS 10 | import Data.List 11 | import Data.Map (Map) 12 | import qualified Data.Map.Strict as Map 13 | import Data.Set (Set) 14 | import qualified Data.Set as Set 15 | import Data.Text (Text) 16 | import qualified Data.Text as T 17 | import qualified Data.Text.IO as T 18 | import Debug.Trace 19 | import System.Environment 20 | import System.IO 21 | import Text.Parser.Char 22 | import Text.Parser.Combinators 23 | import Text.Trifecta (Parser) 24 | import qualified Text.Trifecta as Trifecta 25 | 26 | import Balanced 27 | import Filter 28 | import Graph (Gr, Node) 29 | import qualified Graph as G 30 | import TWRemat 31 | import TreeWidth 32 | import Util 33 | 34 | parse :: Parser a -> Text -> a 35 | parse p txt = case Trifecta.parseString p mempty (T.unpack txt) of 36 | Trifecta.Success a -> a 37 | Trifecta.Failure e -> error (show (Trifecta._errDoc e)) 38 | 39 | p_nat :: (Read a, Integral a) => Parser a 40 | p_nat = read <$> some digit 41 | 42 | isValidSchedule :: Gr a -> [Step] -> [Node] -> Bool 43 | isValidSchedule gr steps ts = go steps IS.empty 44 | where 45 | go (Compute n : steps) live = all (`IS.member` live) (G.preList gr n) && go steps (IS.insert n live) 46 | go (Free n : steps) live = IS.member n live && go steps (IS.delete n live) 47 | go [] live = all (\n -> IS.member n live) ts 48 | 49 | -- Modify the graph to insert direct dependencies through 'pointer' 50 | -- type ops to the op that generated their underlying 51 | -- storage. Example, given op1 -> id -> op2, op2 will now directly 52 | -- keep op1 alive. Assuming that 'pointer' type ops are ~0 cost the 53 | -- 'id' can now be immediately freed after use, and all memory usage 54 | -- charged to op1, simplifying memory analysis. 55 | mergePointers :: Gr a -> (Node -> Weight) -> Gr a 56 | mergePointers gr info = merged 57 | where 58 | merged = G.insEdges [(p, n) | n <- G.nodes gr, p <- pdeps n] gr 59 | pdeps n = let pparents n = [p | p <- G.preList gr n, info p == Pointer] 60 | go [] visited = Set.toList visited 61 | go (p:ps) visited | Set.member p visited = go ps visited 62 | | otherwise = go (pparents p ++ ps) (Set.insert p visited) 63 | in go (pparents n) Set.empty 64 | 65 | 66 | outputSchedule :: [Step] -> IO () 67 | outputSchedule schedule = do 68 | args <- getArgs 69 | let printStep (Compute n) = "c " <> T.pack (show n) 70 | printStep (Free n) = "f " <> T.pack (show n) 71 | output = T.unlines (map printStep schedule) 72 | case args of 73 | [path, outpath] -> T.writeFile outpath output 74 | [path] -> T.putStr output 75 | 76 | 77 | main :: IO () 78 | main = do 79 | args <- getArgs 80 | let path = head args 81 | txt <- T.readFile path 82 | let p = do 83 | string "p remat2" *> spaces 84 | memlimit <- optional (text "memlimit" *> spaces *> p_nat <* spaces) 85 | nodes <- some $ do 86 | text "node" <* spaces 87 | node_id <- p_nat <* spaces 88 | deps <- fold <$> optional (text "deps" *> spaces *> many (p_nat <* spaces)) 89 | let p_weight = do 90 | cpu <- text "cpu" *> spaces *> p_nat <* spaces 91 | mem <- text "mem" *> spaces *> p_nat <* spaces 92 | return (Normal{cpu,mem}) 93 | p_effectful = const Effectful <$> text "effectful" <* spaces 94 | p_pointer = const Pointer <$> text "pointer" <* spaces 95 | weight <- optional (p_weight <|> p_effectful <|> p_pointer) 96 | target <- (const True <$> text "target" <* spaces) <|> pure False 97 | optional (char '\n') 98 | return (node_id, deps, weight, target) 99 | eof 100 | return (memlimit, nodes) 101 | (memlimit, node_data) = parse p txt 102 | ns = [n | (n, _, _, _) <- node_data] 103 | es = [(d, n) | (n, ds, _, _) <- node_data, d <- ds] 104 | ts = [n | (n, _, _, True) <- node_data] 105 | 106 | case memlimit of 107 | Just memlimit -> let 108 | weights = Map.fromList [(n,w) | (n, _, Just w, _) <- node_data] 109 | weight n = Map.findWithDefault (Normal 1 1) n weights 110 | graph = mergePointers (G.mkUGraph ns es) weight 111 | schedule = remat graph (IS.fromList ts) 112 | schedule' = optimize graph weight memlimit schedule 113 | in do outputSchedule schedule' 114 | hPutStrLn stderr ("isValid = " ++ show (isValidSchedule graph schedule' ts)) 115 | hPutStrLn stderr ("length = " ++ show (length schedule')) 116 | evalSched weight (initSched graph schedule') 117 | Nothing -> let 118 | graph = G.mkUGraph ns es 119 | schedule = remat graph (IS.fromList ts) 120 | in outputSchedule schedule 121 | -- -- G.plotLab "tree.dot" (IS.toList <$> treeWidth graph) 122 | -- print (length ns) 123 | -- print (length schedule, isValidSchedule graph schedule ts) 124 | -- print (length schedule', isValidSchedule graph schedule' ts) 125 | 126 | -- let sched_1 = initSched graph schedule 127 | -- evalSched weight sched_1 128 | 129 | -- let sched_2 = optSched sched_1 130 | -- evalSched weight sched_2 131 | 132 | -- let go sched = do 133 | -- let sched_1 = greedy weight 1000 sched 134 | -- evalSched weight sched_1 135 | -- let sched_2 = optSched sched_1 136 | -- evalSched weight sched_2 137 | -- go sched_2 138 | 139 | -- go sched_2 140 | 141 | -- let 142 | -- output = T.unlines $ map T.pack $ do 143 | -- step <- schedule 144 | -- return $ case step of 145 | -- Compute n -> "c " ++ show n 146 | -- Free n -> "f " ++ show n 147 | -- case args of 148 | -- [path, outpath] -> T.writeFile outpath output 149 | -- [path] -> T.putStr output 150 | -- hPutStrLn stderr ("isValid = " ++ show (isValidSchedule graph schedule)) 151 | -------------------------------------------------------------------------------- /twremat/main/test.hs: -------------------------------------------------------------------------------- 1 | {-# Language ScopedTypeVariables #-} 2 | module Main where 3 | 4 | import Data.IntSet (IntSet) 5 | import qualified Data.IntSet as IS 6 | import Test.QuickCheck 7 | import Test.Tasty 8 | import Test.Tasty.QuickCheck 9 | 10 | import Graph 11 | import TWRemat 12 | import TestBalanced 13 | import TestGraph 14 | import TestTreeWidth 15 | 16 | isValidSchedule :: Gr a -> [Step] -> Bool 17 | isValidSchedule gr steps = go steps IS.empty 18 | where 19 | go (Compute n : steps) live = all (`IS.member` live) (preList gr n) && go steps (IS.insert n live) 20 | go (Free n : steps) live = IS.member n live && go steps (IS.delete n live) 21 | go [] live = True 22 | 23 | main = defaultMain $ testGroup "Tests" [ 24 | testGraph, 25 | testBalanced, 26 | testTreeWidth, 27 | testGroup "TWRemat" [ 28 | testProperty "produces valid schedule" $ \(DagOf (gr :: Gr ())) -> 29 | let t = last (nodes gr) 30 | in isValidSchedule gr (remat gr (IS.fromList [t])), 31 | testProperty "produces valid schedule x2" $ \(DagOf (gr :: Gr ())) -> 32 | let t = take 2 (nodes gr) 33 | in isValidSchedule gr (remat gr (IS.fromList t)) 34 | ] 35 | ] 36 | -------------------------------------------------------------------------------- /twremat/src/Balanced.hs: -------------------------------------------------------------------------------- 1 | {-# Language BangPatterns #-} 2 | {-# Language DeriveTraversable #-} 3 | module Balanced where 4 | 5 | import Data.Bifunctor 6 | import Data.Foldable 7 | import Data.IntMap (IntMap) 8 | import qualified Data.IntMap as IM 9 | import Data.IntSet (IntSet) 10 | import qualified Data.IntSet as IS 11 | import Data.List 12 | import Data.Map (Map) 13 | import qualified Data.Map.Strict as Map 14 | import Data.Ord 15 | import Data.Set (Set) 16 | import qualified Data.Set as Set 17 | import Data.Tuple 18 | import Debug.Trace 19 | 20 | import TreeWidth 21 | 22 | import Graph (Gr, Node) 23 | import qualified Graph as G 24 | import Util 25 | 26 | balancedSeparator :: Gr a -> Node 27 | balancedSeparator gr = minimumOn (\n -> (weight n, n)) (G.nodes gr) 28 | where 29 | cutWeight = memo (G.edges gr) $ \(a,b) -> 30 | 1 + sum [cutWeight (b,c) | c <- G.sucList gr b, c /= a] :: Int 31 | weight = \a -> 32 | maximum [cutWeight (a,b) | b <- G.sucList gr a] 33 | 34 | -- Rose tree with weight annotations. 35 | data Tree a = Tree Int a [Tree a] 36 | deriving (Show, Functor, Foldable, Traversable) 37 | 38 | treeWeight :: Tree a -> Int 39 | treeWeight (Tree w _ _) = w 40 | 41 | treeVal :: Tree a -> a 42 | treeVal (Tree _ a _) = a 43 | 44 | tree :: a -> [Tree a] -> Tree a 45 | tree a subs = Tree (1 + sum (map treeWeight subs)) a subs 46 | 47 | -- Create a Tree from treelike graph. Assumes gr is undirected and 48 | -- simple and has at least one node. 49 | mkTree :: Gr a -> Tree Node 50 | mkTree gr = tree top [go top v | v <- G.sucList gr top, v /= top] 51 | where 52 | go u v = tree v [go v w | w <- G.sucList gr v, w /= u] 53 | top = head (G.nodes gr) 54 | 55 | -- Choose one element from a list. 56 | choose1 :: [a] -> [(a, [a])] 57 | choose1 xs = do i <- [0..length xs-1] 58 | return (xs!!i, take i xs ++ drop (i+1) xs) 59 | 60 | -- Balance a tree by recursively rotating each node until the heaviest 61 | -- subtree has minimal weight. The result is a tree with two 62 | -- properties: 63 | -- 64 | -- 1. For every node v in the tree, the subtrees rooted at children of 65 | -- v are disjoint connected components of the original tree with v 66 | -- removed. 67 | -- 68 | -- 2. The tree is balanced, in that for every node v, the heaviest 69 | -- child of v has weight at most weight[v]/2. 70 | 71 | balance :: Tree a -> Tree a 72 | balance root@(Tree x a []) = root 73 | balance root@(Tree x a children) 74 | -- If we can improve balance by rotating, do so and check again. 75 | | bestscore < score root = balance best 76 | -- Current level is balanced, now balance all children. 77 | | otherwise = tree a (map balance children) 78 | where 79 | rotate (Tree _ new_a new_children) other_children = 80 | let old_root = tree a other_children 81 | in tree new_a (old_root : new_children) 82 | options = [rotate choice rest | (choice, rest) <- choose1 children] 83 | score (Tree x a children) 84 | | null children = 0 85 | | otherwise = maximum (map treeWeight children) 86 | (bestscore, best) = minimumOn fst [(score t, t) | t <- options] 87 | 88 | -- Convert a 89 | sepTree :: Gr a -> Tree a 90 | sepTree gr = fmap (G.lab gr) $ balance $ mkTree gr 91 | 92 | -- sepTreeSlow :: Gr a -> Tree a 93 | -- sepTreeSlow gr 94 | -- | G.order gr == 1 = tree (snd $ head $ G.labNodes gr) [] 95 | -- | otherwise = tree (G.lab gr top) [sepTreeSlow sub | sub <- G.splitComponents (G.delNode top gr)] 96 | -- where 97 | -- top = balancedSeparator gr 98 | -------------------------------------------------------------------------------- /twremat/src/Dense.hs: -------------------------------------------------------------------------------- 1 | module Dense where 2 | 3 | type Dense = [Int] 4 | 5 | between :: Dense -> Dense -> Dense 6 | between (a:as) (b:bs) 7 | | a + 1 < b = [div (a + b) 2] 8 | | a < b = a : after as 9 | | a == b = a : between as bs 10 | 11 | before :: Dense -> Dense 12 | before [] = error "before []" 13 | before (a:as) = [a - 1] 14 | 15 | after :: Dense -> Dense 16 | after [] = [2^20] 17 | after (a:as) = [a + 1] 18 | -------------------------------------------------------------------------------- /twremat/src/Filter.hs: -------------------------------------------------------------------------------- 1 | {-# LANGUAGE BangPatterns #-} 2 | {-# LANGUAGE NamedFieldPuns #-} 3 | {-# LANGUAGE RecordWildCards #-} 4 | {-# LANGUAGE PatternSynonyms #-} 5 | {-# LANGUAGE ViewPatterns #-} 6 | {-# LANGUAGE MultiParamTypeClasses #-} 7 | module Filter where 8 | 9 | import Control.Monad.State.Lazy 10 | import Data.Foldable 11 | import Data.List 12 | import Data.Map (Map) 13 | import qualified Data.Map.Strict as Map 14 | import Data.Monoid 15 | import Data.OrdPSQ (OrdPSQ) 16 | import qualified Data.OrdPSQ as PSQ 17 | import Data.Relation (Relation) 18 | import qualified Data.Relation as R 19 | import Data.Semigroup 20 | import Data.Set (Set) 21 | import qualified Data.Set as Set 22 | import Debug.Trace 23 | 24 | import Graph (Gr, Node) 25 | import qualified Graph as G 26 | import TWRemat 27 | 28 | newtype CID = CID Int 29 | deriving (Show, Eq, Ord) 30 | 31 | 32 | -- Indexed data structure for rematerializion schedule. 33 | data Sched = Sched { 34 | computes :: Relation CID Node, -- compute step -> node id (many-1) 35 | c_free :: Relation CID Node, -- compute step -> nodes freed afterward (many-many) 36 | c_require :: Relation CID Node -- compute step -> nodes required as input (many-many) 37 | } 38 | deriving (Show) 39 | 40 | pattern One :: a -> Set a 41 | pattern One a <- (Set.toList -> [a]) where 42 | One a = Set.singleton a 43 | 44 | pattern None :: Set a 45 | pattern None <- (Set.null -> True) where 46 | None = Set.empty 47 | 48 | deleteL :: (Ord a, Ord b) => a -> Relation a b -> Relation a b 49 | deleteL a rel = foldr go rel (Set.toList $ R.lookupDom a rel) 50 | where 51 | go b = R.delete a b 52 | 53 | -- Info for each node relevant to evaluating and optimizing cpu/memory 54 | -- consumption. 55 | data Weight = 56 | -- Normal node that reads its inputs and produces some output. 57 | Normal{cpu::Int, mem::Int} 58 | -- Effectful node that must not be duplicated (eg. assigning to a 59 | -- variable). Assumed to produce no relevant output. 60 | | Effectful 61 | -- Pointer nodes return a view that shares memory with dependencies, 62 | -- keeping them from being GCd (example: tf.identity). 63 | | Pointer 64 | deriving (Eq, Ord) 65 | 66 | -- Reduce cpu usage of a schedule under the constraint that peak mem usage must be less than `memLimit`. 67 | greedy :: (Node -> Weight) -> Int -> Sched -> Sched 68 | greedy info memLimit sched@Sched{computes,c_free,c_require} = go (R.toList computes) Set.empty 0 Set.empty PSQ.empty 69 | where 70 | memOf n = case info n of 71 | Normal{mem} -> mem 72 | Effectful -> 0 73 | Pointer -> 0 74 | 75 | priority :: CID -> Node -> Maybe Double 76 | priority c n 77 | -- Anything we don't need anymore should be freed immediately. 78 | | Nothing == Set.lookupGT c (R.lookupRan n c_require) = Just 0 79 | | otherwise = case info n of 80 | -- Otherwise, prioritise the ops which use most memory and least cpu to recompute. 81 | Normal{mem, cpu} -> Just (fromIntegral cpu / fromIntegral mem) 82 | -- Effectful ops are assumed to use no memory and should never be freed. 83 | Effectful -> Nothing 84 | -- Free pointer ops immediately. 85 | Pointer -> Just 0 86 | 87 | go [] keepcid memUsage live freeList = 88 | let finish c Sched{computes,c_free,c_require} = case R.lookupDom c computes of 89 | One n -> Sched{computes = R.delete c n computes, 90 | c_free = case (Set.lookupLT c (R.lookupRan n computes), 91 | Set.lookupLT c (R.lookupRan n c_free)) of 92 | (Just cc, Just fc) | cc <= fc -> R.delete fc n c_free, 93 | c_require = deleteL c c_require} 94 | in foldr (.) id [finish c | c <- toList (R.dom computes), not (Set.member c keepcid)] sched 95 | go ((c,n):cs) keepcid memUsage live freeList 96 | | Set.member n live = go_free c cs keepcid memUsage live (PSQ.delete n freeList) 97 | | memUsage + memOf n > memLimit && not (PSQ.null freeList) = case PSQ.findMin freeList of 98 | Just (f,_,_) -> 99 | go ((c,n):cs) keepcid (memUsage - memOf f) (Set.delete f live) (PSQ.deleteMin freeList) 100 | | otherwise = go_next (c,n) cs keepcid memUsage live freeList 101 | 102 | go_next (c,n) cs keepcid memUsage live freeList = 103 | go_free c cs (Set.insert c keepcid) (memUsage + memOf n) (Set.insert n live) (PSQ.delete n freeList) 104 | 105 | go_free c cs keepcid memUsage live freeList = 106 | let freeList' = foldr (.) id [PSQ.insert n v () 107 | | n <- Set.toList (R.lookupDom c c_free), 108 | Just v <- [priority c n]] freeList 109 | in go cs keepcid memUsage live freeList' 110 | 111 | evalSched :: (Node -> Weight) -> Sched -> IO () 112 | evalSched info Sched{computes,c_free,c_require} = do 113 | putStrLn (unwords ["steps=" ++ show (Set.size (R.dom computes)), 114 | "cpu=" ++ show (sum [cpu | (c, n) <- R.toList computes, Normal{cpu} <- [info n]]), 115 | "peak=" ++ show peak]) 116 | where 117 | memOf n = case info n of 118 | Normal{mem} -> mem 119 | _ -> 0 120 | peak = go (Set.toList (R.dom computes <> R.dom c_free)) 0 0 121 | go [] maxMem curMem = maxMem 122 | go (c:cs) maxMem curMem = case R.lookupDom c computes of 123 | One n -> go cs (max maxMem (curMem + memOf n)) (curMem + memOf n - sum [memOf f | f <- toList (R.lookupDom c c_free)]) 124 | None -> go cs maxMem (curMem - sum [memOf f | f <- toList (R.lookupDom c c_free)]) 125 | 126 | -- Basic optimizations: move Free actions to be as early as possible, 127 | -- and eliminate Compute actions that are immediately Freed without 128 | -- being used. 129 | optSched :: Sched -> Sched 130 | optSched sched@Sched{computes, c_free} = foldr go sched (toList (R.dom computes <> R.dom c_free)) 131 | where 132 | checkAnnihilate c sched@Sched{..} = 133 | case R.lookupDom c computes of 134 | One n | R.member c n c_free -> 135 | Sched{computes = R.delete c n computes, 136 | c_free = R.delete c n c_free, 137 | c_require = deleteL c c_require} 138 | _ -> sched 139 | checkMove c sched@Sched{..} = 140 | case R.lookupDom c c_free of 141 | ns | Set.size ns > 0 -> 142 | let target n = getMax <$> fold [Max <$> Set.lookupLE c (R.lookupRan n computes), 143 | Max <$> Set.lookupLE c (R.lookupRan n c_require), 144 | Max <$> Set.lookupLT c (R.lookupRan n c_free)] 145 | process n sched@Sched{..} = case target n of 146 | Just c' | c' < c -> sched { c_free = R.insert c' n $ R.delete c n $ c_free } 147 | Just c' | c'== c -> sched 148 | Nothing -> sched { c_free = R.delete c n $ c_free } 149 | in foldr process sched (Set.toList ns) 150 | _ -> sched 151 | 152 | go c sched = checkMove c $ checkAnnihilate c sched 153 | 154 | 155 | initSched :: Gr a -> [Step] -> Sched 156 | initSched gr sched = Sched{computes, c_free, c_require} 157 | where 158 | steps = Map.fromList (zip [1..] sched) :: Map Int Step 159 | computes = R.fromList [(CID k, n) | (k, Compute n) <- Map.toList steps] 160 | cdom = R.dom computes 161 | c_free = R.fromList [let Just c = Set.lookupLT (CID k) cdom 162 | in (c, n) | (k, Free n) <- Map.toList steps] 163 | c_require = R.fromList [(c, p) | (c, n) <- R.toList computes, p <- G.preList gr n] 164 | 165 | runSched :: Sched -> [Step] 166 | runSched Sched{computes, c_free} = fold [[Compute n] ++ (Free <$> Set.toList (R.lookupDom c c_free)) 167 | | (c, n) <- R.toList computes] 168 | 169 | -- 6 cycles of forward-backward optimization seems to generally be enough for a good schedule. 170 | optimize :: Gr a -> (Node -> Weight) -> Int -> [Step] -> [Step] 171 | optimize gr info memLimit steps = runSched (foldl' step startSched [1..maxSteps]) 172 | where 173 | step !sched i = trace ("Optimizing schedule... " ++ show i ++ "/" ++ show maxSteps) $ optSched (greedy info memLimit sched) 174 | startSched = optSched (initSched gr steps) 175 | maxSteps = 6 176 | -------------------------------------------------------------------------------- /twremat/src/Graph.hs: -------------------------------------------------------------------------------- 1 | module Graph where 2 | 3 | import Control.Monad.State.Strict 4 | import Data.Foldable 5 | import Data.IntMap (IntMap) 6 | import qualified Data.IntMap as IM 7 | import Data.IntSet (IntSet) 8 | import qualified Data.IntSet as IS 9 | import Data.Map (Map) 10 | import qualified Data.Map.Strict as Map 11 | import Data.Set (Set) 12 | import qualified Data.Set as Set 13 | import Text.Printf 14 | 15 | type Node = Int 16 | type Context a = (IntSet, a, IntSet) 17 | newtype Gr a = Gr (IntMap (Context a)) 18 | 19 | instance Functor Gr where 20 | fmap = nmap 21 | 22 | instance Show a => Show (Gr a) where 23 | showsPrec d g = showParen (d > 10) $ showString "mkGraph " . showsPrec 11 (labNodes g) . showString " " . showsPrec 11 (edges g) 24 | 25 | mkGraph :: [(Node, a)] -> [(Node, Node)] -> Gr a 26 | mkGraph nodes edges = Gr (IM.fromList [(v, ctx v a) | (v, a) <- nodes]) 27 | where 28 | ctx v a = (IM.findWithDefault IS.empty v bwd, a, IM.findWithDefault IS.empty v fwd) 29 | fwd = IM.fromListWith (<>) [(a, IS.singleton b) | (a, b) <- edges] 30 | bwd = IM.fromListWith (<>) [(b, IS.singleton a) | (a, b) <- edges] 31 | 32 | mkUGraph :: [Node] -> [(Node, Node)] -> Gr () 33 | mkUGraph nodes edges = mkGraph (zip nodes (repeat ())) edges 34 | 35 | labNodes :: Gr a -> [(Node, a)] 36 | labNodes (Gr m) = l <$> IM.toList m 37 | where 38 | l (v, (p, a, s)) = (v, a) 39 | 40 | nodes :: Gr a -> [Node] 41 | nodes (Gr m) = IM.keys m 42 | 43 | edges :: Gr a -> [(Node, Node)] 44 | edges (Gr m) = foldMap go (IM.toList m) 45 | where 46 | go (v, (p, a, s)) = map ((,) v) (IS.toList s) 47 | 48 | suc :: Gr a -> Node -> IntSet 49 | suc (Gr m) v = case m IM.! v of 50 | (p, a, s) -> s 51 | 52 | pre :: Gr a -> Node -> IntSet 53 | pre (Gr m) v = case m IM.! v of 54 | (p, a, s) -> p 55 | 56 | lab :: Gr a -> Node -> a 57 | lab (Gr m) v = case m IM.! v of 58 | (p, a, s) -> a 59 | 60 | labMaybe :: Gr a -> Node -> Maybe a 61 | labMaybe (Gr m) v = case IM.lookup v m of 62 | Just (p, a, s) -> Just a 63 | Nothing -> Nothing 64 | 65 | sucList :: Gr a -> Node -> [Node] 66 | sucList g v = IS.toList (suc g v) 67 | 68 | preList :: Gr a -> Node -> [Node] 69 | preList g v = IS.toList (pre g v) 70 | 71 | indeg :: Gr a -> Node -> Int 72 | indeg gr v = IS.size (pre gr v) 73 | 74 | outdeg :: Gr a -> Node -> Int 75 | outdeg gr v = IS.size (suc gr v) 76 | 77 | hasEdge :: Gr a -> (Node, Node) -> Bool 78 | hasEdge (Gr m) (a,b) = case IM.lookup a m of 79 | Just (p, a, s) -> IS.member b s 80 | Nothing -> False 81 | 82 | hasNode :: Gr a -> Node -> Bool 83 | hasNode (Gr m) v = case IM.lookup v m of 84 | Just _ -> True 85 | Nothing -> False 86 | 87 | delNode :: Node -> Gr a -> Gr a 88 | delNode v (Gr m) = case IM.lookup v m of 89 | Just (p, a, s) -> Gr . foldr (.) id (clearSucc v <$> IS.toList p) . foldr (.) id (clearPred v <$> IS.toList s) . IM.delete v $ m 90 | Nothing -> Gr m 91 | where 92 | clearSucc v k m = IM.adjust (\(p, a, s) -> (p, a, IS.delete v s)) k m 93 | clearPred v k m = IM.adjust (\(p, a, s) -> (IS.delete v p, a, s)) k m 94 | 95 | insEdges :: [(Node, Node)] -> Gr a -> Gr a 96 | insEdges es (Gr m) = Gr . part1 . part2 $ m 97 | where 98 | adjs = IM.fromListWith (<>) [(a, IS.singleton b) | (a, b) <- es] 99 | adjp = IM.fromListWith (<>) [(b, IS.singleton a) | (a, b) <- es] 100 | part1 = foldr (.) id [IM.adjust (\(p, a, s) -> (p, a, s <> js)) i | (i, js) <- IM.toList adjs] 101 | part2 = foldr (.) id [IM.adjust (\(p, a, s) -> (p <> js, a, s)) i | (i, js) <- IM.toList adjp] 102 | 103 | insNode :: (Node, a) -> Gr a -> Gr a 104 | insNode (i, a) (Gr m) = Gr (IM.alter go i m) 105 | where 106 | go (Just (p1, a1, s1)) = Just (p1, a, s1) 107 | go Nothing = Just (IS.empty, a, IS.empty) 108 | 109 | (&) :: (IntSet, Node, a, IntSet) -> Gr a -> Gr a 110 | (&) (p, i, a, s) = insEdges (ein ++ eout) . insNode (i, a) 111 | where 112 | ein = [(j, i) | j <- IS.toList p] 113 | eout = [(i, j) | j <- IS.toList s] 114 | 115 | newNodes :: Int -> Gr a -> [Node] 116 | newNodes n (Gr m) = case IM.findMax m of 117 | (x, _) -> [x+1..x+n] 118 | 119 | -- XXX: Really slow. 120 | order :: Gr a -> Int 121 | order (Gr m) = IM.size m 122 | 123 | gmap :: (Node -> Context a -> Context b) -> Gr a -> Gr b 124 | gmap f (Gr m) = Gr (IM.mapWithKey f m) 125 | 126 | nmap :: (a -> b) -> Gr a -> Gr b 127 | nmap f (Gr m) = Gr (IM.map go m) 128 | where go (p, a, s) = (p, f a, s) 129 | 130 | gfilter :: (Node -> Context a -> Bool) -> Gr a -> Gr a 131 | gfilter f (Gr m) = Gr (IM.mapMaybeWithKey go m) 132 | where 133 | go i (p, a, s) 134 | | IS.member i keep = Just (IS.intersection p keep, a, IS.intersection s keep) 135 | | otherwise = Nothing 136 | keep = IS.fromList [i | (i, ctx) <- IM.toList m, f i ctx] 137 | 138 | labfilter :: (a -> Bool) -> Gr a -> Gr a 139 | labfilter f g = gfilter (\i (p, a, s) -> f a) g 140 | 141 | dfs :: Gr a -> [Node] -> [Node] 142 | dfs g start = go start IS.empty 143 | where 144 | go [] visited = [] 145 | go (x:xs) visited 146 | | IS.member x visited = go xs visited 147 | | otherwise = x : go (sucList g x ++ xs) (IS.insert x visited) 148 | 149 | udfs :: Gr a -> [Node] -> [Node] 150 | udfs g start = go start IS.empty 151 | where 152 | go [] visited = [] 153 | go (x:xs) visited 154 | | IS.member x visited = go xs visited 155 | | otherwise = x : go (sucList g x ++ preList g x ++ xs) (IS.insert x visited) 156 | 157 | topsort :: Gr a -> [Node] 158 | topsort g = (foldr (.) id $ evalState (traverse go (nodes g)) IS.empty) [] 159 | where 160 | go = \x -> do 161 | visited <- get 162 | case IS.member x visited of 163 | True -> pure id 164 | False -> do 165 | put (IS.insert x visited) 166 | before <- foldr (.) id <$> traverse go (preList g x) 167 | return (before . (x:)) 168 | 169 | subgraph :: Gr a -> [Node] -> Gr a 170 | subgraph gr ns = (mkGraph 171 | (filter (\(i,a) -> IS.member i nset) (labNodes gr)) 172 | (filter (\(i,j) -> IS.member i nset && IS.member j nset) (edges gr))) 173 | where 174 | nset = IS.fromList ns 175 | 176 | 177 | -- Assumes that g is undirected, but does not check. 178 | components :: Gr a -> [[Node]] 179 | components g = filter (not . null) $ evalState (traverse go (nodes g)) IS.empty 180 | where 181 | go = \x -> go1 [x] [] 182 | go1 :: [Node] -> [Node] -> State IntSet [Node] 183 | go1 [] os = pure os 184 | go1 (x:xs) os = do 185 | visited <- get 186 | case IS.member x visited of 187 | True -> go1 xs os 188 | False -> do 189 | put (IS.insert x visited) 190 | go1 (sucList g x ++ xs) (x:os) 191 | 192 | splitComponents :: Gr a -> [Gr a] 193 | splitComponents (Gr m) = [Gr (IM.restrictKeys m (IS.fromList c)) | c <- components (Gr m)] 194 | 195 | isEmpty :: Gr a -> Bool 196 | isEmpty (Gr m) = IM.null m 197 | 198 | isConnected :: Gr a -> Bool 199 | isConnected g = isEmpty g || IS.fromList (udfs g (take 1 (nodes g))) == IS.fromList (nodes g) 200 | 201 | isUndirected :: Gr a -> Bool 202 | isUndirected (Gr m) = all ok (toList m) 203 | where 204 | ok (p, a, s) = p == s 205 | 206 | -- Make simple and undirected, remove labels 207 | simplify :: Gr a -> Gr () 208 | simplify gr = const () <$> simplify' gr 209 | 210 | -- Make simple and undirected 211 | simplify' :: Gr a -> Gr a 212 | simplify' gr = gmap dedup gr 213 | where 214 | dedup node (p, a, s) = 215 | let adj = IS.delete node $ (p <> s) 216 | in (adj, a, adj) 217 | 218 | plot :: String -> Gr () -> IO () 219 | plot fname gr = writeFile fname txt 220 | where 221 | txt = unlines [ 222 | "digraph {", 223 | unlines [show n | n <- nodes gr], 224 | unlines [show a ++ " -- " ++ show b | (a,b) <- edges gr, hasEdge gr (b,a), a < b], 225 | unlines [show a ++ " -> " ++ show b | (a,b) <- edges gr, not (hasEdge gr (b,a))], 226 | "}"] 227 | 228 | plotLab :: Show a => String -> Gr a -> IO () 229 | plotLab fname gr = writeFile fname txt 230 | where 231 | txt = unlines [ 232 | if isUndirected gr then "graph {" else "digraph {", 233 | unlines [printf "%i [label=\"%s\"]" n (show a) | (n,a) <- labNodes gr], 234 | unlines [show a ++ " -- " ++ show b | (a,b) <- edges gr, hasEdge gr (b,a), a < b], 235 | unlines [show a ++ " -> " ++ show b | (a,b) <- edges gr, not (hasEdge gr (b,a))], 236 | "}"] 237 | -------------------------------------------------------------------------------- /twremat/src/TWRemat.hs: -------------------------------------------------------------------------------- 1 | {-# Language BangPatterns #-} 2 | {-# Language DeriveTraversable #-} 3 | {-# Language NamedFieldPuns #-} 4 | module TWRemat where 5 | 6 | import Data.DList (DList) 7 | import qualified Data.DList as DL 8 | import Data.Foldable 9 | import Data.IntMap (IntMap) 10 | import qualified Data.IntMap as IM 11 | import Data.Map (Map) 12 | import qualified Data.Map.Lazy as Map 13 | import Data.IntSet (IntSet) 14 | import qualified Data.IntSet as IS 15 | import Data.List 16 | import Data.Ord 17 | 18 | import Balanced 19 | import Graph (Gr, Node) 20 | import qualified Graph as G 21 | import TreeWidth 22 | import Util 23 | 24 | data Step = Compute Node | Free Node 25 | deriving (Show, Eq, Ord) 26 | 27 | toposort :: Gr a -> IntSet -> [Node] 28 | toposort gr = go 29 | where 30 | go xs = sortOn (\x -> score IM.! x) (IS.toList xs) 31 | score = IM.fromList (zip (G.topsort gr) [0..]) 32 | 33 | ancestors :: Gr a -> (Node -> IntSet) 34 | ancestors gr = tab 35 | where 36 | tab = memo (G.nodes gr) (\n -> IS.singleton n <> foldMap tab (G.preList gr n)) 37 | 38 | -- Recursively remove elements of root node from all subtrees. 39 | preFilter :: Tree Bag -> Tree Bag 40 | preFilter (Tree _ x subs) = tree x [preFilter $ fmap (`IS.difference` x) c | c <- subs] 41 | 42 | data Comp = Comp{x :: Bag, xall :: Bag} 43 | 44 | -- Annotate each node with the union of all nodes in its subtree. 45 | preFold :: Tree Bag -> Tree Comp 46 | preFold t@(Tree _ x subs) = tree Comp{x,xall=total} subs' 47 | where subs' = preFold <$> subs 48 | total = x <> fold (xall . treeVal <$> subs') 49 | 50 | -- Computes a rematerialization schedule for the given DAG, which ends 51 | -- with the nodes of 'compute' computed and in memory. 52 | remat :: Gr a -> IntSet -> [Step] 53 | remat gr compute = DL.toList (twremat (preFold . preFilter . sepTree $ treeWidth gr) compute) 54 | where 55 | topo = toposort gr 56 | antab = ancestors gr 57 | 58 | twremat :: Tree Comp -> IntSet -> DList Step 59 | twremat (Tree _ Comp{x} components) compute 60 | | IS.null compute = mempty 61 | | otherwise = case components of 62 | [] -> 63 | -- Base case: simply execute the remaining nodes in order, then 64 | -- free the ones the caller doesn't need. 65 | DL.fromList (Compute <$> topo target) <> DL.fromList (Free <$> topo (target `IS.difference` compute)) 66 | components -> 67 | -- Recursion case: select a balanced separator X of the tree decomposition. 68 | -- 1. for each node v of X that we need to compute, in topological order 69 | -- a. Recursively compute the direct dependencies of v in each subtree, 70 | -- excluding any which are in X itself (those are already computed 71 | -- and in memory, since we are traversing X in topological order). 72 | -- b. Compute v. 73 | -- c. Free the dependencies computed in #1a. 74 | -- 2. Recursively compute the needed nodes which are not in X 75 | -- 3. Free the computed nodes of X that the caller doesn't need. 76 | let compsets = map (xall . treeVal) components :: [Bag] 77 | part1 v = let deps = G.pre gr v 78 | new_computes = [deps `IS.intersection` chi_nodes | chi_nodes <- compsets] 79 | in fold [twremat chi new_compute | (chi, new_compute) <- zip components new_computes] 80 | <> (DL.singleton (Compute v)) 81 | <> DL.fromList (Free <$> (IS.toList $ fold new_computes)) 82 | part2 = fold [twremat chi (outside `IS.intersection` chi_nodes) | (chi, chi_nodes) <- zip components compsets] 83 | part3 = DL.fromList (Free <$> topo (target `IS.difference` compute)) 84 | in foldMap part1 (topo target) <> part2 <> part3 85 | where 86 | ancestor_set = foldMap antab (IS.toList compute) :: IntSet 87 | -- Nodes of X which are needed, directly or indirectly. 88 | target = IS.filter (\i -> IS.member i ancestor_set) x 89 | -- Nodes the caller needs which are not in X. 90 | outside = compute `IS.difference` x 91 | -------------------------------------------------------------------------------- /twremat/src/TreeWidth.hs: -------------------------------------------------------------------------------- 1 | {-# Language BangPatterns #-} 2 | module TreeWidth where 3 | 4 | import Data.Foldable 5 | import Data.IntMap (IntMap) 6 | import qualified Data.IntMap as IM 7 | import Data.OrdPSQ (OrdPSQ) 8 | import qualified Data.OrdPSQ as PQ 9 | import Data.IntSet (IntSet) 10 | import qualified Data.IntSet as IS 11 | import Data.Map (Map) 12 | import qualified Data.Map.Strict as Map 13 | import Data.Set (Set) 14 | import qualified Data.Set as Set 15 | import Data.Tuple 16 | import Debug.Trace 17 | 18 | import Graph (Gr, Node) 19 | import qualified Graph as G 20 | 21 | type Bag = IntSet 22 | 23 | -- O(n^2 d^2), where d is the average degree 24 | slowTreeWidth :: Gr a -> Gr Bag 25 | slowTreeWidth ga = go (G.simplify ga) [] 26 | where 27 | go gr ns = case min_fill_in gr of 28 | Just node -> let gr' = G.insEdges [(a, b) | 29 | a <- G.sucList gr node, 30 | b <- G.sucList gr node, 31 | a /= b, 32 | not (G.hasEdge gr (a,b))] $ G.delNode node gr 33 | in go gr' ((node, G.suc gr node) : ns) 34 | Nothing -> finish ns (G.mkGraph [(0, IS.fromList (G.nodes gr))] []) 35 | finish [] tree = tree 36 | finish ((node,neighbors):ns) tree = finish ns tree' 37 | where 38 | target = head ([i | (i, bag) <- G.labNodes tree, IS.isSubsetOf neighbors bag] ++ [0]) -- inefficient 39 | [new_id] = G.newNodes 1 tree 40 | adj = IS.singleton target 41 | tree' = (adj, new_id, IS.insert node neighbors, adj) G.& tree 42 | 43 | -- O(n d^2 log n), where d is the average degree 44 | treeWidth :: Gr a -> Gr Bag 45 | treeWidth ga = let gr = G.simplify ga in go gr [] (initCache gr) 46 | where 47 | go !gr !ns !cache = 48 | case minCache cache of 49 | Just node -> let neighbors = G.sucList gr node 50 | newEdges = [(a, b) | 51 | a <- neighbors, 52 | b <- neighbors, 53 | a /= b, 54 | not (G.hasEdge gr (a,b))] 55 | gr' = G.insEdges newEdges $ G.delNode node gr 56 | dirty = IS.fromList $ [node] ++ neighbors ++ (neighbors >>= G.sucList gr) 57 | in go gr' ((node, IS.fromList $ neighbors) : ns) (updateCache gr' dirty cache) 58 | Nothing -> finish ns (IS.fromList (G.nodes gr)) 59 | finish ns initBag = gofinish ns (G.mkGraph [(0, initBag)] []) (IM.fromSet (const (IS.singleton 0)) initBag) 60 | -- 'bags' indexes the current bags of the tree by their contents 61 | -- bags : vertex v -> set of bags containing v 62 | gofinish [] tree bags = tree 63 | gofinish ((node,neighbors):ns) tree bags = gofinish ns tree' bags' 64 | where 65 | -- Either connect to some bag that contains all our neighbors, or connect to the first bag 66 | target = case [b | n <- IS.toList neighbors, Just b <- [IM.lookup n bags]] of 67 | [] -> 0 68 | bs -> head (IS.toList (foldr1 IS.intersection bs) ++ [0]) 69 | [new_id] = G.newNodes 1 tree 70 | new_bag = IS.insert node neighbors 71 | adj = IS.singleton target 72 | tree' = (adj, new_id, new_bag, adj) G.& tree 73 | bags' = IM.unionWith (<>) bags (IM.fromSet (const (IS.singleton new_id)) new_bag) 74 | 75 | -- The cache maintains a priority queue of nodes according to their 76 | -- fill number, such that the node with minimum fill number can be 77 | -- obtained in O(1), and the cache can be updated to accomodate 78 | -- removal or addition of nodes in O(log n). 79 | type MinFillCache = OrdPSQ Node (Int, Int, Node) () 80 | 81 | -- O(n d^2), where d is the average degree of nodes 82 | initCache :: Gr a -> MinFillCache 83 | initCache gr = PQ.fromList [(node, (fill gr node, G.outdeg gr node, node), ()) | node <- G.nodes gr] 84 | 85 | -- O(m log n), where m is the number of nodes to refresh 86 | updateCache :: Gr a -> IntSet -> MinFillCache -> MinFillCache 87 | updateCache gr nodes cache = foldr update cache (IS.toList nodes) 88 | where 89 | update node cache 90 | | G.hasNode gr node = let f = fill gr node 91 | d = G.outdeg gr node 92 | in PQ.insert node (f, d, node) () cache 93 | | otherwise = PQ.delete node cache 94 | 95 | -- O(log n) 96 | minCache :: MinFillCache -> Maybe Node 97 | minCache cache 98 | | order == 0 = Nothing 99 | | degree == order - 1 = Nothing 100 | | otherwise = Just n 101 | where 102 | order = PQ.size cache 103 | Just (n, (_, degree, _), _) = PQ.findMin cache 104 | 105 | -- Find fill number of a node, defined as the minimum number of edges 106 | -- that must be added to the graph to make the neighborhood of `node` 107 | -- into a clique. O(log n + e^2), where e is the number of neighbors 108 | -- of node 109 | fill :: Gr a -> Node -> Int 110 | fill gr node = sum (map subfill (IS.toList neighbors)) `div` 2 111 | where neighbors = G.suc gr node 112 | subfill n = IS.size (neighbors `IS.difference` G.suc gr n) - 1 113 | 114 | -- Find node with minimum fill number. 115 | -- O(n d^2) 116 | min_fill_in :: Gr a -> Maybe G.Node 117 | min_fill_in gr 118 | | G.isEmpty gr = Nothing 119 | | degree == G.order gr - 1 = Nothing 120 | | otherwise = case minimum [(fill gr n, G.outdeg gr n, n) | n <- nodes] of 121 | (_, _, node) -> Just node 122 | where 123 | nodes = G.nodes gr 124 | degree = minimum (map (G.outdeg gr) nodes) 125 | -------------------------------------------------------------------------------- /twremat/src/Tupfile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nshepperd/gpt-2/89cb310a244fa179b5c55dbc5098803bfdbd85dc/twremat/src/Tupfile -------------------------------------------------------------------------------- /twremat/src/Util.hs: -------------------------------------------------------------------------------- 1 | module Util where 2 | 3 | import Data.Foldable 4 | import Data.List 5 | import Data.Map (Map) 6 | import qualified Data.Map.Strict as Map 7 | import Data.Ord 8 | import Data.Set (Set) 9 | import qualified Data.Set as Set 10 | import Data.Tuple 11 | import Debug.Trace 12 | 13 | reflex :: Ord a => [a] -> Map a a 14 | reflex xs = Map.fromList [(x, x) | x <- xs] 15 | 16 | memo :: Ord a => [a] -> (a -> b) -> (a -> b) 17 | memo xs f = \x -> tab Map.! x 18 | where 19 | tab = f <$> reflex xs 20 | 21 | minimumOn :: (Foldable t, Ord a) => (b -> a) -> t b -> b 22 | minimumOn f xs = minimumBy (comparing f) xs 23 | 24 | maximumOn :: (Foldable t, Ord a) => (b -> a) -> t b -> b 25 | maximumOn f xs = maximumBy (comparing f) xs 26 | -------------------------------------------------------------------------------- /twremat/test/TestBalanced.hs: -------------------------------------------------------------------------------- 1 | module TestBalanced where 2 | 3 | import Control.Monad 4 | import Data.Foldable 5 | import Test.QuickCheck 6 | import Test.Tasty 7 | import Test.Tasty.QuickCheck 8 | 9 | import Balanced 10 | import Graph (Gr, Node) 11 | import qualified Graph as G 12 | import TestGraph 13 | 14 | subTrees :: Tree a -> [Tree a] 15 | subTrees t@(Tree w a []) = [t] 16 | subTrees t@(Tree w a cs) = t : foldMap subTrees cs 17 | 18 | testBalanced :: TestTree 19 | testBalanced = testGroup "Balanced" [ 20 | testProperty "Subtrees are connected" $ \(TreeOf gr) -> 21 | let t = mkTree (gr :: Gr ()) 22 | go tree = G.isConnected (G.subgraph gr (toList tree)) 23 | in all go (subTrees t), 24 | testProperty "Subtrees are connected after balance" $ \(TreeOf gr) -> 25 | let t = balance $ mkTree (gr :: Gr ()) 26 | go tree = G.isConnected (G.subgraph gr (toList tree)) 27 | in all go (subTrees t), 28 | testProperty "Subtrees are balanced" $ \(TreeOf gr) -> 29 | let t = balance $ mkTree (gr :: Gr ()) 30 | go t@(Tree w a []) = True 31 | go t@(Tree w a cs) = maximum (map treeWeight cs) <= div w 2 32 | in all go (subTrees t) 33 | ] 34 | -------------------------------------------------------------------------------- /twremat/test/TestGraph.hs: -------------------------------------------------------------------------------- 1 | {-# Language ScopedTypeVariables #-} 2 | module TestGraph where 3 | 4 | import Control.Monad 5 | import qualified Data.Graph.Inductive.Arbitrary as FGL 6 | import qualified Data.Graph.Inductive.Graph as FGL 7 | import qualified Data.Graph.Inductive.PatriciaTree as FGL 8 | import Data.List 9 | import Test.QuickCheck 10 | import Test.Tasty 11 | import Test.Tasty.QuickCheck 12 | 13 | import Graph 14 | 15 | instance Arbitrary a => Arbitrary (Gr a) where 16 | arbitrary = do let t = id :: Gen (FGL.Gr a ()) -> Gen (FGL.Gr a ()) 17 | g <- t arbitrary 18 | return (mkGraph (FGL.labNodes g) (FGL.edges g)) 19 | 20 | newtype TreeOf a = TreeOf { getTreeOf :: Gr a } 21 | deriving Show 22 | 23 | instance Arbitrary a => Arbitrary (TreeOf a) where 24 | arbitrary = do n <- chooseInt (1, 20) 25 | ids <- shuffle [1..n] 26 | vals <- replicateM n arbitrary 27 | 28 | let go tree xs [] = pure tree 29 | go tree xs (y:ys) = do 30 | x <- elements xs 31 | go ((x,y):tree) (y : xs) ys 32 | edges <- go [] [head ids] (tail ids) 33 | 34 | return (TreeOf $ simplify' $ mkGraph (zip ids vals) edges) 35 | 36 | newtype DagOf a = DagOf { getDagOf :: Gr a } 37 | deriving Show 38 | 39 | instance Arbitrary a => Arbitrary (DagOf a) where 40 | arbitrary = do n <- chooseInt (1, 20) 41 | ids <- shuffle [1..n] 42 | vals <- replicateM n arbitrary 43 | 44 | let go edges xs [] = pure edges 45 | go edges xs (y:ys) = do 46 | -- choose an existing node, make the new node a dependency 47 | x <- elements xs 48 | go ((y,x):edges) (y : xs) ys 49 | edges <- go [] [head ids] (tail ids) 50 | 51 | extra <- case n of 52 | 1 -> pure [] 53 | _ -> do 54 | n_extra <- chooseInt (0,20) 55 | replicateM n_extra $ do 56 | sub <- elements (filter ((>1) . length) (tails ids)) 57 | let (a:b:cs) = sub 58 | b <- elements (b:cs) 59 | return (b, a) 60 | 61 | return (DagOf $ mkGraph (zip ids vals) (edges ++ extra)) 62 | 63 | testGraph :: TestTree 64 | testGraph = testGroup "Graph" [ 65 | testProperty "subgraph nodes" $ \(gr :: Gr ()) -> 66 | let sub = take (length (nodes gr) `div` 2) (nodes gr) 67 | subgr = subgraph gr sub 68 | in nodes (subgr) == sub, 69 | testProperty "topsort nodes" $ \(gr :: Gr ()) -> 70 | sort (nodes gr) == sort (topsort gr) 71 | ] 72 | -------------------------------------------------------------------------------- /twremat/test/TestTreeWidth.hs: -------------------------------------------------------------------------------- 1 | module TestTreeWidth where 2 | 3 | import Data.Foldable 4 | import Data.IntMap (IntMap) 5 | import qualified Data.IntMap as IM 6 | import Data.IntSet (IntSet) 7 | import qualified Data.IntSet as IS 8 | import Data.Map (Map) 9 | import qualified Data.Map.Strict as Map 10 | import Data.Set (Set) 11 | import qualified Data.Set as Set 12 | import Data.Tuple 13 | import Debug.Trace 14 | 15 | import Test.QuickCheck 16 | import Test.Tasty 17 | import Test.Tasty.QuickCheck 18 | 19 | import Graph (Gr, Node) 20 | import qualified Graph as G 21 | import TestGraph 22 | import TreeWidth 23 | 24 | -- Verify the three properties of a tree decomposition: 25 | -- 1. The union of all bags = the set of nodes 26 | check1 :: Gr () -> Bool 27 | check1 gr = IS.unions (map snd (G.labNodes (treeWidth gr))) == IS.fromList (G.nodes gr) 28 | 29 | -- 2. For every edge (a,b), there is a bag which includes both vertices. 30 | check2 :: Gr () -> Bool 31 | check2 gr = let tree = treeWidth gr 32 | sets = map snd (G.labNodes tree) 33 | in and [any (\s -> IS.member a s && IS.member b s) sets | (a, b) <- G.edges gr] 34 | 35 | -- 3. For a given vertex v, bags containing v are connected. 36 | check3 :: Gr () -> Bool 37 | check3 gr = let tree = treeWidth gr 38 | in and [G.isConnected (G.labfilter (IS.member v) tree) | v <- G.nodes gr] 39 | 40 | -- ?. should validate that the result is a tree? 41 | check4 :: Gr () -> Bool 42 | check4 gr = let tree = treeWidth gr 43 | in length (G.edges tree) == 2 * (G.order tree - 1) 44 | 45 | testTreeWidth :: TestTree 46 | testTreeWidth = testGroup "TreeWidth" [ 47 | testProperty "vertices" check1, 48 | testProperty "edges" check2, 49 | testProperty "connected" check3, 50 | testProperty "tree" check4 51 | ] 52 | -------------------------------------------------------------------------------- /twremat/test/Tupfile: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nshepperd/gpt-2/89cb310a244fa179b5c55dbc5098803bfdbd85dc/twremat/test/Tupfile -------------------------------------------------------------------------------- /twremat/twremat.cabal: -------------------------------------------------------------------------------- 1 | cabal-version: 2.2 2 | name: twremat 3 | version: 0.1.0.0 4 | synopsis: Fast implementation of `Efficient Rematerialization for Deep Networks` 5 | -- description: 6 | -- bug-reports: 7 | -- license: GPL-3.0-or-later 8 | -- license-file: LICENSE 9 | author: nshepperd 10 | maintainer: nshepperd@gmail.com 11 | -- copyright: 12 | -- category: Distribution 13 | extra-source-files: README.md 14 | 15 | library lib 16 | hs-source-dirs: src 17 | exposed-modules: 18 | Balanced Dense Filter Graph TreeWidth TWRemat Util 19 | build-depends: base >= 4.12.0.0 && < 4.16.0.0, containers, mtl, psqueues, dlist, relation 20 | default-language: Haskell2010 21 | 22 | executable twremat 23 | main-is: remat.hs 24 | -- other-modules: Cabbage.Config, Cabbage.Cabal, Cabbage.Parser 25 | -- other-extensions: 26 | build-depends: base >= 4.12.0.0 && < 4.16.0.0, lib, 27 | containers, mtl, psqueues, dlist, relation, parsers, trifecta, text 28 | -- , Glob 29 | -- , containers 30 | -- , directory 31 | -- , filepath 32 | -- , optparse-applicative 33 | -- , pretty-simple 34 | -- , process 35 | -- , temporary 36 | -- , text 37 | -- , xdg-basedir 38 | -- , parsers 39 | -- , trifecta 40 | 41 | hs-source-dirs: main 42 | default-language: Haskell2010 43 | default-extensions: LambdaCase, OverloadedStrings, RecordWildCards 44 | 45 | 46 | Test-Suite testmain 47 | type: exitcode-stdio-1.0 48 | main-is: test.hs 49 | hs-source-dirs: main test 50 | other-modules: TestBalanced TestGraph TestTreeWidth 51 | build-depends: base >= 4.12.0.0 && < 4.16.0.0, lib, 52 | containers, mtl, psqueues, dlist, relation, parsers, trifecta, text, 53 | QuickCheck, tasty, tasty-quickcheck, fgl, fgl-arbitrary 54 | default-language: Haskell2010 55 | --------------------------------------------------------------------------------