├── .gitignore ├── DOWNLOAD_DEMIX_DATA.md ├── LICENSE ├── README.md ├── domain_loader ├── __init__.py ├── constants.py ├── count_words.py ├── domain_loader.py ├── make_splits.py ├── scan_filenames.py ├── shard_dataset.py └── utils.py ├── domains ├── legal │ └── split_legal.py ├── openwebtext │ ├── convert_filenames.py │ └── unpack_openwebtext.py ├── realnews │ └── split_realnews.py ├── reddit │ └── download.sh ├── reviews │ ├── download.sh │ └── split_reviews.py └── s2orc │ └── extract_papers.py ├── environment.yml ├── example_domains ├── acl_papers │ └── acl_papers.jsonl ├── gpt2_bpe │ ├── dict.txt │ ├── encoder.json │ └── vocab.bpe └── legal_contracts │ └── legal_contracts.jsonl ├── requirements.txt └── scripts ├── __init__.py ├── anonymize_file.py ├── download_example_domains.sh ├── fetch_articles.py ├── multiprocessing_bpe_encoder.py ├── prepare.sh ├── preprocess.sh ├── preprocess_example_domains.sh ├── pretokenize.sh └── vocab_overlap.py /.gitignore: -------------------------------------------------------------------------------- 1 | metadata/ 2 | splits/ 3 | shards/ 4 | data-bin/ 5 | example_domains/ag/ 6 | example_domains/chemprot/ 7 | example_domains/imdb/ 8 | example_domains/amazon/ 9 | example_domains/rct-20k/ 10 | example_domains/hyperpartisan_news/ 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | *.py,cover 61 | .hypothesis/ 62 | .pytest_cache/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | -------------------------------------------------------------------------------- /DOWNLOAD_DEMIX_DATA.md: -------------------------------------------------------------------------------- 1 | # Download Instructions for DEMix Data 2 | 3 | Here we provide instructions to download data used in the DEMix paper. Note that downloading most datasets involve getting approval from dataset hosters. 4 | 5 | ### 1B Words 6 | 7 | Download 1B words corpus from here: https://opensource.google/projects/lm-benchmark 8 | 9 | ### Legal 10 | 11 | Create an account and download data here https://case.law/ 12 | 13 | ### S2ORC (e.g., Med, CS) 14 | 15 | Follow instructions here to download papers: https://github.com/allenai/s2orc 16 | 17 | When papers are downloaded, you can extract papers using the scripts in `domains/s2orc/extract_papers.py`. 18 | 19 | ### Openwebtext 20 | 21 | Download Openwebtext from here https://skylion007.github.io/OpenWebTextCorpus/. 22 | 23 | Use the script at `domains/openwebtext/unpack_openwebtext.py` to unpack the data. 24 | 25 | ### RealNews 26 | 27 | Download the dataset from here: https://docs.google.com/forms/d/1LMAUeUtHNPXO9koyAIlDpvyKsLSYlrBj3rYhC30a7Ak/viewform?edit_requested=true 28 | 29 | ### Reviews 30 | 31 | Download the raw review data from here: http://deepyeti.ucsd.edu/jianmo/amazon/index.html 32 | 33 | ### Gutenberg 34 | 35 | Follow instructions here to download the data: https://github.com/aparrish/gutenberg-dammit 36 | 37 | ### Github 38 | 39 | Download data here: https://console.cloud.google.com/marketplace/product/github/github-repos, under the `contents` table. 40 | 41 | ### ACL Papers 42 | 43 | Download data here: https://allenai.org/data/qasper 44 | 45 | ### Legal contracts 46 | 47 | Download data here: https://www.atticusprojectai.org/cuad 48 | 49 | ### CORD-19 50 | 51 | Download dataset here: https://www.semanticscholar.org/cord19/download 52 | 53 | ### Tweets 54 | 55 | Sign up for the [Twitter Academic API](https://developer.twitter.com/en/products/twitter-api/academic-research), and download tweets in a jsonl format. 56 | 57 | 58 | ### Breaking News 59 | 60 | Use `domain/scripts/fetch_articles.py` to crawl breaking news articles. We use the URLs associated with `high factuality` in `https://github.com/ramybaly/News-Media-Reliability/blob/master/data/acl2020/corpus.tsv`. 61 | 62 | ```bash 63 | python -m domain.scripts.fetch_articles --num-articles-per-source 100 --path-to-output news.jsonl 64 | ``` 65 | 66 | 67 | ### Yelp Reviews 68 | 69 | Download dataset here https://www.yelp.com/dataset 70 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multidomain Language Modeling Data Utilities 2 | 3 | This repository contains data utilities for "DEMix Layers: Disentangling Domains for Modular Language Modeling" (Gururangan et. al, 2021). 4 | 5 | This code is generic; it can be used to build any multi-domain, metadata-tagged datasets in a format compatible with Fairseq for language modeling. We also provide download links to data necessary to reproduce results in the DEMix paper. 6 | 7 | 8 | ## General Overview 9 | 10 | In the DEMix paper, we assume a sharded dataset structure across domains, where the dataset is split among many folders, and each folder contains many files, each containing a single document. We found this format to be particularly amenable to efficient PyTorch dataloading, and this follows the Openwebtext dataset format. 11 | 12 | The processing steps below generally build the following files: 13 | 14 | * A `shards/` folder, which contains a sharded version of the dataset for efficient Pytorch data loading. 15 | * A `data-bin/` folder, which contains data binaries for training and evaluation of language models in Fairseq 16 | * A `metadata/` folder, which contains `filenames.txt`, an index of the paths to all files in your dataset, and a `metadata.jsonl`, a json-lines file which contains per-document metadata. The former is used for faster data loading, and the later is used for finer-grained filtering of documents based on certain metadata. 17 | 18 | In this tutorial, we use the example datasets in the `example_domains/` directory to build these necessary folders and files. You can use the same process on any data of any size, provided that the original input data is in a `.jsonl` format. 19 | 20 | ## Installation 21 | 22 | ```bash 23 | conda env create --name demix -f environment.yml 24 | conda activate demix 25 | ``` 26 | 27 | First, set your `DATA_DIR` to the root directory where you will be housing the domain directories. 28 | 29 | ```bash 30 | export DATA_DIR=$(pwd)/example_domains 31 | ``` 32 | 33 | ## Download data 34 | 35 | You can download example domains for this tutorial here: 36 | 37 | ```bash 38 | bash scripts/download_example_domains.sh 39 | ``` 40 | 41 | We include the legal contracts and ACL papers domains in the `example_domains` directory already. 42 | 43 | Check this [file](DOWNLOAD_DEMIX_DATA.md) for more information on how to download the data used in the DEMix paper. 44 | 45 | ## Preprocess data 46 | 47 | We next want preprocess all the datasets into fairseq data-bins. We've made this easy with a script: 48 | 49 | ```bash 50 | bash scripts/preprocess_example_domains.sh 51 | ``` 52 | 53 | Otherwise, you can follow along below to understand each preprocessing step. 54 | 55 | We will first preprocess the `imdb` domain. 56 | 57 | ```bash 58 | export DOMAIN=imdb 59 | ``` 60 | 61 | ## Shard Data 62 | 63 | ```bash 64 | python -m domain_loader.shard_dataset \ 65 | --domain $DOMAIN \ 66 | --input-file example_domains/$DOMAIN/$DOMAIN.jsonl \ 67 | --batch-size 512 \ 68 | --text-field text 69 | ``` 70 | 71 | 72 | ## Build metadata/filenames.txt 73 | 74 | To make data loading faster, we first gather a list of filenames in a separate file `${DOMAIN}/metadata/filenames.txt`. To build this file, use `domain_loader/build_filenames.py`. 75 | 76 | ```bash 77 | python -m domain_loader.scan_filenames --domain $DOMAIN 78 | ``` 79 | 80 | ## Split data into train, dev, and test files 81 | 82 | First, count the total whitespace tokens in a domain: 83 | 84 | ```bash 85 | python -m domain_loader.count_words --domain $DOMAIN 86 | ``` 87 | 88 | Then use these word counts to set the total number of tokens for train, dev, and test splits by editing `domain_loader/constants.py`. 89 | 90 | Then make the data splits: 91 | 92 | ```bash 93 | python -m domain_loader.make_splits \ 94 | --domain $DOMAIN \ 95 | --num-workers 0 \ 96 | --batch-size 1 \ 97 | --output-dir $DATA_DIR/$DOMAIN/splits 98 | ``` 99 | 100 | 101 | ## Build fairseq data-bin 102 | 103 | 104 | Download the gpt2 vocabulary: 105 | 106 | ```bash 107 | mkdir ${DATA_DIR}/gpt2_bpe 108 | curl -Lo ${DATA_DIR}/gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt 109 | curl -Lo ${DATA_DIR}/gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json 110 | curl -Lo ${DATA_DIR}/gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe 111 | ``` 112 | 113 | ```bash 114 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 115 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 116 | ``` 117 | 118 | These scripts will output a `data-bin` files in `${DATA_DIR}/data-bin/`, which you can train on with fairseq LMs. 119 | 120 | 121 | 122 | ## Building multi-domain datasets 123 | 124 | 125 | Building a multi-domain dataset follows the same procedure above, except you just add multiple domains in the same data-bin folder (i.e., `${DATA_DIR}/data-bin/`). 126 | 127 | You can apply the same process to the all other domains in the `example_domains` folder, e.g.: 128 | 129 | ```bash 130 | export DOMAIN=ag_news 131 | python -m domain_loader.shard_dataset \ 132 | --domain $DOMAIN \ 133 | --input-file example_domains/$DOMAIN/$DOMAIN.jsonl \ 134 | --batch-size 512 \ 135 | --text-field text 136 | python -m domain_loader.scan_filenames --domain $DOMAIN 137 | python -m domain_loader.count_words --domain $DOMAIN 138 | ## set token counts for "ag_news" in domain_loader/constants.py 139 | python -m domain_loader.make_splits \ 140 | --domain $DOMAIN \ 141 | --num-workers 0 \ 142 | --batch-size 1 \ 143 | --output-dir $DATA_DIR/$DOMAIN/splits 144 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 145 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 146 | ``` 147 | 148 | Check out `bash scripts/preprocess_example_domains.sh` for other examples. 149 | 150 | 151 | 152 | ## Train a multi-domain LM 153 | 154 | Check out the [DEMix](http://github.com/kernelmachine/demix) repo to see how to train an LM on these data-bins. 155 | -------------------------------------------------------------------------------- /domain_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kernelmachine/demix-data/ab513103640c7eae8172324f309b46769798b96c/domain_loader/__init__.py -------------------------------------------------------------------------------- /domain_loader/constants.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import os 3 | 4 | DATA_DIR = Path(os.environ['DATA_DIR']) 5 | 6 | 7 | TOKEN_COUNTS = {"1b": {'num_train_tokens': 700_000_000, 8 | 'num_dev_tokens': 10_000_000, 9 | 'num_test_tokens': 10_000_000}, 10 | "cs": {'num_train_tokens': 4_500_000_000, 11 | 'num_dev_tokens': 10_000_000, 12 | 'num_test_tokens': 10_000_000}, 13 | "reddit": {'num_train_tokens': 25_000_000_000, 14 | 'num_dev_tokens': 10_000_000, 15 | 'num_test_tokens': 10_000_000}, 16 | "reviews": {'num_train_tokens': 2_500_000_000, 17 | 'num_dev_tokens': 10_000_000, 18 | 'num_test_tokens': 10_000_000}, 19 | "realnews": {'num_train_tokens': 15_000_000_000, 20 | 'num_dev_tokens': 10_000_000, 21 | 'num_test_tokens': 10_000_000}, 22 | "openwebtext": {'num_train_tokens': 6_500_000_000, 23 | 'num_dev_tokens': 10_000_000, 24 | 'num_test_tokens': 10_000_000}, 25 | "legal": {'num_train_tokens': 10_500_000_000, 26 | 'num_dev_tokens': 10_000_000, 27 | 'num_test_tokens': 10_000_000}, 28 | "med": {'num_train_tokens': 9_500_000_000, 29 | 'num_dev_tokens': 10_000_000, 30 | 'num_test_tokens': 10_000_000}, 31 | "gutenberg": {'num_train_tokens': 3_000_000_000, 32 | 'num_dev_tokens': 10_000_000, 33 | 'num_test_tokens': 10_000_000}, 34 | "qasper": {'num_train_tokens': 1_000_000, 35 | 'num_dev_tokens': 1_000_000, 36 | 'num_test_tokens': 1_000_000}, 37 | "legal_contracts": {"num_train_tokens": 1_500_000, 38 | "num_dev_tokens": 1_000_000, 39 | "num_test_tokens": 1_000_000}, 40 | "cord19": {"num_train_tokens": 60_000_000, 41 | "num_dev_tokens": 10_000_000, 42 | "num_test_tokens": 10_000_000}, 43 | "github": {"num_train_tokens": 200_000_000, 44 | "num_dev_tokens": 10_000_000, 45 | "num_test_tokens": 10_000_000}, 46 | "tweets": {"num_train_tokens": 8_000_000, 47 | "num_dev_tokens": 1_000_000, 48 | "num_test_tokens": 1_000_000}, 49 | "yelp_reviews": {"num_train_tokens": 600_000_000, 50 | "num_dev_tokens": 10_000_000, 51 | "num_test_tokens": 10_000_000}, 52 | "latest_news": {"num_train_tokens":11_000_000, 53 | "num_dev_tokens": 1_000_000, 54 | "num_test_tokens": 1_000_000}, 55 | "ag": {"num_train_tokens": 100_000, 56 | "num_dev_tokens": 10_000, 57 | "num_test_tokens": 10_000}, 58 | "imdb": {"num_train_tokens": 800_000, 59 | "num_dev_tokens": 100_000, 60 | "num_test_tokens": 100_000}, 61 | "1b_test": {'num_train_tokens': 80_000, 62 | 'num_dev_tokens': 100_000, 63 | 'num_test_tokens': 100_000}, 64 | "hyperpartisan_news": {"num_train_tokens": 200_000, 65 | "num_dev_tokens": 10_000, 66 | "num_test_tokens": 10_000}, 67 | "chemprot": {"num_train_tokens": 100_000, 68 | "num_dev_tokens": 10_000, 69 | "num_test_tokens": 10_000}, 70 | "rct": {"num_train_tokens": 800_000, 71 | "num_dev_tokens": 100_000, 72 | "num_test_tokens": 100_000}, 73 | "citation_intent": {"num_train_tokens": 50_000, 74 | "num_dev_tokens": 10_000, 75 | "num_test_tokens": 10_000}, 76 | "amazon": {"num_train_tokens": 500_000, 77 | "num_dev_tokens": 100_000, 78 | "num_test_tokens": 100_000}, 79 | "acl_papers": {"num_train_tokens": 500_000, 80 | "num_dev_tokens": 100_000, 81 | "num_test_tokens": 100_000} 82 | } 83 | -------------------------------------------------------------------------------- /domain_loader/count_words.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, List, Tuple 3 | 4 | import numpy as np 5 | from torch.utils.data import Dataset, DataLoader 6 | import torch 7 | import humanize 8 | from torch.utils.data.sampler import SubsetRandomSampler 9 | from transformers import GPT2Tokenizer 10 | import pandas as pd 11 | from pathlib import Path 12 | import gzip 13 | 14 | from domain_loader.constants import DATA_DIR 15 | from domain_loader.utils import take_n_tokens 16 | from tqdm.auto import tqdm 17 | import numpy as np 18 | from domain_loader.domain_loader import Domain 19 | import argparse 20 | 21 | 22 | if __name__ == '__main__': 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--domain', type=str) 26 | parser.add_argument('--num-workers', type=int, default=0) 27 | parser.add_argument('--batch-size', type=int, default=1) 28 | 29 | 30 | args = parser.parse_args() 31 | domain = args.domain 32 | 33 | dataset = Domain(DATA_DIR / domain/ "shards") 34 | 35 | dataloader = DataLoader(dataset, 36 | num_workers=args.num_workers, 37 | batch_size=args.batch_size) 38 | 39 | pbar = tqdm(dataloader) 40 | curr_tokens = 0 41 | for _, _, token_count, _ in pbar: 42 | curr_tokens += sum(token_count) 43 | pbar.set_description(f"{humanize.intword(curr_tokens)} tokens") 44 | 45 | print(f"Number of tokens in {str(DATA_DIR / domain / domain)}: {humanize.intword(curr_tokens)}") 46 | -------------------------------------------------------------------------------- /domain_loader/domain_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, List, Tuple, Iterable, Any 3 | from collections import defaultdict 4 | import json 5 | import random 6 | import numpy as np 7 | from torch.utils.data import Dataset, DataLoader 8 | import torch 9 | from torch.utils.data.sampler import SubsetRandomSampler 10 | from transformers import GPT2Tokenizer 11 | from sklearn.feature_extraction.text import TfidfVectorizer 12 | import pandas as pd 13 | from pathlib import Path 14 | import gzip 15 | from joblib import Parallel, delayed 16 | 17 | from domain_loader.constants import DATA_DIR, TOKEN_COUNTS 18 | from domain_loader.utils import take_n_tokens, REGEXES 19 | from tqdm.auto import tqdm 20 | import numpy as np 21 | from scipy import sparse 22 | import re 23 | 24 | 25 | def reservoir_sampling(iterator: Iterable[Any], K: int): 26 | """ 27 | Sample from an iterator without loading the iterator into memory. 28 | """ 29 | result = [] 30 | N = 0 31 | for item in iterator: 32 | N += 1 33 | if len( result ) < K: 34 | result.append( item ) 35 | else: 36 | s = int(random.random() * N) 37 | if s < K: 38 | result[ s ] = item 39 | return result 40 | 41 | 42 | class Domain(Dataset): 43 | def __init__(self, 44 | domain_directory: Path, 45 | filenames: Optional[List[str]] = None, 46 | add_bos_token: bool = False, 47 | bos_token: str = "<|endoftext|>", 48 | ignore_files: Optional[List[str]] = [], 49 | sample: int = None, 50 | sample_from_head: bool = False, 51 | track_token_count: bool = False, 52 | anonymize: bool = False, 53 | sample_by_metadata: Optional[Tuple[str, int]] = None, 54 | metadata_columns: List[str] = None, 55 | metadata_file: Optional[Path] = None, 56 | **metadata_filters): 57 | """ 58 | Basic domain dataset. 59 | 60 | Arguments 61 | ========= 62 | domain_directory -> root directory of the domain 63 | filenames -> list of filenames to use (avoids scanning the directory, which may take a while) 64 | add_bos_token -> prepend a beginning of sentence token to each document during loading 65 | ignore_files -> list of filenames to ignore (use to specify, for example, dev or test data you'd like to ignore) 66 | sample -> specify number of random documents from the domain to sample 67 | sample_from_head -> if set, will sample from the head of the domain, rather than doing reservoir sampling, which may take a while, though it is more unbiased. 68 | track_token_count -> if set, will track the number of tokens sampled during data loading. 69 | anonymize -> if set, will apply some basic regexes to loaded data to redact user identifiable information. 70 | sample_by_metadata -> if set, in the form (metadata_column, k), sample k documents that align with the metadata_column. 71 | metadata_columns -> if set, return metadata columns (from metadata_file) for each document 72 | metadata_file -> if set, read metadata_file as well 73 | **metadata_filters -> if set, in the form {metadata_column: [item_1,item_2,...]}, will filter documents that satisfy these metadata filters 74 | """ 75 | super().__init__() 76 | self.add_bos_token = add_bos_token 77 | self.anonymize = anonymize 78 | 79 | self.anonymizer = {re.compile(regex['regex']): regex['repl'] for regex in REGEXES} 80 | 81 | self.bos_token = bos_token 82 | self.domain_directory = domain_directory 83 | self.files = {} 84 | 85 | if metadata_file: 86 | print(f'loading files from metadata in {metadata_file}') 87 | with open(metadata_file, 'r') as f: 88 | for ix, line in tqdm(enumerate(f)): 89 | if sample: 90 | if ix > sample: 91 | break 92 | z = json.loads(line) 93 | if filenames: 94 | if z['filename'] not in filenames: 95 | continue 96 | metadata_ = {metadata_column: z[metadata_column] for metadata_column in metadata_columns} 97 | if metadata_filters: 98 | for key, items in metadata_filters.items(): 99 | if metadata_[key] in items: 100 | self.files[z['filename']] = metadata_ 101 | else: 102 | self.files[z['filename']] = metadata_ 103 | 104 | if sample_by_metadata: 105 | files_ = {} 106 | self.metadata_counts = defaultdict(int) 107 | for file in self.files: 108 | if self.metadata_counts[self.files[file][sample_by_metadata['metadata_column']]] < sample_by_metadata['sample_size']: 109 | self.metadata_counts[self.files[file][sample_by_metadata['metadata_column']]] += 1 110 | files_[file] = self.files[file] 111 | self.files = files_ 112 | else: 113 | self.metadata_file = None 114 | if filenames: 115 | if isinstance(filenames[0], Tuple): 116 | self.files = dict(filenames) 117 | else: 118 | self.files = filenames 119 | else: 120 | fs = tqdm(domain_directory.glob("*/*")) 121 | if sample: 122 | print(f"Loading {sample} files from {domain_directory}...") 123 | if sample_from_head: 124 | sample_files = [] 125 | for ix, file in enumerate(fs): 126 | if ix < sample: 127 | sample_files.append(file) 128 | else: 129 | break 130 | else: 131 | sample_files = reservoir_sampling(fs, sample) 132 | self.files = sample_files 133 | else: 134 | print(f"Loading all files from {domain_directory}...") 135 | self.files = list(fs) 136 | 137 | if ignore_files: 138 | self.files = list(set(self.files) - set(ignore_files)) 139 | print(f"loaded {len(self.files)} files, ignoring {len(ignore_files)} files") 140 | 141 | def __getitem__(self, idx): 142 | if self.metadata_file: 143 | x = self.files[idx] 144 | file, metadata = str(x[0]), x[1] 145 | else: 146 | file = str(self.files[idx]) 147 | metadata = [] 148 | try: 149 | if file.endswith('.gz'): 150 | with gzip.open(file, 'rb') as f: 151 | text = f.read().decode('utf-8') 152 | else: 153 | with open(file, "r") as f: 154 | text = f.read() 155 | except: 156 | text = "" 157 | if self.add_bos_token: 158 | text = self.bos_token + " " + text 159 | if self.anonymize: 160 | for x,y in self.anonymizer.items(): 161 | text = x.sub(y, text) 162 | token_count = len(text.split()) 163 | return file, text, token_count, metadata 164 | 165 | def __len__(self): 166 | return len(self.files) 167 | 168 | 169 | 170 | class DomainTokenized(Domain): 171 | def __init__(self, 172 | domain_directory: Path, 173 | metadata_file: Optional[Path] = None, 174 | filenames: Optional[List[str]] = None, 175 | tokenizer: Optional[GPT2Tokenizer] = None, 176 | ignore_files: Optional[List[str]] = [], 177 | **metadata_filters): 178 | """ 179 | Domain dataset with tokenization built in. 180 | """ 181 | super().__init__(domain_directory=domain_directory, 182 | metadata_file=metadata_file, 183 | filenames=filenames, 184 | ignore_files=ignore_files, 185 | **metadata_filters) 186 | if not tokenizer: 187 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 188 | self.tokenizer = tokenizer 189 | 190 | def __getitem__(self, idx) -> Tuple[str, np.array]: 191 | filename, text, token_count, metadata = super().__getitem__(idx) 192 | tokenized_text = np.array(self.tokenizer.encode(text, truncation=True)) 193 | return filename, tokenized_text, token_count, metadata 194 | 195 | 196 | class DomainVectorized(Domain): 197 | def __init__(self, 198 | domain_directory: Path, 199 | vectorizer = None, 200 | filenames: Optional[List[str]] = None, 201 | add_bos_token: bool = False, 202 | metadata_columns: List[str] = None, 203 | ignore_files: Optional[List[str]] = [], 204 | sample_by_metadata: Optional[Tuple[str, int]] = None, 205 | metadata_file: Optional[Path] = None, 206 | tokenizer = None, 207 | sample: int = None, 208 | **metadata_filters): 209 | """ 210 | Domain dataset with document vectorization built in. 211 | """ 212 | super().__init__(domain_directory=domain_directory, 213 | filenames=filenames, 214 | add_bos_token=add_bos_token, 215 | metadata_columns=metadata_columns, 216 | ignore_files=ignore_files, 217 | sample_by_metadata=sample_by_metadata, 218 | metadata_file=metadata_file, 219 | sample=sample, 220 | **metadata_filters) 221 | if not tokenizer: 222 | tokenizer = GPT2Tokenizer.from_pretrained('gpt2') 223 | self.tokenizer = tokenizer 224 | if not vectorizer: 225 | self.vectorizer = TfidfVectorizer(vocabulary=self.tokenizer.encoder, max_features=len(self.tokenizer.encoder)) 226 | else: 227 | self.vectorizer = vectorizer 228 | print("fitting vectorizer...") 229 | 230 | def __getitem__(self, idx) -> Tuple[str, np.array]: 231 | filename, text, token_count, metadata = super().__getitem__(idx) 232 | vectorized_text = self.vectorizer(tokenized_text) 233 | return filename, vectorized_text, token_count, metadata 234 | 235 | 236 | def domain_dataloader(domain_directory: Path, 237 | metadata_file: Optional[Path] = None, 238 | filenames: Optional[List[str]] = None, 239 | num_workers: int = 16, 240 | batch_size: int = 16, 241 | **metadata_filters): 242 | 243 | if tokenized: 244 | dataset = DomainTokenized(domain_directory, metadata_file, filenames, **metadata_filters) 245 | else: 246 | dataset = Domain(domain_directory, metadata_file, filenames, **metadata_filters) 247 | dataloader = DataLoader(dataset, num_workers=16, batch_size=batch_size) 248 | return dataloader 249 | -------------------------------------------------------------------------------- /domain_loader/make_splits.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gzip 3 | import json 4 | import os 5 | import pickle 6 | import sys 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from typing import List, Optional, Tuple 10 | 11 | import humanize 12 | import numpy as np 13 | import pandas as pd 14 | import torch 15 | import re 16 | from torch.utils.data import DataLoader, Dataset 17 | from tqdm import tqdm 18 | from typing import TypeVar, Iterable, List, Sequence, Union, Any 19 | 20 | 21 | from domain_loader.constants import DATA_DIR, TOKEN_COUNTS 22 | from domain_loader.domain_loader import Domain 23 | from domain_loader.utils import take_n_tokens 24 | 25 | T = TypeVar('T') 26 | 27 | def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]: 28 | assert batch_size > 0 29 | 30 | batch = [] 31 | for item in data: 32 | # Yield next batch 33 | if len(batch) == batch_size: 34 | yield batch 35 | batch = [] 36 | 37 | batch.append(item) 38 | 39 | # Yield last un-filled batch 40 | if len(batch) != 0: 41 | yield batch 42 | 43 | def get_cluster_id(clusterer, text): 44 | text = [x.replace("<|endoftext|>", "") for x in text] 45 | vec = clusterer['vectorizer'].transform(text) 46 | vec = clusterer['svd'].transform(vec) 47 | cluster_id = clusterer['kmeans'].predict(vec) 48 | return cluster_id 49 | 50 | 51 | def write_split(domain: str, 52 | output_dir: str, 53 | split: str, 54 | add_bos_token: bool, 55 | bos_token: str = "<|endoftext|>", 56 | num_workers: int = 16, 57 | batch_size: int = 16, 58 | files=None, 59 | ignore_files=[], 60 | clusterer=None, 61 | from_file=None, 62 | anonymize=False): 63 | 64 | if not from_file: 65 | dataset = Domain(DATA_DIR / domain / domain, 66 | filenames=files, 67 | add_bos_token=add_bos_token, 68 | bos_token=bos_token, 69 | ignore_files=ignore_files, 70 | anonymize=anonymize) 71 | loader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size) 72 | else: 73 | fh = open(from_file, 'r') 74 | loader = batchify(fh.read().split('<|endoftext|>'), 10000) 75 | fh.close() 76 | files = [] 77 | done = False 78 | 79 | if clusterer: 80 | print("Detected clusterer. Clustering while loading splits.") 81 | filehandle = [open(output_dir / f"{split}.{i}.txt", "w+") for i in range(clusterer['kmeans'].n_clusters)] 82 | curr_tokens = dict((i, 0) for i in range(clusterer['kmeans'].n_clusters)) 83 | else: 84 | filehandle = [open(output_dir / f"{split}.txt", "w+")] 85 | curr_tokens = {0: 0} 86 | 87 | pbar = tqdm(loader) 88 | written = False 89 | 90 | if not from_file: 91 | for fname, text,_, _ in pbar: 92 | files.extend(fname) 93 | if clusterer: 94 | if domain in ['1b', 'reddit']: 95 | tt = [t.split("<|endoftext|>") for t in text] 96 | text = ["<|endoftext|> " + y for x in tt for y in x] 97 | cluster_ids = get_cluster_id(clusterer, text) 98 | iter_ = zip(cluster_ids, text) 99 | else: 100 | iter_ = text 101 | for item in iter_: 102 | if clusterer: 103 | cluster_id = item[0] 104 | doc = item[1] 105 | else: 106 | doc = item 107 | if not doc or doc == "<|endoftext|> ": 108 | continue 109 | if clusterer: 110 | s = f"{split}, " 111 | for i, tok in curr_tokens.items(): 112 | s += f"cluster {i}: {humanize.intword(tok)} || " 113 | else: 114 | s = f"{split}, num tokens: {humanize.intword(curr_tokens[0])}" 115 | pbar.set_description(s) 116 | if sum(curr_tokens.values()) > TOKEN_COUNTS[domain][f'num_{split}_tokens']: 117 | if not written: 118 | count_ = 0 119 | item = " ".join(doc.split()[:TOKEN_COUNTS[domain][f'num_{split}_tokens']]) 120 | if clusterer: 121 | filehandle[cluster_id].write(doc.strip() + "\n") 122 | curr_tokens[cluster_id] += len(doc.split()) 123 | else: 124 | filehandle[0].write(doc.strip() + "\n") 125 | curr_tokens[0] += len(doc.split()) 126 | written = True 127 | done = True 128 | break 129 | if clusterer: 130 | filehandle[cluster_id].write(doc.strip() + "\n") 131 | curr_tokens[cluster_id] += len(doc.split()) 132 | else: 133 | filehandle[0].write(doc.strip() + "\n") 134 | curr_tokens[0] += len(doc.split()) 135 | written = True 136 | if done: 137 | break 138 | else: 139 | for text in pbar: 140 | text = ["<|endoftext|> " + x for x in text] 141 | if clusterer: 142 | cluster_ids = get_cluster_id(clusterer, text) 143 | iter_ = zip(cluster_ids, text) 144 | else: 145 | iter_ = text 146 | for item in iter_: 147 | if clusterer: 148 | cluster_id = item[0] 149 | doc = item[1] 150 | else: 151 | doc = item 152 | if not doc: 153 | continue 154 | if clusterer: 155 | s = f"{split}, " 156 | for i, tok in curr_tokens.items(): 157 | s += f"cluster {i}: {humanize.intword(tok)} || " 158 | else: 159 | s = f"{split}, num tokens: {humanize.intword(curr_tokens[0])}" 160 | pbar.set_description(s) 161 | if sum(curr_tokens.values()) > TOKEN_COUNTS[domain][f'num_{split}_tokens']: 162 | if not written: 163 | count_ = 0 164 | item = " ".join(doc.split()[:TOKEN_COUNTS[domain][f'num_{split}_tokens']]) 165 | if clusterer: 166 | filehandle[cluster_id].write(doc.strip() + "\n") 167 | curr_tokens[cluster_id] += len(doc.split()) 168 | else: 169 | filehandle[0].write(doc.strip() + "\n") 170 | curr_tokens[0] += len(doc.split()) 171 | written = True 172 | done = True 173 | break 174 | if clusterer: 175 | filehandle[cluster_id].write(doc.strip() + "\n") 176 | curr_tokens[cluster_id] += len(doc.split()) 177 | else: 178 | filehandle[0].write(doc.strip() + "\n") 179 | curr_tokens[0] += len(doc.split()) 180 | written = True 181 | if done: 182 | break 183 | for fh in filehandle: 184 | fh.close() 185 | if from_file: 186 | return None, None, curr_tokens 187 | else: 188 | return dataset.files, files, curr_tokens 189 | 190 | if __name__ == '__main__': 191 | 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--domain", default=None) 194 | parser.add_argument("--add-bos-token", action='store_true') 195 | parser.add_argument("--num-workers", type=int, default=0) 196 | parser.add_argument("--batch-size", type=int, default=16) 197 | parser.add_argument("--output-dir", type=Path, default=None) 198 | parser.add_argument("--load", type=Path, default=None) 199 | parser.add_argument("--train-files", type=Path, default=None) 200 | parser.add_argument("--dev-files", type=Path, default=None) 201 | parser.add_argument("--test-files", type=Path, default=None) 202 | parser.add_argument("--pretrain-clusters-only", action='store_true') 203 | parser.add_argument("--pretrain-clusters", nargs="+", type=str) 204 | parser.add_argument("--output-clusters", type=Path) 205 | parser.add_argument("--from-file", type=Path) 206 | parser.add_argument("--train-only", action='store_true') 207 | parser.add_argument("--dev-only", action='store_true') 208 | parser.add_argument("--test-only", action='store_true') 209 | parser.add_argument("--anonymize", action='store_true') 210 | 211 | args = parser.parse_args() 212 | domain = args.domain 213 | 214 | if args.output_dir: 215 | output_dir = args.output_dir 216 | output_dir.mkdir(exist_ok=True) 217 | 218 | clusterer=None 219 | 220 | if args.train_files: 221 | with open(args.train_files, 'r') as f: 222 | args_train_files = [x.strip() for x in f.readlines()] 223 | else: 224 | args_train_files = None 225 | 226 | if args.dev_files: 227 | with open(args.dev_files, 'r') as f: 228 | args_dev_files = [x.strip() for x in f.readlines()] 229 | else: 230 | args_dev_files = None 231 | 232 | if args.test_files: 233 | with open(args.test_files, 'r') as f: 234 | args_test_files = [x.strip() for x in f.readlines()] 235 | else: 236 | args_test_files = None 237 | 238 | if not args.from_file: 239 | resolved_path = str(DATA_DIR / domain / domain) 240 | with open(DATA_DIR / args.domain / "metadata" / "filenames.txt", 'r') as f: 241 | domain_files = [] 242 | for x in tqdm(f.readlines()): 243 | fp = x.strip() 244 | domain_files.append(fp) 245 | else: 246 | domain_files = None 247 | 248 | if args.domain in ['reddit', '1b']: 249 | add_bos_token = False 250 | num_workers = args.num_workers 251 | batch_size = args.batch_size 252 | else: 253 | add_bos_token = True 254 | num_workers = args.num_workers 255 | batch_size = args.batch_size 256 | if not args.test_only and not args.dev_only: 257 | train_files, train_files_to_ignore, num_train_tokens = write_split(args.domain, 258 | output_dir, 259 | "train", 260 | add_bos_token, 261 | num_workers=num_workers, 262 | batch_size=batch_size, 263 | files=args_train_files or domain_files, 264 | clusterer=clusterer, 265 | anonymize=args.anonymize) 266 | else: 267 | train_files = None 268 | train_files_to_ignore = None 269 | num_train_tokens = None 270 | 271 | if not args.train_only: 272 | if not args.test_only: 273 | if args.from_file: 274 | train_files_to_ignore = None 275 | dev_files, dev_files_to_ignore, num_dev_tokens = write_split( 276 | args.domain, 277 | output_dir, 278 | "dev", 279 | add_bos_token, 280 | num_workers=num_workers, 281 | batch_size=batch_size, 282 | files=args_dev_files or domain_files, 283 | ignore_files=args_train_files or train_files_to_ignore, 284 | clusterer=clusterer, 285 | from_file=args.from_file) 286 | else: 287 | dev_files = None 288 | dev_files_to_ignore = None 289 | num_dev_tokens = None 290 | if not args.dev_only: 291 | if args.from_file: 292 | train_files_to_ignore = [] 293 | dev_files_to_ignore = [] 294 | if args_train_files and args_dev_files: 295 | ignore_files = args_train_files + args_dev_files 296 | else: 297 | ignore_files = train_files_to_ignore + dev_files_to_ignore 298 | 299 | test_files, test_files_to_ignore, num_test_tokens = write_split( 300 | args.domain, 301 | output_dir, 302 | "test", 303 | add_bos_token, 304 | num_workers=num_workers, 305 | batch_size=batch_size, 306 | files=args_test_files or domain_files, 307 | ignore_files=ignore_files, 308 | clusterer=clusterer, 309 | from_file=args.from_file) 310 | 311 | else: 312 | test_files = None 313 | test_files_to_ignore = None 314 | num_test_tokens = None 315 | if train_files_to_ignore: 316 | with open(args.output_dir / "train_files.txt", "w+") as f: 317 | for file in train_files_to_ignore: 318 | f.write(str(file) + "\n") 319 | if dev_files_to_ignore: 320 | with open(args.output_dir / "dev_files.txt", "w+") as f: 321 | for file in dev_files_to_ignore: 322 | f.write(str(file) + "\n") 323 | if test_files_to_ignore: 324 | with open(args.output_dir / "test_files.txt", "w+") as f: 325 | for file in test_files_to_ignore: 326 | f.write(str(file) + "\n") 327 | 328 | 329 | print("Finished successfully.") 330 | if num_train_tokens: 331 | with open(args.output_dir / "train_token_counts.txt", 'w+') as f: 332 | json.dump(num_train_tokens, f) 333 | if num_dev_tokens: 334 | with open(args.output_dir / "dev_token_counts.txt", 'w+') as f: 335 | json.dump(num_dev_tokens, f) 336 | if num_test_tokens: 337 | with open(args.output_dir / "test_token_counts.txt", 'w+') as f: 338 | json.dump(num_test_tokens, f) 339 | 340 | print(f"Num train tokens: {humanize.intword(sum(num_train_tokens.values()))}") 341 | print(f"Num dev tokens: {humanize.intword(sum(num_dev_tokens.values()))}") 342 | print(f"Num test tokens: {humanize.intword(sum(num_test_tokens.values()))}") 343 | -------------------------------------------------------------------------------- /domain_loader/scan_filenames.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, List, Tuple 3 | 4 | from torch.utils.data import Dataset, DataLoader 5 | import torch 6 | from torch.utils.data.sampler import SubsetRandomSampler 7 | from pathlib import Path 8 | 9 | from domain_loader.constants import DATA_DIR 10 | from tqdm.auto import tqdm 11 | from domain_loader.domain_loader import Domain 12 | import argparse 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--domain", type=str) 19 | parser.add_argument('--num-workers', type=int, default=0) 20 | parser.add_argument('--batch-size', type=int, default=1) 21 | args = parser.parse_args() 22 | 23 | domain = args.domain 24 | 25 | dataset = Domain(DATA_DIR / domain / "shards") 26 | 27 | if domain in ["1b", "reddit"]: 28 | num_workers = 1 29 | batch_size = 1 30 | else: 31 | num_workers = args.num_workers 32 | batch_size = args.batch_size 33 | dataloader = DataLoader(dataset, 34 | num_workers=num_workers, 35 | batch_size=batch_size) 36 | 37 | pbar = tqdm(dataloader) 38 | curr_files = 0 39 | (DATA_DIR / domain / "metadata" ).mkdir(exist_ok=True) 40 | with open(DATA_DIR / domain / "metadata" / "filenames.txt", "w+") as f: 41 | for fname, _, _, _ in pbar: 42 | for fn in fname: 43 | f.write(fn + "\n") 44 | curr_files += 1 45 | print(f"Number of files in {str(DATA_DIR / domain / domain)}: {curr_files}") 46 | -------------------------------------------------------------------------------- /domain_loader/shard_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable, List, TypeVar 3 | from tqdm.auto import tqdm 4 | from pathlib import Path 5 | import random 6 | import gzip 7 | import os 8 | import argparse 9 | import sys 10 | 11 | T = TypeVar('T') 12 | 13 | 14 | def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]: 15 | assert batch_size > 0 16 | 17 | batch = [] 18 | for item in data: 19 | # Yield next batch 20 | if len(batch) == batch_size: 21 | yield batch 22 | batch = [] 23 | 24 | batch.append(item) 25 | 26 | # Yield last un-filled batch 27 | if len(batch) != 0: 28 | yield batch 29 | 30 | 31 | def build(texts_dir, input_fh, metadata_fh, batch_size=512, text_field='text'): 32 | num_folders, num_files = 0, 0 33 | for ix, batch in tqdm(enumerate(batchify(f, batch_size=batch_size))): 34 | num_folders += 1 35 | batch = [json.loads(x) for x in batch] 36 | for x in batch: 37 | fname = random.getrandbits(128) 38 | x['filename'] = str(Path("shards") / f"subset_{ix}" / (str(random.getrandbits(128)) + ".txt")) 39 | text = [(x['filename'], x.pop(text_field)) for x in batch if x.get(text_field)] 40 | for line in batch: 41 | json.dump(line, g) 42 | g.write("\n") 43 | subset_dir = (texts_dir / "shards" / f"subset_{ix}") 44 | subset_dir.mkdir(parents=True, exist_ok=True) 45 | for fname, line in text: 46 | with open(texts_dir / fname, 'w+') as h: 47 | h.write(line) 48 | num_files += 1 49 | return num_folders, num_files 50 | 51 | if __name__ == '__main__': 52 | data_dir = Path(os.environ['DATA_DIR']) 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--domain", type=str) 55 | parser.add_argument("--input-file", type=str) 56 | parser.add_argument("--batch-size", type=int, default=512) 57 | parser.add_argument("--text-field", type=str, default='text') 58 | 59 | 60 | args = parser.parse_args() 61 | 62 | texts_dir = data_dir / args.domain 63 | 64 | if (texts_dir / 'shards').is_dir(): 65 | sys.exit(f"dataset at {texts_dir} already sharded.") 66 | else: 67 | (texts_dir / 'metadata').mkdir(parents=True, exist_ok=True) 68 | (texts_dir / 'shards').mkdir(parents=True, exist_ok=True) 69 | 70 | if args.input_file.endswith(".gz"): 71 | with gzip.open(args.input_file, 'rb') as f, open(texts_dir / 'metadata' / 'metadata.jsonl', 'w+') as g: 72 | num_folders, num_files = build(texts_dir, f, g, args.batch_size, args.text_field) 73 | else: 74 | with open(args.input_file, 'r') as f, open(texts_dir/ 'metadata'/ 'metadata.jsonl', 'w+') as g: 75 | num_folders, num_files = build(texts_dir, f, g, args.batch_size, args.text_field) 76 | 77 | print(f"Sharded {args.input_file} into {num_folders} folders, {num_files} files, located at {texts_dir}") 78 | -------------------------------------------------------------------------------- /domain_loader/utils.py: -------------------------------------------------------------------------------- 1 | def take_n_tokens(dataloader, num_tokens): 2 | curr_num_tokens = 0 3 | for _, text in dataloader: 4 | curr_num_tokens += sum(len(x.split()) for x in text) 5 | if curr_num_tokens < num_tokens: 6 | yield curr_num_tokens, text 7 | 8 | re1 = { 9 | "regex": "[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", 10 | "repl": "<|EMAIL|>" 11 | } 12 | 13 | re2 = {"regex": "[0-9]{10}-[0-9A-Fa-f]{53}", "repl":"<|DART|>"} 14 | 15 | re3 = {"regex": "@\[[0-9]+:[0-9]+:(?=[^\]])(([^\\:\]]*(?:\\.)*)*)\]", "repl": "<|FBUSERID|>"} 16 | 17 | re4 ={"regex": "(?:(?"} 18 | 19 | re5 ={"regex": "(?:4\d{12}(?:\d{3})?|(?:5[1-5]\d{2}|222[1-9]|22[3-9]\d|2[3-6]\d{2}|27[01]\d|2720)\d{12}|3[47]\d{13}|5019\d{12}|3(?:0[0-5]|[68]\d)\d{11}|6(?:011|5\d{2})\d{12}|(?:2131|1800|35\d{3})\d{11})", "repl": ""} 20 | 21 | re6 ={"regex": "(?!(?:000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}", "repl": "<|SSN|>"} 22 | 23 | re7 = {"regex": "\d+\s(?:(?:[a-z0-9.-]+[ ]?)+\s(?:Alley|Aly|Ave(?:nue)?|Boulevard|Blvd|Br(?:anch)?|Center|Ctr|Cir(?:cle)?|Court|Ct|Crossing|Xing|Dr(?:ive)?|Est(?:ate)?|Expressway|Expy|Freeway|Fwy|Highway|Hwy|Hills|Hls|Knoll|Knl|Landing|Lndg|Lane|Ln|Manor|Mnr|Meadow|Mdw|Parkway|Pkwy|Pass|Path|Plaza|Plz|Road|Rd|Run|Sq(?:uare)?|St(?:ation|reet|a)?|Ter(?:ace)?|Trail|Trl|Turnpike|Tpke|Valley|Vly|View|Vw|Village|Vlg|Vis(?:ta)?|Walk|Way)|(?:Route|Rte|Interstate|I)[- ]?\d{1,3})(?:\s(?:Apt[\.]?|Apartment|#)[ ]?\d+[a-z]?)?(?:\s(?:[a-z-]+[ ]?)+,?(?:\s(?:AK|AL(?:aska|abama)?|AR(?:kansas|izona)?|AZ|CA(?:lifornia)?|CO(?:lorado|nnecticut)?|CT|DC|DE(?:laware)?|FL(?:orida)?|GA|Georgia|GU(?:am)?|HI|Hawaii|IA|Iowa|ID(?:aho)?|IL(?:linois)?|IN(?:diana)?|KS|Kansas|KY|Kentucky|LA|Louisiana|MA(?:ssachusetts|ryland|ine)?|MD|ME|MI(?:chigan|nnesota|ssissippi|ssouri)|MN|MO(?:ntana)?|MS|MT|NC|North[ ]Carolina|ND|North[ ]Dakota|NH|New[ ]Hampshire|NJ|New[ ]Jersey|NM|New[ ]Mexico|NV|Nevada|NY|New[ ]York|OH(?:io)?|OK(?:lahoma)?|OR(?:egon)?|PA|Pennsylvania|PR|Puerto[ ]Rico|RI|Rhode[ ]Island|SC|South[ ]Carolina|SD|South[ ]Dakota|TN|Tennessee|TX|Texas|UT(?:ah)?|VA|Virginia|VI(?:rgin[ ]Islands)?|VT|Vermont|WA(?:shington(?:[ ]D[. ]?C[.]?)?)?|WI(?:sconsin)?|WV|West[ ]Virginia|WY(?:oming)?)(?:\s\b\d{5}(?:-\d{4})?\b)?)?)?", 24 | "repl": "<|ADDRESS|>"} 25 | 26 | re8 = {"regex": "@[a-zA-Z0-9_\.\-]{1,30}", "repl": "@USER"} 27 | 28 | 29 | REGEXES = [re1,re2,re3,re4,re5,re6,re8] 30 | -------------------------------------------------------------------------------- /domains/legal/split_legal.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable, List, TypeVar 3 | from tqdm.auto import tqdm 4 | from pathlib import Path 5 | import random 6 | 7 | 8 | T = TypeVar('T') 9 | 10 | 11 | def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]: 12 | assert batch_size > 0 13 | 14 | batch = [] 15 | for item in data: 16 | # Yield next batch 17 | if len(batch) == batch_size: 18 | yield batch 19 | batch = [] 20 | 21 | batch.append(item) 22 | 23 | # Yield last un-filled batch 24 | if len(batch) != 0: 25 | yield batch 26 | 27 | 28 | if __name__ == '__main__': 29 | texts_dir = Path('legal/') 30 | 31 | files = list(Path(".").glob('*/data/data.jsonl')) 32 | for file in files: 33 | with open(file, 'r') as f, open('metadata/metadata.jsonl', 'a+') as g: 34 | for ix, batch in tqdm(enumerate(batchify(f, batch_size=512))): 35 | batch = [json.loads(x) for x in batch] 36 | for x in batch: 37 | x['filenames'] = [] 38 | for y in x['casebody']['data']['opinions']: 39 | fname = random.getrandbits(128) 40 | x['filenames'].append(str(Path(f"subset_{ix}") / (str(random.getrandbits(128)) + ".txt"))) 41 | 42 | text = [zip(x['filenames'], [x['text'] for x in x['casebody']['data']['opinions']]) for x in batch] 43 | for line in batch: 44 | json.dump(line, g) 45 | g.write("\n") 46 | (texts_dir / f"subset_{ix}").mkdir(parents=True, exist_ok=True) 47 | for x in text: 48 | for fname, line in x: 49 | with open(texts_dir / fname, 'w') as h: 50 | h.write(line) 51 | -------------------------------------------------------------------------------- /domains/openwebtext/convert_filenames.py: -------------------------------------------------------------------------------- 1 | from tqdm.auto import tqdm 2 | import pandas as pd 3 | import json 4 | from pathlib import Path 5 | from copy import copy 6 | 7 | if __name__ == '__main__': 8 | 9 | fs = {} 10 | for file in tqdm(Path("openwebtext").glob("*/*")): 11 | fs[file.stem] = str(Path(str(file.parents[0]).split('/')[1]) / file.name) 12 | with open('metadata/metadata.jsonl', 'r') as f, open('metadata/metadata.1.jsonl', 'w+') as g: 13 | for line in tqdm(f): 14 | z = json.loads(line) 15 | z['stem'] = copy(z['filename']) 16 | z['filename'] = fs[Path(z['filename']).stem] 17 | json.dump(z, g) 18 | g.write('\n') -------------------------------------------------------------------------------- /domains/openwebtext/unpack_openwebtext.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from shutil import unpack_archive, rmtree 4 | import multiprocessing as mp 5 | from itertools import chain 6 | 7 | import numpy as np 8 | import click 9 | from joblib import Parallel, delayed, dump 10 | from tqdm.auto import tqdm 11 | 12 | 13 | def text_files_in_dir(s): 14 | return [f for f in s.iterdir() if f.suffix == '.txt'] 15 | 16 | 17 | def read_text(file): 18 | return file.read_text() 19 | 20 | 21 | def shardify_openwebtext(input_dir: str, output_dir: str, n_processes: int, n_shards: int = 20): 22 | input_dir = Path(input_dir) 23 | assert input_dir.exists() 24 | 25 | output_dir = Path(output_dir) 26 | output_dir.mkdir() 27 | 28 | subset_dirs = [subset_dir for subset_dir in input_dir.iterdir() if subset_dir.is_dir()] 29 | with mp.Pool(processes=n_processes) as pool: 30 | # Get list of sorted files in OpenWebText 31 | text_files = sorted(list(chain.from_iterable(pool.map(text_files_in_dir, subset_dirs)))) 32 | 33 | # Split list of files into shards 34 | shard_files = np.array_split(text_files, n_shards) 35 | 36 | # Save shards and associated filenames 37 | for i, split in enumerate(tqdm(shard_files)): 38 | tqdm.write("Loading text in shard...") 39 | shard = pool.map(read_text, split) 40 | 41 | tqdm.write("Saving shard...") 42 | shard_name = f'owtc{i:02d}' 43 | dump(shard, output_dir / f'{shard_name}.joblib') 44 | filenames = map(lambda f: f.stem, split) 45 | with open(output_dir / f'{shard_name}_filenames.txt', 'w') as f: 46 | print(*filenames, file=f, sep='\n') 47 | 48 | 49 | @click.command() 50 | @click.option('--archive', required=True) 51 | @click.option('--n_jobs', default=16) 52 | @click.argument('out_dir') 53 | def unpack_openwebtext(archive: str, out_dir: str, n_jobs: int): 54 | out_dir = Path(out_dir) 55 | out_dir.mkdir() 56 | 57 | tmp_dir = Path(tempfile.mkdtemp(prefix='openwebtext')) 58 | print("Unpacking subset archives to", tmp_dir) 59 | unpack_archive(archive, extract_dir=tmp_dir) 60 | subset_tarfiles = [x for x in (tmp_dir / 'openwebtext').iterdir()] 61 | 62 | print("Unpacking corpus to", out_dir) 63 | subset_out_dirs = [out_dir / subset.stem for subset in subset_tarfiles] 64 | Parallel(n_jobs=n_jobs)( 65 | delayed(unpack_archive)(tarfile_path, subset_out_dir, 'xztar') 66 | for tarfile_path, subset_out_dir in zip(subset_tarfiles, subset_out_dirs) 67 | ) 68 | 69 | rmtree(tmp_dir) 70 | 71 | 72 | if __name__ == '__main__': 73 | unpack_openwebtext() 74 | -------------------------------------------------------------------------------- /domains/realnews/split_realnews.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable, List, TypeVar 3 | from tqdm.auto import tqdm 4 | from pathlib import Path 5 | import random 6 | 7 | 8 | T = TypeVar('T') 9 | 10 | 11 | def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]: 12 | assert batch_size > 0 13 | 14 | batch = [] 15 | for item in data: 16 | # Yield next batch 17 | if len(batch) == batch_size: 18 | yield batch 19 | batch = [] 20 | 21 | batch.append(item) 22 | 23 | # Yield last un-filled batch 24 | if len(batch) != 0: 25 | yield batch 26 | 27 | 28 | if __name__ == '__main__': 29 | texts_dir = Path('realnews_shards/') 30 | with open('realnews/realnews.jsonl', 'r') as f, open('metadata/metadata.jsonl', 'w+') as g: 31 | for ix, batch in tqdm(enumerate(batchify(f, batch_size=512))): 32 | batch = [json.loads(x) for x in batch] 33 | for x in batch: 34 | fname = random.getrandbits(128) 35 | x['filename'] = str(Path(f"subset_{ix}") / (str(random.getrandbits(128)) + ".txt")) 36 | 37 | text = [(x['filename'], x.pop('text')) for x in batch] 38 | 39 | for line in batch: 40 | json.dump(line, g) 41 | g.write("\n") 42 | (texts_dir / f"subset_{ix}").mkdir(parents=True, exist_ok=True) 43 | for fname, line in text: 44 | with open(texts_dir / fname, 'w') as h: 45 | h.write(line) 46 | -------------------------------------------------------------------------------- /domains/reddit/download.sh: -------------------------------------------------------------------------------- 1 | for f in /checkpoint/parlai/tasks/meena_reddit/v1/*; do 2 | echo "processing ${f}..."; 3 | file=$(basename $f); 4 | pigz -dc $f | pv | parallel --pipe -q jq -rc '"<|endoftext|> " + .context + " " + .label' | pigz > $file; 5 | done 6 | -------------------------------------------------------------------------------- /domains/reviews/download.sh: -------------------------------------------------------------------------------- 1 | pigz -dc All_Amazon_Review.json.gz | parallel --pipe -q jq -rc '"<|endoftext|>" + .reviewText' | pv | pigz > reviews.txt.gz 2 | -------------------------------------------------------------------------------- /domains/reviews/split_reviews.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Iterable, List, TypeVar 3 | from tqdm.auto import tqdm 4 | from pathlib import Path 5 | import random 6 | import gzip 7 | 8 | 9 | T = TypeVar('T') 10 | 11 | 12 | def batchify(data: Iterable[T], batch_size: int) -> Iterable[List[T]]: 13 | assert batch_size > 0 14 | 15 | batch = [] 16 | for item in data: 17 | # Yield next batch 18 | if len(batch) == batch_size: 19 | yield batch 20 | batch = [] 21 | 22 | batch.append(item) 23 | 24 | # Yield last un-filled batch 25 | if len(batch) != 0: 26 | yield batch 27 | 28 | 29 | if __name__ == '__main__': 30 | texts_dir = Path('reviews/') 31 | with gzip.open('All_Amazon_Review.json.gz', 'rb') as f, open('metadata/metadata.jsonl', 'w+') as g: 32 | for ix, batch in tqdm(enumerate(batchify(f, batch_size=512))): 33 | batch = [json.loads(x) for x in batch] 34 | for x in batch: 35 | fname = random.getrandbits(128) 36 | x['filename'] = str(Path(f"subset_{ix}") / (str(random.getrandbits(128)) + ".txt")) 37 | text = [(x['filename'], x.pop('reviewText')) for x in batch if x.get('reviewText')] 38 | 39 | for line in batch: 40 | json.dump(line, g) 41 | g.write("\n") 42 | (texts_dir / f"subset_{ix}").mkdir(parents=True, exist_ok=True) 43 | for fname, line in text: 44 | with open(texts_dir / fname, 'w') as h: 45 | h.write(line) 46 | -------------------------------------------------------------------------------- /domains/s2orc/extract_papers.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm.auto import tqdm 3 | from pathlib import Path 4 | import pandas as pd 5 | import gzip 6 | import random 7 | import json 8 | from joblib import Parallel, delayed 9 | pd.options.mode.chained_assignment = None # default='warn' 10 | import argparse 11 | 12 | def get_papers(texts_dir, file, paper_ids_to_keep, ix): 13 | with gzip.open(file, 'rb') as f, open(texts_dir / 'metadata.jsonl', 'a+') as g: 14 | papers = [] 15 | pbar = tqdm(f) 16 | for line in pbar: 17 | if not line: 18 | continue 19 | z = json.loads(line) 20 | if paper_ids_to_keep.get(z['paper_id']): 21 | pbar.set_description(f"papers to extract: {len(paper_ids_to_keep)}, written {ix} shards") 22 | _ = paper_ids_to_keep.pop(z['paper_id']) 23 | z['filename'] = str(Path(f"subset_{ix}") / (str(random.getrandbits(128)) + ".txt")) 24 | text = " ".join([paper['text'] for paper in z['body_text']]) 25 | if text: 26 | papers.append((z['filename'], text)) 27 | json.dump(z, g) 28 | g.write("\n") 29 | if len(papers) > 512: 30 | (texts_dir / f"subset_{ix}").mkdir(parents=True, exist_ok=True) 31 | for fname, paper in papers: 32 | with open(texts_dir / fname, 'w') as h: 33 | h.write(paper) 34 | pbar.set_description(f"papers to extract: {len(paper_ids_to_keep)}, written {ix} shards") 35 | ix += 1 36 | papers = [] 37 | 38 | return ix, paper_ids_to_keep 39 | 40 | 41 | def get_metadata(file, name): 42 | with gzip.open(file, 'rb') as f, open(f"metadata/{name}.jsonl", "a+") as g: 43 | for line in tqdm(f, leave=False): 44 | if not line: 45 | continue 46 | z = json.loads(line) 47 | if not z.get('mag_field_of_study'): 48 | continue 49 | if set(z['mag_field_of_study']) & fields_of_study: 50 | json.dump(z, g) 51 | f.write('\n') 52 | return metadata 53 | 54 | 55 | 56 | if __name__ == '__main__': 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--fields-of-study", type=str, nargs='+') 59 | args = parser.parse_args() 60 | 61 | fields_of_study = args.fields_of_study 62 | 63 | metadata_files = list(Path('20200705v1/full/metadata/').rglob('*')) 64 | pdf_parses = list(Path('20200705v1/full/pdf_parses/').rglob('*')) 65 | 66 | 67 | 68 | for name in fields_of_study: 69 | texts_dir = Path(f'{name}/') 70 | texts_dir.mkdir(exist_ok=True) 71 | paper_ids_to_keep = {} 72 | with open(f"metadata/{name}.jsonl", "r") as f: 73 | for line in tqdm(f): 74 | line = json.loads(line) 75 | paper_ids_to_keep[line['paper_id']] = 1 76 | ix, i = 0, 0 77 | files_pbar = tqdm(total=len(pdf_parses)) 78 | while paper_ids_to_keep and i < len(pdf_parses): 79 | ix, paper_ids_to_keep = get_papers(texts_dir, pdf_parses[i], paper_ids_to_keep, ix) 80 | i += 1 81 | files_pbar.update(1) 82 | if i >= len(pdf_parses): 83 | print(f"reached end of input") 84 | elif not paper_ids_to_keep: 85 | print(f"finished extracting all papers") 86 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: demix 2 | channels: 3 | - pytorch 4 | dependencies: 5 | - pytorch 6 | - pip 7 | - pip: 8 | - pandas 9 | - humanize 10 | - transformers 11 | - scikit-learn 12 | - fairseq 13 | - wandb 14 | - git+https://github.com/fadel/pytorch_ema 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas 2 | humanize 3 | transformers 4 | scikit-learn 5 | fairseq 6 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kernelmachine/demix-data/ab513103640c7eae8172324f309b46769798b96c/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/anonymize_file.py: -------------------------------------------------------------------------------- 1 | import re 2 | import argparse 3 | from tqdm.auto import tqdm 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--input-file', type=str) 9 | parser.add_argument('--output-file', type=str) 10 | args = parser.parse_args() 11 | re1 = { 12 | "regex": "[a-z0-9!#$%&'*+/=?^_`{|}~-]+(?:\.[a-z0-9!#$%&'*+/=?^_`{|}~-]+)*@(?:[a-z0-9](?:[a-z0-9-]*[a-z0-9])?\.)+[a-z0-9](?:[a-z0-9-]*[a-z0-9])?", 13 | "repl": "<|EMAIL|>" 14 | } 15 | 16 | re2 = {"regex": "[0-9]{10}-[0-9A-Fa-f]{53}", "repl":"<|DART|>"} 17 | 18 | re3 = {"regex": "@\[[0-9]+:[0-9]+:(?=[^\]])(([^\\:\]]*(?:\\.)*)*)\]", "repl": "<|FBUSERID|>"} 19 | 20 | re4 ={"regex": "(?:(?"} 21 | 22 | re5 ={"regex": "(?:4\d{12}(?:\d{3})?|(?:5[1-5]\d{2}|222[1-9]|22[3-9]\d|2[3-6]\d{2}|27[01]\d|2720)\d{12}|3[47]\d{13}|5019\d{12}|3(?:0[0-5]|[68]\d)\d{11}|6(?:011|5\d{2})\d{12}|(?:2131|1800|35\d{3})\d{11})", "repl": ""} 23 | 24 | re6 ={"regex": "(?!(?:000|666|9))\d{3}-(?!00)\d{2}-(?!0000)\d{4}", "repl": "<|SSN|>"} 25 | 26 | re7 = {"regex": "\d+\s(?:(?:[a-z0-9.-]+[ ]?)+\s(?:Alley|Aly|Ave(?:nue)?|Boulevard|Blvd|Br(?:anch)?|Center|Ctr|Cir(?:cle)?|Court|Ct|Crossing|Xing|Dr(?:ive)?|Est(?:ate)?|Expressway|Expy|Freeway|Fwy|Highway|Hwy|Hills|Hls|Knoll|Knl|Landing|Lndg|Lane|Ln|Manor|Mnr|Meadow|Mdw|Parkway|Pkwy|Pass|Path|Plaza|Plz|Road|Rd|Run|Sq(?:uare)?|St(?:ation|reet|a)?|Ter(?:ace)?|Trail|Trl|Turnpike|Tpke|Valley|Vly|View|Vw|Village|Vlg|Vis(?:ta)?|Walk|Way)|(?:Route|Rte|Interstate|I)[- ]?\d{1,3})(?:\s(?:Apt[\.]?|Apartment|#)[ ]?\d+[a-z]?)?(?:\s(?:[a-z-]+[ ]?)+,?(?:\s(?:AK|AL(?:aska|abama)?|AR(?:kansas|izona)?|AZ|CA(?:lifornia)?|CO(?:lorado|nnecticut)?|CT|DC|DE(?:laware)?|FL(?:orida)?|GA|Georgia|GU(?:am)?|HI|Hawaii|IA|Iowa|ID(?:aho)?|IL(?:linois)?|IN(?:diana)?|KS|Kansas|KY|Kentucky|LA|Louisiana|MA(?:ssachusetts|ryland|ine)?|MD|ME|MI(?:chigan|nnesota|ssissippi|ssouri)|MN|MO(?:ntana)?|MS|MT|NC|North[ ]Carolina|ND|North[ ]Dakota|NH|New[ ]Hampshire|NJ|New[ ]Jersey|NM|New[ ]Mexico|NV|Nevada|NY|New[ ]York|OH(?:io)?|OK(?:lahoma)?|OR(?:egon)?|PA|Pennsylvania|PR|Puerto[ ]Rico|RI|Rhode[ ]Island|SC|South[ ]Carolina|SD|South[ ]Dakota|TN|Tennessee|TX|Texas|UT(?:ah)?|VA|Virginia|VI(?:rgin[ ]Islands)?|VT|Vermont|WA(?:shington(?:[ ]D[. ]?C[.]?)?)?|WI(?:sconsin)?|WV|West[ ]Virginia|WY(?:oming)?)(?:\s\b\d{5}(?:-\d{4})?\b)?)?)?", 27 | "repl": "<|ADDRESS|>"} 28 | 29 | re8 = {"regex": "@[a-zA-Z0-9_\.\-]{1,30}", "repl": "@USER"} 30 | 31 | 32 | re_list = [re1,re2,re3,re4,re5,re6,re8] 33 | 34 | 35 | anonymizer = {re.compile(x['regex']): x['repl'] for x in re_list} 36 | 37 | with open(args.input_file, 'r') as f, open(args.output_file, 'w+') as g: 38 | for line in tqdm(f.readlines()): 39 | for x,y in anonymizer.items(): 40 | line = x.sub(y, line) 41 | g.write(line) -------------------------------------------------------------------------------- /scripts/download_example_domains.sh: -------------------------------------------------------------------------------- 1 | for domain in "ag" "amazon" "imdb" "chemprot" "rct-20k" "hyperpartisan_news"; do 2 | if [ ! -d "example_domains/$domain/" ]; then 3 | echo "Processing $domain" 4 | mkdir example_domains/$domain/ 5 | curl -Lo example_domains/$domain/train.jsonl https://s3-us-west-2.amazonaws.com/allennlp/dont_stop_pretraining/data/$domain/train.jsonl; 6 | curl -Lo example_domains/$domain/dev.jsonl https://s3-us-west-2.amazonaws.com/allennlp/dont_stop_pretraining/data/$domain/dev.jsonl; 7 | curl -Lo example_domains/$domain/test.jsonl https://s3-us-west-2.amazonaws.com/allennlp/dont_stop_pretraining/data/$domain/test.jsonl; 8 | cat example_domains/$domain/train.jsonl example_domains/$domain/dev.jsonl example_domains/$domain/test.jsonl > example_domains/$domain/$domain.jsonl; 9 | else echo "$domain already exists." 10 | fi; 11 | done; 12 | -------------------------------------------------------------------------------- /scripts/fetch_articles.py: -------------------------------------------------------------------------------- 1 | import newspaper 2 | import random 3 | import pandas as pd 4 | from tqdm.auto import tqdm 5 | import argparse 6 | from pathlib import Path 7 | from collections import defaultdict 8 | 9 | def parse(paper, i): 10 | article = paper.articles[i] 11 | article.download() 12 | try: 13 | article.parse() 14 | except Exception as e: 15 | print(e) 16 | return 17 | return {'text': article.text, 'title': article.title, 'paper': paper.url, 'url': article.url} 18 | 19 | if __name__ == '__main__': 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--num-articles-per-source', type=int, help='number of articles to fetch per source') 22 | parser.add_argument('--path-to-output', type=Path, help='path to output file') 23 | args = parser.parse_args() 24 | outputs = [] 25 | papers = [] 26 | # build papers 27 | df = pd.read_csv("corpus.tsv", sep='\t') 28 | sources = df.loc[df.fact == 'high'].source_url 29 | for source in tqdm(sources, desc="building sources"): 30 | try: 31 | papers.append(newspaper.build(source, language='en', memoize_articles=False)) 32 | except: 33 | print(f"could not build {source}, skipping...") 34 | continue 35 | # parse downloaded articles 36 | errors = defaultdict(int) 37 | for paper in tqdm(papers, desc='parsing articles'): 38 | try: 39 | random_indexes = random.choices(range(paper.size()), k=args.num_articles_per_source) 40 | for i in tqdm(random_indexes): 41 | output = parse(paper, i) 42 | if output: 43 | outputs.append(output) 44 | else: 45 | errors[paper.url] += 1 46 | except: 47 | continue 48 | 49 | pd.DataFrame(outputs).to_json(args.path_to_output, lines=True, orient='records') 50 | if not errors: 51 | print(f"Completed. No errors!") 52 | else: 53 | print(f"Completed. Errors: {dict(errors)}") 54 | -------------------------------------------------------------------------------- /scripts/multiprocessing_bpe_encoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | ## Pulled from fairseq library! https://raw.githubusercontent.com/pytorch/fairseq/master/examples/roberta/multiprocessing_bpe_encoder.py 9 | 10 | import argparse 11 | import contextlib 12 | import sys 13 | from collections import Counter 14 | from multiprocessing import Pool 15 | 16 | from fairseq.data.encoders.gpt2_bpe import get_encoder 17 | 18 | 19 | def main(): 20 | """ 21 | Helper script to encode raw text with the GPT-2 BPE using multiple processes. 22 | 23 | The encoder.json and vocab.bpe files can be obtained here: 24 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json 25 | - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe 26 | """ 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument( 29 | "--encoder-json", 30 | help="path to encoder.json", 31 | ) 32 | parser.add_argument( 33 | "--vocab-bpe", 34 | type=str, 35 | help="path to vocab.bpe", 36 | ) 37 | parser.add_argument( 38 | "--inputs", 39 | nargs="+", 40 | default=["-"], 41 | help="input files to filter/encode", 42 | ) 43 | parser.add_argument( 44 | "--outputs", 45 | nargs="+", 46 | default=["-"], 47 | help="path to save encoded outputs", 48 | ) 49 | parser.add_argument( 50 | "--keep-empty", 51 | action="store_true", 52 | help="keep empty lines", 53 | ) 54 | parser.add_argument("--workers", type=int, default=20) 55 | args = parser.parse_args() 56 | 57 | assert len(args.inputs) == len( 58 | args.outputs 59 | ), "number of input and output paths should match" 60 | 61 | with contextlib.ExitStack() as stack: 62 | inputs = [ 63 | stack.enter_context(open(input, "r", encoding="utf-8")) 64 | if input != "-" 65 | else sys.stdin 66 | for input in args.inputs 67 | ] 68 | outputs = [ 69 | stack.enter_context(open(output, "w", encoding="utf-8")) 70 | if output != "-" 71 | else sys.stdout 72 | for output in args.outputs 73 | ] 74 | 75 | encoder = MultiprocessingEncoder(args) 76 | pool = Pool(args.workers, initializer=encoder.initializer) 77 | encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100) 78 | 79 | stats = Counter() 80 | for i, (filt, enc_lines) in enumerate(encoded_lines, start=1): 81 | if filt == "PASS": 82 | for enc_line, output_h in zip(enc_lines, outputs): 83 | print(enc_line, file=output_h) 84 | else: 85 | stats["num_filtered_" + filt] += 1 86 | if i % 10000 == 0: 87 | print("processed {} lines".format(i), file=sys.stderr) 88 | 89 | for k, v in stats.most_common(): 90 | print("[{}] filtered {} lines".format(k, v), file=sys.stderr) 91 | 92 | 93 | class MultiprocessingEncoder(object): 94 | def __init__(self, args): 95 | self.args = args 96 | 97 | def initializer(self): 98 | global bpe 99 | bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe) 100 | 101 | def encode(self, line): 102 | global bpe 103 | ids = bpe.encode(line) 104 | return list(map(str, ids)) 105 | 106 | def decode(self, tokens): 107 | global bpe 108 | return bpe.decode(tokens) 109 | 110 | def encode_lines(self, lines): 111 | """ 112 | Encode a set of lines. All lines will be encoded together. 113 | """ 114 | enc_lines = [] 115 | for line in lines: 116 | line = line.strip() 117 | if len(line) == 0 and not self.args.keep_empty: 118 | return ["EMPTY", None] 119 | tokens = self.encode(line) 120 | enc_lines.append(" ".join(tokens)) 121 | return ["PASS", enc_lines] 122 | 123 | def decode_lines(self, lines): 124 | dec_lines = [] 125 | for line in lines: 126 | tokens = map(int, line.strip().split()) 127 | dec_lines.append(self.decode(tokens)) 128 | return ["PASS", dec_lines] 129 | 130 | 131 | if __name__ == "__main__": 132 | main() 133 | -------------------------------------------------------------------------------- /scripts/prepare.sh: -------------------------------------------------------------------------------- 1 | DOMAIN=$1 2 | NUM_WORKERS=$2 3 | BATCH_SIZE=$3 4 | OUTPUT_DIR=$4 5 | python -m loader.make_splits --domain $DOMAIN --num-workers $NUM_WORKERS --batch-size $BATCH_SIZE --output-dir $DOMAIN/splits-big/ 6 | bash scripts/pretokenize.sh $DOMAIN 7 | bash scripts/preprocess.sh $DOMAIN $OUTPUT_DIR 8 | -------------------------------------------------------------------------------- /scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | INPUT_DIR=$1 2 | DOMAIN=$2 3 | OUTPUT_DIR=$3 4 | 5 | fairseq-preprocess \ 6 | --only-source \ 7 | --srcdict ${DATA_DIR}/gpt2_bpe/dict.txt \ 8 | --trainpref ${INPUT_DIR}/train.txt.bpe \ 9 | --validpref ${INPUT_DIR}/dev.txt.bpe \ 10 | --testpref ${INPUT_DIR}/test.txt.bpe \ 11 | --destdir ${OUTPUT_DIR}/${DOMAIN} \ 12 | --workers 60; 13 | mv ${OUTPUT_DIR}/${DOMAIN}/valid.bin ${OUTPUT_DIR}/${DOMAIN}/valid_${DOMAIN}.bin 14 | mv ${OUTPUT_DIR}/${DOMAIN}/valid.idx ${OUTPUT_DIR}/${DOMAIN}/valid_${DOMAIN}.idx 15 | mv ${OUTPUT_DIR}/${DOMAIN}/test.bin ${OUTPUT_DIR}/${DOMAIN}/test_${DOMAIN}.bin 16 | mv ${OUTPUT_DIR}/${DOMAIN}/test.idx ${OUTPUT_DIR}/${DOMAIN}/test_${DOMAIN}.idx 17 | cp ${DATA_DIR}/gpt2_bpe/dict.txt ${OUTPUT_DIR}/dict.txt 18 | -------------------------------------------------------------------------------- /scripts/preprocess_example_domains.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | export DOMAIN=ag 5 | 6 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 7 | python -m domain_loader.scan_filenames --domain $DOMAIN 8 | python -m domain_loader.count_words --domain $DOMAIN 9 | ## set token counts for "ag" in domain_loader/constants.py 10 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 11 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 12 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 13 | 14 | 15 | export DOMAIN=imdb 16 | 17 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 18 | python -m domain_loader.scan_filenames --domain $DOMAIN 19 | python -m domain_loader.count_words --domain $DOMAIN 20 | ## set token counts for "imdb" in domain_loader/constants.py 21 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 22 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 23 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 24 | 25 | 26 | export DOMAIN=rct-20k 27 | 28 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 29 | python -m domain_loader.scan_filenames --domain $DOMAIN 30 | python -m domain_loader.count_words --domain $DOMAIN 31 | ## set token counts for "rct-20k" in domain_loader/constants.py 32 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 33 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 34 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 35 | 36 | 37 | export DOMAIN=hyperpartisan_news 38 | 39 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 40 | python -m domain_loader.scan_filenames --domain $DOMAIN 41 | python -m domain_loader.count_words --domain $DOMAIN 42 | ## set token counts for "hyperpartisan_news" in domain_loader/constants.py 43 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 44 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 45 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 46 | 47 | 48 | export DOMAIN=acl_papers 49 | 50 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 51 | python -m domain_loader.scan_filenames --domain $DOMAIN 52 | python -m domain_loader.count_words --domain $DOMAIN 53 | ## set token counts for "chemprot" in domain_loader/constants.py 54 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 55 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 56 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 57 | 58 | 59 | export DOMAIN=legal_contracts 60 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 61 | python -m domain_loader.scan_filenames --domain $DOMAIN 62 | python -m domain_loader.count_words --domain $DOMAIN 63 | ## set token counts for "rct" in domain_loader/constants.py 64 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 16 --batch-size 16 --output-dir $DATA_DIR/$DOMAIN/splits 65 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 66 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 67 | 68 | export DOMAIN=citation_intent 69 | 70 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 71 | python -m domain_loader.scan_filenames --domain $DOMAIN 72 | python -m domain_loader.count_words --domain $DOMAIN 73 | ## set token counts for "citation_intent" in domain_loader/constants.py 74 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 75 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 76 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 77 | 78 | export DOMAIN=amazon 79 | 80 | python -m domain_loader.shard_dataset --domain $DOMAIN --input-file example_domains/$DOMAIN/$DOMAIN.jsonl --batch-size 512 --text-field text 81 | python -m domain_loader.scan_filenames --domain $DOMAIN 82 | python -m domain_loader.count_words --domain $DOMAIN 83 | ## set token counts for "amazon" in domain_loader/constants.py 84 | python -m domain_loader.make_splits --domain $DOMAIN --num-workers 0 --batch-size 1 --output-dir $DATA_DIR/$DOMAIN/splits 85 | bash scripts/pretokenize.sh ${DATA_DIR}/$DOMAIN/splits 86 | bash scripts/preprocess.sh ${DATA_DIR}/$DOMAIN/splits $DOMAIN ${DATA_DIR}/data-bin/ 87 | -------------------------------------------------------------------------------- /scripts/pretokenize.sh: -------------------------------------------------------------------------------- 1 | DIR=$1 2 | for SPLIT in train dev test; do \ 3 | python -m scripts.multiprocessing_bpe_encoder \ 4 | --encoder-json ${DATA_DIR}/gpt2_bpe/encoder.json \ 5 | --vocab-bpe ${DATA_DIR}/gpt2_bpe/vocab.bpe \ 6 | --inputs ${DIR}/${SPLIT}.txt \ 7 | --outputs ${DIR}/${SPLIT}.txt.bpe \ 8 | --keep-empty \ 9 | --workers 60; 10 | done 11 | -------------------------------------------------------------------------------- /scripts/vocab_overlap.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | from sklearn.feature_extraction.text import CountVectorizer 4 | import json 5 | from collections import defaultdict 6 | import humanize 7 | from typing import List 8 | import itertools 9 | import seaborn as sns 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import matplotlib 13 | import argparse 14 | import pandas as pd 15 | from domain_loader.domain_loader import Domain 16 | from torch.utils.data import DataLoader 17 | from domain_loader.constants import DATA_DIR 18 | sns.set(context="paper", style="white", font_scale=1.4) 19 | 20 | def load_data(data_path: str, sample: int=None) -> List[str]: 21 | examples = [] 22 | with tqdm(open(data_path, "r"), desc=f"loading {data_path}", disable=sample is None) as f: 23 | for line in f: 24 | if sample: 25 | if len(examples) > sample: 26 | break 27 | line = line.strip() 28 | if line: 29 | if data_path.endswith(".jsonl") or data_path.endswith(".json"): 30 | example = json.loads(line) 31 | else: 32 | example = {"text": line} 33 | text = example['text'] 34 | if sample: 35 | if np.random.binomial(1, 0.5): 36 | examples.append(text) 37 | else: 38 | examples.append(text) 39 | if sample: 40 | examples = np.random.choice(examples, size=sample) 41 | return examples 42 | 43 | def load_text(domain, add_bos_token=False, num_workers=1, batch_size=1, num_expected_tokens=None, num_expected_docs=None): 44 | with open(DATA_DIR / domain / "splits-final" / "train_files.txt", 'r') as f: 45 | files = [x.strip() for x in tqdm(f.readlines())] 46 | np.random.shuffle(files) 47 | 48 | dataset = Domain(DATA_DIR / domain / domain, 49 | filenames=files if domain not in ['1b', 'reddit'] else None, 50 | add_bos_token=add_bos_token, 51 | track_token_count=True) 52 | 53 | loader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size) 54 | 55 | pbar = tqdm(loader) 56 | texts = [] 57 | 58 | curr_tokens = 0 59 | curr_docs = 0 60 | written = False 61 | for _, text ,token_count, _ in pbar: 62 | s = f"{domain}, num tokens: {humanize.intword(curr_tokens)}, num docs: {humanize.intword(curr_docs)}" 63 | pbar.set_description(s) 64 | if (num_expected_docs and curr_docs > num_expected_docs): 65 | texts = texts[:num_expected_docs] 66 | break 67 | if domain in ['1b', 'reddit']: 68 | tt = [t.split("<|endoftext|>") for t in text] 69 | text = [y for x in tt for y in x] 70 | token_count = [len(x.split()) for x in text] 71 | if (num_expected_tokens and curr_tokens > num_expected_tokens): 72 | if not written: 73 | text = " ".join(text)[:num_expected_tokens] 74 | texts.extend(text) 75 | else: 76 | texts = "\n".join(texts)[:num_expected_tokens] 77 | texts = texts.split('\n') 78 | curr_tokens = num_expected_tokens 79 | break 80 | else: 81 | texts.extend(text) 82 | curr_tokens += sum(token_count) 83 | curr_docs += len(text) 84 | written = True 85 | return texts, curr_tokens, curr_docs 86 | 87 | def load_vocab(loader): 88 | count_vectorizer = CountVectorizer(min_df=3, max_features=10000, stop_words="english", ngram_range=(2,2)) 89 | pbar = tqdm(text) 90 | pbar.set_description(file) 91 | count_vectorizer.fit(pbar) 92 | vocab = set(count_vectorizer.vocabulary_.keys()) 93 | return vocab 94 | 95 | 96 | if __name__ == '__main__': 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--output_plot_file", help="path to save heatmap", required=True) 99 | parser.add_argument("--output_data_file", help="path to save heatmap data", required=True) 100 | parser.add_argument("--sample", type=int, help="sample tokens", required=False) 101 | args = parser.parse_args() 102 | vocabs = {} 103 | for domain in ['1b', 'cs', 'legal', 'med', 'openwebtext', 'realnews', 'reviews', 'reddit']: 104 | texts, curr_tokens, curr_docs = load_text(domain, 105 | add_bos_token=True if domain not in ['1b', 'reddit'] else False, 106 | num_workers=16 if domain not in ['1b', 'reddit'] else 1, 107 | batch_size=16 if domain not in ['1b', 'reddit'] else 1, 108 | num_expected_tokens=args.sample, 109 | num_expected_docs=None) 110 | count_vectorizer = CountVectorizer(stop_words="english", min_df=3, ngram_range=(2,2)) 111 | count_vectorizer.fit(tqdm(texts)) 112 | vocabs[domain] = set(count_vectorizer.vocabulary_.keys()) 113 | 114 | 115 | file_pairs = itertools.combinations(list(vocabs.keys()), 2) 116 | 117 | overlaps = {} 118 | for x, y in tqdm(file_pairs): 119 | intersection = vocabs[x] & vocabs[y] 120 | union = (vocabs[x] | vocabs[y]) 121 | overlaps[x + "_" + y] = len(intersection) / len(union) 122 | 123 | data = [] 124 | 125 | z = {} 126 | for key in tqdm(overlaps.keys()): 127 | file_1, file_2 = key.split('_') 128 | if not z.get(file_1): 129 | z[file_1] = {} 130 | z[file_1][file_2] = overlaps[key] 131 | if not z.get(file_2): 132 | z[file_2] = {} 133 | z[file_2][file_1] = overlaps[key] 134 | 135 | labels = list(vocabs.keys()) 136 | 137 | for ix, key in tqdm(enumerate(z)): 138 | items = [] 139 | for subkey in labels: 140 | if not z[key].get(subkey): 141 | items.append(1.0) 142 | else: 143 | items.append(z[key][subkey]) 144 | data.append(items) 145 | 146 | data = np.array(data) * 100 147 | 148 | if args.output_data_file: 149 | print('saving data...') 150 | np.save(args.output_data_file, data) 151 | 152 | print('generating fig...') 153 | fig, ax = plt.subplots(1,1,figsize=(8,8)) 154 | sns.heatmap(data, cmap="Blues", cbar=True, annot=True, ax=ax) 155 | plt.yticks(rotation=0) 156 | plt.xticks(rotation=90) 157 | plt.tight_layout() 158 | print('saving fig...') 159 | plt.savefig(args.output_plot_file, dpi=300) 160 | --------------------------------------------------------------------------------