├── .gitignore ├── CODEOWNERS ├── LICENSE ├── README.md ├── package.bat ├── processing_scripts ├── README.md ├── ablation_dedupe │ ├── make_deduped.py │ └── make_excludes_lambada_wikitext.py ├── dedupe_train.py ├── fix_dm_math.py ├── fix_empty_lines.py ├── github_reduce.py ├── join.py ├── lang_len_analysis_pass1.py ├── lang_len_analysis_pass2.py ├── pass2_shuffle_holdout.py ├── pile_proportions_sanitycheck.py ├── profanity_analysis_pass1.py └── repack_arxiv.py ├── pytest.ini ├── requirements-dev.txt ├── setup.py ├── test └── test_deterministic.py └── the_pile ├── __init__.py ├── datasets.py ├── pile.py ├── tfds_pile.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | components/*/ 2 | __pycache__ 3 | .pypirc 4 | *build 5 | *dist 6 | *egg-info* 7 | test.py 8 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @leogao2 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Pile Replication Code 2 | 3 | ## The official website for the the Pile is [here](http://pile.eleuther.ai/). 4 | 5 | The Pile is a large, diverse, open source language modelling data set that consists of many smaller datasets combined together. The objective is to obtain text from as many modalities as possible to ensure that models trained using The Pile will have much broader generalization abilities. 6 | 7 | 8 | **This repository is for replicating or making variants of the Pile. IF YOU ARE HERE TO USE THE PILE DATASET, THIS REPO IS PROBABLY NOT WHAT YOU ARE LOOKING FOR. A copy of the Pile can be downloaded [here](https://the-eye.eu/public/AI/pile/).** 9 | 10 | | Component | Raw Size |Weight|Epochs|Effective Size|Mean Document Size| 11 | |-----------------|----------|------|-----:|--------------|------------------| 12 | |[Pile-CC](https://github.com/leogao2/commoncrawl_downloader) |227.12 GiB|18.11%| 1.0|227.12 GiB |4.33 KiB | 13 | |[PubMed Central](https://github.com/EleutherAI/pile-pubmedcentral) |90.27 GiB |14.40%| 2.0|180.55 GiB |30.55 KiB | 14 | |[Books3](https://twitter.com/theshawwn/status/1320282149329784833) |100.96 GiB|12.07%| 1.5|151.44 GiB |538.36 KiB | 15 | |[OpenWebText2](https://github.com/EleutherAI/openwebtext2) |62.77 GiB |10.01%| 2.0|125.54 GiB |3.85 KiB | 16 | |[ArXiv](https://gist.github.com/leogao2/e09b64eae3b987925ccf3b86401624c6) |56.21 GiB |8.96% | 2.0|112.42 GiB |46.61 KiB | 17 | |[Github](https://github.com/EleutherAI/github-downloader) |95.16 GiB |7.59% | 1.0|95.16 GiB |5.25 KiB | 18 | |[FreeLaw](https://github.com/thoppe/The-Pile-FreeLaw) |51.15 GiB |6.12% | 1.5|76.73 GiB |15.06 KiB | 19 | |[StackExchange](https://github.com/EleutherAI/stackexchange-dataset) |32.20 GiB |5.13% | 2.0|64.39 GiB |2.16 KiB | 20 | |[USPTO Backgrounds](https://github.com/EleutherAI/pile-uspto) |22.90 GiB |3.65% | 2.0|45.81 GiB |4.08 KiB | 21 | |[PubMed Abstracts](https://github.com/thoppe/The-Pile-PubMed) |19.26 GiB |3.07% | 2.0|38.53 GiB |1.30 KiB | 22 | |[Gutenberg (PG-19)](https://github.com/deepmind/pg19)|10.88 GiB |2.17% | 2.5|27.19 GiB |398.73 KiB | 23 | |[OpenSubtitles](https://github.com/sdtblck/Opensubtitles_dataset) |12.98 GiB |1.55% | 1.5|19.47 GiB |30.48 KiB | 24 | |[Wikipedia (en)](https://github.com/noanabeshima/wikipedia-downloader) |6.38 GiB |1.53% | 3.0|19.13 GiB |1.11 KiB | 25 | |[DM Mathematics](https://github.com/deepmind/mathematics_dataset) |7.75 GiB |1.24% | 2.0|15.49 GiB |8.00 KiB | 26 | |[Ubuntu IRC](https://github.com/EleutherAI/pile-ubuntu-irc) |5.52 GiB |0.88% | 2.0|11.03 GiB |545.48 KiB | 27 | |[BookCorpus2](https://github.com/shawwn/scrap/blob/master/epub2txt-all) |6.30 GiB |0.75% | 1.5|9.45 GiB |369.87 KiB | 28 | |[EuroParl](https://github.com/thoppe/The-Pile-EuroParl) |4.59 GiB |0.73% | 2.0|9.17 GiB |68.87 KiB | 29 | |[HackerNews](https://github.com/EleutherAI/hn-scraper) |3.90 GiB |0.62% | 2.0|7.80 GiB |4.92 KiB | 30 | |[YoutubeSubtitles](https://github.com/sdtblck/youtube_subtitle_dataset) |3.73 GiB |0.60% | 2.0|7.47 GiB |22.55 KiB | 31 | |[PhilPapers](https://github.com/thoppe/The-Pile-PhilPapers) |2.38 GiB |0.38% | 2.0|4.76 GiB |73.37 KiB | 32 | |[NIH ExPorter](https://github.com/thoppe/The-Pile-NIH-ExPORTER) |1.89 GiB |0.30% | 2.0|3.79 GiB |2.11 KiB | 33 | |[Enron Emails](https://github.com/EleutherAI/pile-enron-emails) |0.88 GiB |0.14% | 2.0|1.76 GiB |1.78 KiB | 34 | |**Total** | | | |1254.20 GiB |5.91 KiB | 35 | 36 | 37 | (Epochs refers to the number of epochs elapsed after 1.2TB) 38 | 39 | 40 | ## Usage 41 | 42 | 43 | Install: 44 | 45 | ``` 46 | pip install -e . 47 | ``` 48 | 49 | ### To replicate pile 50 | 51 | ``` 52 | python the_pile/pile.py --interleave_output 30 --using pile_reprod 53 | ``` 54 | 55 | Use the pass 2 script [here](https://github.com/EleutherAI/The-Pile/tree/master/processing_scripts) to complete shuffling. 56 | 57 | 58 | ### Other 59 | 60 | To force download all data: 61 | ``` 62 | python the_pile/pile.py --force_download 63 | ``` 64 | 65 | To generate fasttext training data for CC filtering (OWT2 only): 66 | ``` 67 | sudo apt install build-essential 68 | python the_pile/pile.py --using owt2 --make_fasttext 69 | ``` 70 | 71 | ## Manual Download Components 72 | 73 | The following components need manual downloading. Either download them or comment out from `pile.py`. 74 | 75 | - **Bibliotik**: `books3.tar.gz` needs to be in the current directory. Download temporarily unavailable. 76 | 77 | ## Workflow 78 | 79 | To propose a new dataset be added to the Pile, [open an issue](https://github.com/EleutherAI/The-Pile/issues/new). Your issue should include a description of the dataset, its size, what language(s) it is in, a link to the data, and any other relevant information. If a project manger approves your proposal, they will change its label to [![Datasets](https://img.shields.io/github/labels/EleutherAI/The-Pile/Dataset)](https://github.com/EleutherAI/The-Pile/labels/Dataset) and add it to [![Project: Datasets](https://img.shields.io/badge/Project-Datasets-lightgrey)](https://github.com/EleutherAI/The-Pile/projects/2). Datasets that we elect to not include in the current version of the Pile will receive a [![Deferred](https://img.shields.io/github/labels/EleutherAI/The-Pile/Deferred%20to%20v2)](https://github.com/EleutherAI/The-Pile/labels/Deferred%20to%20v2) or [![Declined](https://img.shields.io/github/labels/EleutherAI/The-Pile/Declined)](https://github.com/EleutherAI/The-Pile/labels/Declined) label. While we welcome multilingual datasets and plan on including non-English datasets in the future, the initial release of the Pile will be English-only and all submissions of non-English datasets will be deferred. 80 | 81 | To claim responsibility for implementing an unclaimed dataset, leave a comment on one of our unassigned issues. Once an dataset has been assigned to you, make the necessary changes to `datsets.py` and `pile.py` in a fork and submit a pull request. If you require, you can also submit a script for processing the data as shown [here](https://github.com/EleutherAI/pile_enron_emails). 82 | 83 | To raise an issue that is not proposing a new dataset, open an issue with the tag [![Feature Request](https://img.shields.io/github/labels/EleutherAI/The-Pile/Feature%20Request)](https://github.com/EleutherAI/The-Pile/labels/Feature%20Request) or [![Bug](https://img.shields.io/github/labels/EleutherAI/The-Pile/Bug)](https://github.com/EleutherAI/The-Pile/labels/Bug) as appropriate. 84 | 85 | Data ready for final implementation should meet the following criteria: 86 | 87 | - The data must be in [lm_dataformat](https://github.com/leogao2/lm_dataformat/) format. 88 | - The data must be shuffled. 89 | 90 | **In preparation for the initial release, we are no longer accepting additions to the *master* branch. If you would like to contribute a dataset, please submit the pull request to the *Version2* branch.** 91 | -------------------------------------------------------------------------------- /package.bat: -------------------------------------------------------------------------------- 1 | del dist\* /Q 2 | python setup.py sdist bdist_wheel 3 | python -m twine upload --repository pypi dist/* -------------------------------------------------------------------------------- /processing_scripts/README.md: -------------------------------------------------------------------------------- 1 | In this directory are all of the scripts, including now-obsolete or one-off ones, used for processing, analyzing, and ablating the Pile. 2 | 3 | ## Replication 4 | 5 | Replication scripts are listed in approximate order needed for replication. 6 | 7 | - `pass2_shuffle_holdout.py`: Script for pass 2 of the shuffling. The first pass is handled in Pile repo if `--interleave` is used. Pass 2 is basically going through each of the interleaved outputs and shuffling it. For more info on why this works see https://blog.janestreet.com/how-to-shuffle-a-big-dataset/. This step also creates the holdout set, from which val and test are created. 8 | - `dedupe_train.py`: This script removes all exact-match data in the held-out sets (including test and val) from the training set. This is very important because otherwise there's leakage between train and val/test. Fuzzy matching is out of the scope of this script. 9 | 10 | ## Analysis & Ablation 11 | 12 | - `lang_len_analysis_pass1.py`: Runs analysis for length in {chars, bytes, tokens, words} and language. Saves the result as .jsonl.zst files which need a second pass to aggregate, but this first pass is the more expensive one anyways, and this means we can make nice histograms and stuff. Should be run with `TOKENIZERS_PARALLELISM=false` for max performance since it prevents thread thrashing. This script would be a useful template for other future analysis. 13 | - `lang_len_analysis_pass2.py`: Pass 2 for langth/language analysis. Aggregates and makes plots. 14 | - `profanity_analysis_pass1.py`: Profanity analysis pass 1. 15 | - `ablation_dedupe/make_excludes_lambada_wikitext.py`: For ablation; detokenizes LAMBADA and wikitext in preparation for eval-dedupe. Thie script should be obsolete now; `write_out.py` in lm_evaluation_harness handles many more sets. TODO: write detailed guide on how to use `write_out.py` 16 | - `ablation_dedupe/make_deduped.py`: For ablation; performs decontamination of training data against validation/test data. Run `make_excludes_lambada_wikitext` or `write_out.py` first. TODO: clean up and make official validation-dedupe script. 17 | 18 | ## Miscellaneous 19 | 20 | - `repack_arxiv.py`: packages the arxiv tar.gz into a lmd archive. 21 | - `pile_proportions_sanitycheck.py`: shows the proportions of a sample of a Pile output to make sure the proportions are about right 22 | - `github_reduce.py`: One off script for cutting down github to a manageable size. Pile repo used to pull all 600GB of github each time but that's kinda ridiculous since we only use 95GB of it. 23 | - `join.py`: Script for joining multiple lmd archives. Much faster than actually using lmd because we're not actually parsing the json. 24 | - `fix_empty_lines.py`: One-off script for fixing extra newlines in lmd archives. Shouldn't be too useful for replication but included for completeness. -------------------------------------------------------------------------------- /processing_scripts/ablation_dedupe/make_deduped.py: -------------------------------------------------------------------------------- 1 | import lm_dataformat as lmd 2 | from glob import glob 3 | import re 4 | from tqdm import tqdm 5 | 6 | ngrams_to_exclude = set() 7 | 8 | 9 | def hash(s): 10 | return 11 | 12 | 13 | def ngrams(txt, n): 14 | pos = 0 15 | words = re.split(r'\s', txt.lower()) 16 | 17 | arr = [] 18 | for word in words: 19 | arr.append((word, pos)) 20 | pos += len(word) 21 | 22 | arr = list(filter(lambda x: x[0].strip(), arr)) 23 | 24 | for i in range(len(arr) - n): 25 | yield arr[i:i+n] 26 | 27 | 28 | for f in glob('excludes/*'): 29 | with open(f) as fh: 30 | doc = fh.read() 31 | ngs = ngrams(doc, 13) 32 | for ngram in ngs: 33 | ngrams_to_exclude.add(' '.join([x[0] for x in ngram])) 34 | 35 | 36 | def process_doc(txt, n=13): 37 | ngs = ngrams(txt, n) 38 | res = [] 39 | delspans = [] 40 | for ngram in ngs: 41 | joinedn = ' '.join([x[0] for x in ngram]) 42 | if joinedn in ngrams_to_exclude: 43 | delspans.append((max(0, ngram[0][1] - 200), min(len(txt), ngram[-1][1] + 200))) 44 | 45 | if len(delspans) == 0: return [txt] 46 | 47 | ptr = 0 48 | result = [] 49 | for l, r in delspans: 50 | if ptr < l: 51 | result.append(txt[ptr:l]) 52 | ptr = r 53 | if l <= ptr < r: 54 | ptr = r 55 | if r < ptr: 56 | raise AssertionError() 57 | 58 | result.append(txt[ptr:]) 59 | 60 | result = list(filter(lambda x: len(x) > 200, result)) 61 | 62 | if len(result) > 10: return [] 63 | 64 | return result 65 | 66 | 67 | chunk_docs = 50000 68 | 69 | 70 | dsets = [ 71 | ('output_pile', '00.jsonl.zst'), 72 | ('output_owt', '/data/datasets/openwebtext'), 73 | ] 74 | 75 | for outdir, source in dsets: 76 | ar = lmd.Archive(outdir) 77 | for i, doc in enumerate(tqdm(lmd.Reader(source).stream_data())): 78 | for piece in process_doc(doc): 79 | ar.add_data(piece) 80 | 81 | if (i + 1) % chunk_docs == 0: 82 | ar.commit() 83 | 84 | ar.commit() 85 | -------------------------------------------------------------------------------- /processing_scripts/ablation_dedupe/make_excludes_lambada_wikitext.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def wikitext_detokenizer(string): 4 | # contractions 5 | string = string.replace("s '", "s'") 6 | string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) 7 | # number separators 8 | string = string.replace(" @-@ ", "-") 9 | string = string.replace(" @,@ ", ",") 10 | string = string.replace(" @.@ ", ".") 11 | # punctuation 12 | string = string.replace(" : ", ": ") 13 | string = string.replace(" ; ", "; ") 14 | string = string.replace(" . ", ". ") 15 | string = string.replace(" ! ", "! ") 16 | string = string.replace(" ? ", "? ") 17 | string = string.replace(" , ", ", ") 18 | # double brackets 19 | string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) 20 | string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) 21 | string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) 22 | string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) 23 | string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) 24 | # miscellaneous 25 | string = string.replace("= = = =", "====") 26 | string = string.replace("= = =", "===") 27 | string = string.replace("= =", "==") 28 | string = string.replace(" " + chr(176) + " ", chr(176)) 29 | string = string.replace(" \n", "\n") 30 | string = string.replace("\n ", "\n") 31 | string = string.replace(" N ", " 1 ") 32 | string = string.replace(" 's", "'s") 33 | 34 | return string 35 | 36 | 37 | def lambada_detokenizer(text): 38 | text = text.replace("“", '"') 39 | text = text.replace("”", '"') 40 | return '\n'+text.strip() 41 | 42 | 43 | from glob import glob 44 | for f in [*glob('wikitext-2-raw/*'), *glob('wikitext-103-raw/*')]: 45 | if 'train' in f: continue 46 | with open(f) as fh,open('excludes/' + '_'.join(f.split('/')[-2:]) + '.txt', 'w') as fout: 47 | fout.write(wikitext_detokenizer(fh.read())) 48 | 49 | import json 50 | with open('lambada_test.jsonl') as fh, open('excludes/lambada.txt', 'w') as fout: 51 | for line in fh: 52 | fout.write(json.loads(line)['text'] + '\n') -------------------------------------------------------------------------------- /processing_scripts/dedupe_train.py: -------------------------------------------------------------------------------- 1 | from the_pile.utils import * 2 | import random 3 | import os 4 | from tqdm import tqdm 5 | import shutil 6 | import zstandard 7 | import io 8 | import parse 9 | 10 | import sys 11 | 12 | 13 | from glob import glob 14 | 15 | 16 | 17 | def readf(f): 18 | with open(f, 'rb') as fh: 19 | cctx = zstandard.ZstdDecompressor() 20 | reader = io.BufferedReader(cctx.stream_reader(fh)) 21 | yield from tqdm(reader) 22 | 23 | 24 | def writef(f, lines): 25 | with open(f, 'wb') as fh: 26 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 27 | compressor = cctx.stream_writer(fh) 28 | for line in tqdm(lines): 29 | compressor.write(line) 30 | compressor.flush(zstandard.FLUSH_FRAME) 31 | 32 | seen = set() 33 | 34 | if os.path.exists('hashes.txt'): 35 | with open('hashes.txt') as fh: 36 | for line in tqdm(fh): 37 | seen.add(line.strip()) 38 | else: 39 | hashf = open('hashes.txt', 'w') 40 | for f in tqdm(glob('/mnt/data/pile_holdout/*.zst')): 41 | for doc in readf(f): 42 | hash = sha256str(doc) 43 | hashf.write(hash + '\n') 44 | seen.add(hash) 45 | hashf.close() 46 | 47 | os.makedirs('train', exist_ok=True) 48 | 49 | for f in tqdm(glob('train/*')): 50 | def filtered_docs(): 51 | removed = 0 52 | for doc in readf(f): 53 | hash = sha256str(doc) 54 | if hash in seen: 55 | removed += 1 56 | if removed % 1000 == 0: 57 | print(removed) 58 | else: 59 | yield doc 60 | writef('train2/' + f.split('/')[-1], filtered_docs()) 61 | 62 | -------------------------------------------------------------------------------- /processing_scripts/fix_dm_math.py: -------------------------------------------------------------------------------- 1 | from the_pile.utils import * 2 | import random 3 | import os 4 | from tqdm import tqdm 5 | import shutil 6 | import zstandard 7 | import io 8 | import json 9 | 10 | import sys 11 | 12 | 13 | f = sys.argv[1] 14 | 15 | fout = 'tmp' 16 | 17 | def readf(f): 18 | with open(f, 'rb') as fh: 19 | cctx = zstandard.ZstdDecompressor() 20 | reader = io.BufferedReader(cctx.stream_reader(fh)) 21 | yield from tqdm(reader) 22 | 23 | 24 | def writef(f, lines): 25 | with open(f, 'wb') as fh: 26 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 27 | compressor = cctx.stream_writer(fh) 28 | for line in tqdm(lines): 29 | compressor.write(line) 30 | compressor.flush(zstandard.FLUSH_FRAME) 31 | 32 | 33 | def despace(x): 34 | res = [] 35 | for i, char in enumerate(x): 36 | if i % 2 == 1: 37 | if char != '\n': 38 | print(x) 39 | raise AssertionError() 40 | else: 41 | res.append(char) 42 | 43 | return ''.join(res) 44 | 45 | 46 | def fix(x): 47 | # optimization 48 | if b'DM Mathematics' not in x: return x 49 | 50 | ob = json.loads(x) 51 | 52 | if ob['meta']['pile_set_name'] != 'DM Mathematics': return x 53 | 54 | ob['text'] = despace(ob['text']) 55 | 56 | return (json.dumps(ob).strip() + '\n').encode('utf-8') 57 | 58 | 59 | writef(fout, map(fix, readf(f))) 60 | sh('mv tmp ' + f) 61 | 62 | -------------------------------------------------------------------------------- /processing_scripts/fix_empty_lines.py: -------------------------------------------------------------------------------- 1 | from the_pile.utils import * 2 | import random 3 | import os 4 | from tqdm import tqdm 5 | import shutil 6 | import zstandard 7 | import io 8 | import parse 9 | 10 | import sys 11 | 12 | 13 | f = sys.argv[1] 14 | 15 | fout = 'tmp' 16 | 17 | def readf(f): 18 | with open(f, 'rb') as fh: 19 | cctx = zstandard.ZstdDecompressor() 20 | reader = io.BufferedReader(cctx.stream_reader(fh)) 21 | yield from tqdm(reader) 22 | 23 | 24 | def writef(f, lines): 25 | with open(f, 'wb') as fh: 26 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 27 | compressor = cctx.stream_writer(fh) 28 | for line in tqdm(lines): 29 | compressor.write(line) 30 | compressor.flush(zstandard.FLUSH_FRAME) 31 | 32 | 33 | def cont(x): 34 | return x.strip() 35 | 36 | 37 | writef(fout, filter(cont, readf(f))) 38 | sh('mv tmp ' + f) 39 | -------------------------------------------------------------------------------- /processing_scripts/github_reduce.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import lm_dataformat as lmd 3 | import random 4 | 5 | 6 | random.seed(42) 7 | 8 | def utf8len(s): 9 | return len(s.encode('utf-8')) 10 | 11 | out = lmd.Archive('github_min') 12 | 13 | n = 0 14 | size = 0 15 | for data, meta in tqdm(filter(lambda x: len(x[0]) < 100000, lmd.Reader('components/github/github.jsonl.zst.tar').stream_data(get_meta=True)), total=56626342): 16 | if random.random() < 0.16: 17 | out.add_data(data, meta) 18 | n += 1 19 | size += utf8len(data) 20 | 21 | out.commit() 22 | 23 | print('size', size) 24 | print('ndocs', n) -------------------------------------------------------------------------------- /processing_scripts/join.py: -------------------------------------------------------------------------------- 1 | from the_pile.utils import * 2 | import random 3 | import os 4 | from tqdm import tqdm 5 | import shutil 6 | import zstandard 7 | import io 8 | import parse 9 | 10 | def readf(f): 11 | with open(f, 'rb') as fh: 12 | cctx = zstandard.ZstdDecompressor() 13 | reader = io.BufferedReader(cctx.stream_reader(fh)) 14 | lines = [] 15 | for line in tqdm(reader): 16 | lines.append(line) 17 | 18 | return lines 19 | 20 | 21 | def writef(f, lines): 22 | with open(f, 'wb') as fh: 23 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 24 | compressor = cctx.stream_writer(fh) 25 | for line in tqdm(lines): compressor.write(line) 26 | compressor.flush(zstandard.FLUSH_FRAME) 27 | 28 | import sys 29 | lines = [] 30 | for f in sys.argv[2:]: 31 | print(f) 32 | lines.extend(readf(f)) 33 | 34 | writef(sys.argv[1], lines) 35 | -------------------------------------------------------------------------------- /processing_scripts/lang_len_analysis_pass1.py: -------------------------------------------------------------------------------- 1 | import lm_dataformat as lmd 2 | from glob import glob 3 | import os 4 | import json 5 | import collections 6 | from tqdm import tqdm 7 | 8 | import transformers 9 | import re 10 | from best_download import download_file 11 | import fasttext 12 | 13 | import zstandard 14 | import multiprocessing as mp 15 | 16 | 17 | in_path = 'pile' 18 | out_path = 'langlen_stage1' 19 | 20 | 21 | def lengths(doc): 22 | global tok 23 | return { 24 | 'len_char': len(doc), 25 | 'len_utf8bytes': len(doc.encode('utf-8')), 26 | 'len_words': len(re.split(r'\s+', doc)), 27 | 'len_tokens': len(tok.encode(doc)), 28 | } 29 | 30 | 31 | download_file('https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin', 'lid.176.bin', '7e69ec5451bc261cc7844e49e4792a85d7f09c06789ec800fc4a44aec362764e') 32 | 33 | 34 | def language(doc): 35 | global langdet 36 | details = langdet.predict(doc.replace('\n', ' '), k=1) 37 | 38 | return { 39 | 'lang': details[0][0].replace('__label__', '') 40 | } 41 | 42 | 43 | def writef(f, lines): 44 | with open(f, 'wb') as fh: 45 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 46 | compressor = cctx.stream_writer(fh) 47 | for line in tqdm(lines): 48 | compressor.write(line) 49 | compressor.flush(zstandard.FLUSH_FRAME) 50 | 51 | 52 | def analyze(ob): 53 | doc, meta = ob 54 | res = { 55 | 'pile_set_name': meta['pile_set_name'] 56 | } 57 | for metric in metrics: 58 | res = {**res, **metric(doc)} 59 | return json.dumps(res).encode('utf-8') 60 | 61 | 62 | metrics = [ 63 | lengths, 64 | language, 65 | ] 66 | 67 | def init_process(): 68 | global langdet 69 | global tok 70 | 71 | langdet = fasttext.load_model("lid.176.bin") 72 | tok = transformers.GPT2TokenizerFast.from_pretrained('gpt2') 73 | 74 | 75 | pool = mp.Pool(30, initializer=init_process) 76 | 77 | 78 | for f in tqdm(sorted(glob(in_path + '/*'))): 79 | if os.path.exists(out_path + '/analysis_' + f.split('/')[-1]): continue 80 | def meta_items(): 81 | rdr = lmd.Reader(f) 82 | return pool.imap(analyze, rdr.stream_data(get_meta=True)) 83 | 84 | writef(out_path + '/tmp_analysis_' + f.split('/')[-1], meta_items()) 85 | os.rename(out_path + '/tmp_analysis_' + f.split('/')[-1], out_path + '/analysis_' + f.split('/')[-1]) 86 | -------------------------------------------------------------------------------- /processing_scripts/lang_len_analysis_pass2.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import random 3 | import os 4 | from tqdm import tqdm 5 | import shutil 6 | import zstandard 7 | import json 8 | import io 9 | 10 | import sys 11 | import math 12 | 13 | 14 | def readf(f): 15 | with open(f, 'rb') as fh: 16 | cctx = zstandard.ZstdDecompressor() 17 | reader = io.BufferedReader(cctx.stream_reader(fh)) 18 | yield from reader 19 | 20 | 21 | rewritenames = { 22 | 'CommonCrawl': 'Pile-CC', 23 | 'Bibliotik': 'Books3', 24 | 'USPTO': 'USPTO Backgrounds', 25 | 'BookCorpus': 'BookCorpus2', 26 | } 27 | 28 | 29 | def rewrite_name(n): 30 | if n in rewritenames: return rewritenames[n] 31 | 32 | return n 33 | 34 | import collections 35 | 36 | dat = collections.defaultdict(list) 37 | 38 | set_names = set() 39 | 40 | 41 | for f in tqdm(glob('langlen_stage1/*')[:1]): 42 | # forgot to add \n in stage1 >.> 43 | for x in map(lambda x: x + b'}', (list(next(readf(f)).split(b'}'))[:-1])): 44 | ob = json.loads(x) 45 | setname = rewrite_name(ob['pile_set_name']) 46 | #print(ob) 47 | for attr in ['len_char', 'len_utf8bytes', 'len_words', 'len_tokens', 'lang']: 48 | dat[(setname, attr)].append(ob[attr]) 49 | dat[('Pile', attr)].append(ob[attr]) 50 | set_names.add(setname) 51 | 52 | if ob['len_tokens'] > 0: 53 | dat[(setname, 'bytes_per_token')].append(ob['len_utf8bytes'] / ob['len_tokens']) 54 | dat[('Pile', 'bytes_per_token')].append(ob['len_utf8bytes'] / ob['len_tokens']) 55 | 56 | dat[(setname, 'words_per_token')].append(ob['len_words'] / ob['len_tokens']) 57 | dat[('Pile', 'words_per_token')].append(ob['len_words'] / ob['len_tokens']) 58 | 59 | set_names = list(set_names) 60 | set_names.append('Pile') 61 | 62 | def mean(x): 63 | return sum(x) / len(x) 64 | 65 | def stddev(x): 66 | mu = mean(x) 67 | return math.sqrt(mean([(v - mu) ** 2 for v in x])) 68 | 69 | def freqs(x): 70 | ret = collections.defaultdict(int) 71 | for v in x: 72 | ret[v] += 1 73 | 74 | return ret 75 | 76 | def filter_freqs(x, minpass): 77 | total = sum(x.values()) 78 | torm = [] 79 | for k, v in x.items(): 80 | if v / total < minpass: 81 | torm.append(k) 82 | 83 | for k in torm: del x[k] 84 | 85 | return x 86 | 87 | nicename = { 88 | 'len_char': 'Length in characters', 89 | 'len_utf8bytes': 'Length in bytes', 90 | 'len_words': 'Length in words', 91 | 'len_tokens': 'Length in tokens', 92 | 'bytes_per_token': 'Mean bytes per token', 93 | 'words_per_token': 'Mean words per token', 94 | 'lang': 'Language' 95 | } 96 | 97 | 98 | import matplotlib.pyplot as plt 99 | import numpy as np 100 | 101 | def histogram(x, sname, attr): 102 | plt.clf() 103 | plt.cla() 104 | plt.hist(x, density=True, bins=100) 105 | #plt.ylabel('Probability Density') 106 | plt.xlabel('{} ({})'.format(nicename[attr], sname)) 107 | plt.savefig('figures/analysis_{}_{}.png'.format(sname, attr),bbox_inches='tight') 108 | 109 | 110 | def barplot(d, sname, attr, normalize=True, yerr=False): 111 | x, y = zip(*sorted(d.items(), key=lambda x: x[1], reverse=True)) 112 | yerrs = None 113 | if yerr: 114 | yerrs = [v[1] for v in y] 115 | y = [v[0] for v in y] 116 | if normalize: 117 | total = sum(d.values()) 118 | y = [val / total for val in y] 119 | plt.clf() 120 | plt.cla() 121 | if yerr: 122 | plt.errorbar(x, y, yerr=yerrs, fmt='o') 123 | plt.xticks(rotation=45, ha="right") 124 | 125 | #ymin = None 126 | #ymax = None 127 | 128 | #if attr == 'len_char': 129 | # ymin, ymax = -30000, 1200000 130 | #if attr == 'len_tokens': 131 | # ymin, ymax = -30000, 300000 132 | #if attr == 'len_utf8bytes': 133 | # ymin, ymax = -30000, 1200000 134 | #if ymin and ymax: 135 | # axes = plt.gca() 136 | # axes.set_ylim([ymin,ymax]) 137 | else: 138 | plt.bar(x, y) 139 | #plt.ylabel('Proportion') 140 | # plt.xlabel('{} ({})'.format(nicename[attr], sname)) 141 | plt.xlabel('Pile component') 142 | plt.ylabel(nicename[attr]) 143 | plt.savefig('figures/analysis_{}_{}.png'.format(sname, attr),bbox_inches='tight', dpi=600) 144 | 145 | 146 | def format_freqs(d): 147 | res = [] 148 | total = sum(d.values()) 149 | for k,v in sorted(d.items(), key=lambda x: -x[1]): 150 | res.append(' {}: {:2f}%'.format(k, v / total * 100)) 151 | return '\n'.join(res) 152 | 153 | 154 | def rm_outliers_trunc_1p(x): 155 | x = list(sorted(x)) 156 | return x[:len(x)*99//100] 157 | 158 | 159 | summary = collections.defaultdict(dict) 160 | 161 | print('bytes per token, all:', sum(dat[('Pile', 'len_utf8bytes')]) / sum(dat[('Pile', 'len_tokens')])) 162 | 163 | for sname in set_names: 164 | print('**' + sname + '**') 165 | for attr in ['len_char', 'len_utf8bytes', 'len_words', 'len_tokens', 'bytes_per_token', 'words_per_token']: 166 | mu, sigma = mean(dat[(sname,attr)]), stddev(dat[(sname,attr)]) 167 | print('{}: {:.4f}±{:.4f}'.format(nicename[attr], mu, sigma)) 168 | #histogram(rm_outliers_trunc_1p(dat[(sname,attr)]), sname, attr) 169 | if sname != 'Pile' and (sname != 'Ubuntu IRC' or 'len_' not in attr): summary[attr][sname] = (mu, sigma) 170 | 171 | #barplot(filter_freqs(freqs(dat[(sname,'lang')]), 0.001), sname, 'lang') 172 | 173 | print('Langs:') 174 | print(format_freqs(freqs(dat[(sname, 'lang')]))) 175 | 176 | 177 | for attr in ['len_char', 'len_utf8bytes', 'len_words', 'len_tokens', 'bytes_per_token', 'words_per_token']: 178 | barplot(summary[attr], 'overview', attr, normalize=False, yerr=True) -------------------------------------------------------------------------------- /processing_scripts/pass2_shuffle_holdout.py: -------------------------------------------------------------------------------- 1 | from the_pile.utils import * 2 | import random 3 | import os 4 | from tqdm import tqdm 5 | import shutil 6 | import zstandard 7 | import io 8 | import parse 9 | 10 | random.seed(42) 11 | 12 | for f in ls('pile_pass1'): 13 | f = [x for x in ls(f) if 'current_chunk_incomplete' not in x] 14 | if len(f) == 1: break 15 | 16 | f, = f 17 | print(f) 18 | 19 | fout = 'pile_output/' + parse.parse('{}chunk{}/{}', f)[1] + '.jsonl.zst' 20 | fouth = 'pile_holdout/' + parse.parse('{}chunk{}/{}', f)[1] + '.jsonl.zst' 21 | 22 | with open(f, 'rb') as fh, open(fout, 'wb') as fout, open(fouth, 'wb') as fouth: 23 | print('reading') 24 | cctx = zstandard.ZstdDecompressor() 25 | reader = io.BufferedReader(cctx.stream_reader(fh)) 26 | lines = [] 27 | for line in tqdm(reader): 28 | lines.append(line) 29 | 30 | random.shuffle(lines) 31 | 32 | pivot = int(len(lines) * 0.01) 33 | 34 | holdout = lines[:pivot] 35 | lines = lines[pivot:] 36 | 37 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 38 | compressor = cctx.stream_writer(fout) 39 | print('writing') 40 | for line in tqdm(lines): compressor.write(line) 41 | compressor.flush(zstandard.FLUSH_FRAME) 42 | 43 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 44 | compressor = cctx.stream_writer(fouth) 45 | print('writing holdout') 46 | for line in tqdm(holdout): compressor.write(line) 47 | compressor.flush(zstandard.FLUSH_FRAME) 48 | 49 | del lines 50 | 51 | rm_if_exists(f) -------------------------------------------------------------------------------- /processing_scripts/pile_proportions_sanitycheck.py: -------------------------------------------------------------------------------- 1 | from the_pile.utils import * 2 | import lm_dataformat as lmd 3 | import random 4 | import os 5 | from tqdm import tqdm 6 | import shutil 7 | 8 | 9 | fnames = ls('pile_output')[:10] 10 | 11 | import collections 12 | 13 | total_bytes = 0 14 | bytes_per_subset = collections.defaultdict(int) 15 | i = 0 16 | 17 | for f in fnames: 18 | rdr = lmd.Reader(f) 19 | for doc, meta in rdr.stream_data(get_meta=True): 20 | size = utf8len(doc) 21 | bytes_per_subset[meta['pile_set_name']] += size 22 | total_bytes += size 23 | 24 | i += 1 25 | if i % 10000 == 0: 26 | l = list(bytes_per_subset.items()) 27 | l.sort(key=lambda x: -x[1]) 28 | for s, n in l: 29 | print(s, n / total_bytes) 30 | print('==============') 31 | -------------------------------------------------------------------------------- /processing_scripts/profanity_analysis_pass1.py: -------------------------------------------------------------------------------- 1 | import lm_dataformat as lmd 2 | from glob import glob 3 | import os 4 | import json 5 | import collections 6 | from tqdm import tqdm 7 | 8 | import re 9 | from best_download import download_file 10 | import fasttext 11 | 12 | import zstandard 13 | import multiprocessing as mp 14 | from profanity_check import predict 15 | 16 | in_path = 'pile' 17 | out_path = 'prof_analysis' 18 | 19 | # From https://stackoverflow.com/a/31505798 20 | import re 21 | alphabets= "([A-Za-z])" 22 | prefixes = "(Mr|St|Mrs|Ms|Dr)[.]" 23 | suffixes = "(Inc|Ltd|Jr|Sr|Co)" 24 | starters = "(Mr|Mrs|Ms|Dr|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)" 25 | acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)" 26 | websites = "[.](com|net|org|io|gov)" 27 | 28 | def split_into_sentences(text): 29 | text = " " + text + " " 30 | text = text.replace("\n"," ") 31 | text = re.sub(prefixes,"\\1",text) 32 | text = re.sub(websites,"\\1",text) 33 | if "Ph.D" in text: text = text.replace("Ph.D.","PhD") 34 | text = re.sub("\s" + alphabets + "[.] "," \\1 ",text) 35 | text = re.sub(acronyms+" "+starters,"\\1 \\2",text) 36 | text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1\\2\\3",text) 37 | text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1\\2",text) 38 | text = re.sub(" "+suffixes+"[.] "+starters," \\1 \\2",text) 39 | text = re.sub(" "+suffixes+"[.]"," \\1",text) 40 | text = re.sub(" " + alphabets + "[.]"," \\1",text) 41 | if "”" in text: text = text.replace(".”","”.") 42 | if "\"" in text: text = text.replace(".\"","\".") 43 | if "!" in text: text = text.replace("!\"","\"!") 44 | if "?" in text: text = text.replace("?\"","\"?") 45 | text = text.replace(".",".") 46 | text = text.replace("?","?") 47 | text = text.replace("!","!") 48 | text = text.replace("",".") 49 | 50 | # return quotes to normal 51 | text = text.replace("\".", ".\"") 52 | text = text.replace("”.", ".”") 53 | text = text.replace("\"!", "!\"") 54 | text = text.replace("”!", "!”") 55 | text = text.replace("\"?", "?\"") 56 | text = text.replace("”?", "?”") 57 | sentences = text.split("") 58 | sentences = sentences[:-1] 59 | sentences = [s.strip() for s in sentences] 60 | return sentences 61 | 62 | from best_download import download_file 63 | import fasttext 64 | download_file('https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin', 'lid.176.bin', '7e69ec5451bc261cc7844e49e4792a85d7f09c06789ec800fc4a44aec362764e') 65 | 66 | langdet = fasttext.load_model("lid.176.bin") 67 | 68 | def language(doc): 69 | details = langdet.predict(doc.replace('\n', ' '), k=1) 70 | 71 | return { 72 | 'lang': details[0][0].replace('__label__', '') 73 | } 74 | 75 | def is_english(doc): return doc != '' and language(doc) == 'en' 76 | 77 | def words(sent): return re.split(r'\s+', sent) 78 | 79 | def join(arr): 80 | ret = [] 81 | for val in arr: 82 | ret.extend(val) 83 | 84 | return ret 85 | 86 | def unjoin(arr, lens): 87 | ret = [] 88 | 89 | last = 0 90 | for l in lens: 91 | ret.append(arr[last:last+l]) 92 | last += l 93 | assert last == len(arr) 94 | 95 | return ret 96 | 97 | 98 | def is_profane(docs): 99 | if len(docs) == 0: return [] 100 | return list(map(int, predict(docs))) 101 | 102 | 103 | def profanity(doc): 104 | sents = list(filter(is_english, split_into_sentences(doc))) 105 | p_sents = is_profane(sents) 106 | 107 | sentwords = list(map(words, sents)) 108 | sentlens = list(map(len, sentwords)) 109 | 110 | lwords = join(sentwords) 111 | p_words = list(map(is_profane, lwords)) 112 | p_words = unjoin(p_words, sentlens) 113 | n_prof = list(map(sum, p_words)) 114 | 115 | res = list(zip(pred, sentlens, n_prof)) 116 | return { 117 | 'sentences': res, 118 | 'num_bytes': len(doc.encode('utf-8')) 119 | } 120 | 121 | def writef(f, lines): 122 | with open(f, 'wb') as fh: 123 | cctx = zstandard.ZstdCompressor(level=3, threads=8) 124 | compressor = cctx.stream_writer(fh) 125 | for line in tqdm(lines): 126 | compressor.write(line) 127 | compressor.flush(zstandard.FLUSH_FRAME) 128 | 129 | 130 | def analyze(ob): 131 | doc, meta = ob 132 | res = { 133 | 'pile_set_name': meta['pile_set_name'] 134 | } 135 | for metric in metrics: 136 | res = {**res, **metric(doc)} 137 | return json.dumps(res).encode('utf-8') 138 | 139 | 140 | metrics = [ 141 | profanity 142 | ] 143 | 144 | pool = mp.Pool(24) 145 | 146 | 147 | for f in tqdm(sorted(glob(in_path + '/*'))): 148 | if os.path.exists(out_path + '/analysis_' + f.split('/')[-1]): continue 149 | def meta_items(): 150 | rdr = lmd.Reader(f) 151 | return pool.imap(analyze, rdr.stream_data(get_meta=True)) 152 | 153 | writef(out_path + '/tmp_analysis_' + f.split('/')[-1], meta_items()) 154 | os.rename(out_path + '/tmp_analysis_' + f.split('/')[-1], out_path + '/analysis_' + f.split('/')[-1]) -------------------------------------------------------------------------------- /processing_scripts/repack_arxiv.py: -------------------------------------------------------------------------------- 1 | import lm_dataformat as lmd 2 | import os 3 | import hashlib 4 | import re 5 | from tqdm import tqdm 6 | 7 | def sha256str(s): 8 | h = hashlib.sha256() 9 | h.update(s) 10 | return h.hexdigest() 11 | 12 | def stableorder(x): 13 | arr = [(elem, sha256str(elem.encode('utf-8'))) for elem in x] 14 | arr.sort(key=lambda x: x[1]) 15 | return [elem for elem,_ in arr] 16 | 17 | def ls(x): 18 | return [x + '/' + fn for fn in stableorder(os.listdir(x))] 19 | 20 | def fread(fname): 21 | with open(fname) as fh: 22 | return fh.read() 23 | 24 | def strip_markdown_colons(x): 25 | return re.sub(r'^:::.*?\n', '', x, flags=re.MULTILINE) 26 | 27 | def compose(*fs): 28 | def _f(x): 29 | for f in reversed(fs): 30 | x = f(x) 31 | return x 32 | 33 | return _f 34 | 35 | 36 | ar = lmd.Archive('arxiv_lmd') 37 | 38 | for doc in map(compose(strip_markdown_colons, fread), tqdm(ls('documents'))): 39 | ar.add_data(doc) 40 | 41 | ar.commit() -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | testpaths = test 3 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | lm_dataformat 2 | tqdm 3 | gdown 4 | concurrent_iterator 5 | pytablewriter 6 | gitpython 7 | fasttext 8 | best-download 9 | gsutil 10 | virtualenv 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | import os 3 | from io import open as io_open 4 | 5 | src_dir = os.path.abspath(os.path.dirname(__file__)) 6 | 7 | with open("README.md", "r") as fh: 8 | long_description = fh.read() 9 | 10 | # Build requirements 11 | extras_require = {} 12 | requirements_dev = os.path.join(src_dir, 'requirements-dev.txt') 13 | with io_open(requirements_dev, mode='r') as fd: 14 | extras_require['dev'] = [i.strip().split('#', 1)[0].strip() 15 | for i in fd.read().strip().split('\n')] 16 | 17 | setuptools.setup( 18 | name="the-pile", 19 | version="0.0.1", 20 | author="EleutherAI", 21 | author_email="leogao31@gmail.com", 22 | description="The Pile is a large, diverse, open source language modelling data set that consists of many smaller datasets combined together.", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | url="https://github.com/EleutherAI/the-pile", 26 | classifiers=[ 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | "Operating System :: OS Independent", 30 | ], 31 | 32 | python_requires='>=3.6', 33 | extras_require=extras_require, 34 | packages=['the_pile'], 35 | package_data={'the_pile': ['LICENCE', 'requirements-dev.txt']}, 36 | ) 37 | -------------------------------------------------------------------------------- /test/test_deterministic.py: -------------------------------------------------------------------------------- 1 | import the_pile 2 | import hashlib 3 | 4 | 5 | def limit_iter(x, ct): 6 | for i in range(ct): 7 | yield next(x) 8 | 9 | 10 | def test_deterministic(): 11 | total_docs = 100000 12 | 13 | # remember to update this hash every time the pile is modified 14 | expected = 'e197f27c3061a73123277cd79d641681b6abc92b2ea9e69710c045bb73bd8b28' 15 | 16 | # run twice just to make sure it doesn't change 17 | for i in range(2): 18 | h1 = hashlib.sha256() 19 | pile = the_pile.pile() 20 | for doc in limit_iter(pile.documents(), total_docs): 21 | h1.update(doc.encode('utf-8')) 22 | assert h1.hexdigest() == expected -------------------------------------------------------------------------------- /the_pile/__init__.py: -------------------------------------------------------------------------------- 1 | from the_pile.pile import ThePile 2 | from the_pile.datasets import * 3 | import hashlib 4 | 5 | def pile(): 6 | return ThePile() -------------------------------------------------------------------------------- /the_pile/datasets.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import os 3 | import json 4 | 5 | import gdown 6 | import lm_dataformat as lmd 7 | from tqdm import tqdm 8 | 9 | from .utils import * 10 | 11 | class Dataset(abc.ABC): 12 | @abc.abstractmethod 13 | def name(self): 14 | """ Human-readable name of the dataset """ 15 | pass 16 | 17 | @abc.abstractmethod 18 | def documents(self): 19 | """ A generator producing all documents in the dataset. """ 20 | pass 21 | 22 | @abc.abstractmethod 23 | def clean(self): 24 | """ Remove any dataset files. """ 25 | pass 26 | 27 | def size(self): 28 | """ Return an estimate of the dataset size. Implementations may use a faster, less accurate estimate. """ 29 | 30 | size = sum(map(utf8len, tqdm(self.documents()))) 31 | print('size', self.name(), size) 32 | return size 33 | 34 | def num_docs(self): 35 | """ Return an estimate of the number of documents in the dataset. Implementations may use a faster, less accurate estimate. """ 36 | 37 | size = len(list(map(lambda x: None, tqdm(self.documents())))) 38 | print('docs', self.name(), size) 39 | return size 40 | 41 | def already_shuffled(self): 42 | """ Datasets where the source is already shuffled should override this to return True so that it isn't shuffled again. """ 43 | return False 44 | 45 | 46 | class WikipediaDataset(Dataset): 47 | def name(self): 48 | return "Wikipedia (en)" 49 | 50 | def _download(self): 51 | download('components/wikipedia_en/output/wikipedia-en.tar.gz', '87b78787f71297250bca644ab9d8e3992346eeb2e2ad91101487109e3d01e644', [ 52 | Source('direct', 'http://eaidata.bmk.sh/data/wikipedia-en.tar.gz'), 53 | ], extract=True) 54 | 55 | def documents(self): 56 | self._download() 57 | 58 | for file in ls('components/wikipedia_en/output'): 59 | if not file.endswith('.json'): 60 | continue 61 | 62 | with open(file) as fh: 63 | ob = json.load(fh) 64 | yield from dummy_meta(ob) 65 | 66 | def clean(self): 67 | rm_if_exists('components/wikipedia_en') 68 | 69 | def size(self): 70 | return 6847462907 71 | 72 | def num_docs(self): 73 | return 6033151 74 | 75 | 76 | class OpensubtitlesDataset(Dataset): 77 | def name(self): 78 | return "OpenSubtitles" 79 | 80 | def _download(self): 81 | download('components/opensubtitles/opensubtitles_out.tar', 'f3039709677292f899bb0a8fa2dbc6ae785f9e33ffd7613f94f4f722f2dfd95c', [ 82 | Source('direct', 'http://eaidata.bmk.sh/data/opensubtitles_out.tar'), 83 | ], extract=True) 84 | 85 | def documents(self): 86 | self._download() 87 | 88 | return dummy_meta(lmd.Reader('components/opensubtitles/out').stream_data()) 89 | 90 | def clean(self): 91 | rm_if_exists('components/opensubtitles') 92 | 93 | 94 | def size(self): 95 | return 13940478112 96 | 97 | def num_docs(self): 98 | return 446612 99 | 100 | 101 | class BookCorpusDataset(Dataset): 102 | def name(self): 103 | return "BookCorpus" 104 | 105 | def _download(self): 106 | download('components/bookcorpus/books1.tar.gz', 'e3c993cc825df2bdf0f78ef592f5c09236f0b9cd6bb1877142281acc50f446f9', [ 107 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/books1.tar.gz'), 108 | Source('direct', 'http://battle.shawwn.com/sdb/books1/books1.tar.gz'), 109 | ], extract=True) 110 | 111 | def documents(self): 112 | self._download() 113 | 114 | return dummy_meta(map(fread, ls('components/bookcorpus/books1/epubtxt'))) 115 | 116 | def clean(self): 117 | rm_if_exists('components/bookcorpus') 118 | 119 | def size(self): 120 | return 6767414779 121 | 122 | def num_docs(self): 123 | return 17868 124 | 125 | def already_shuffled(self): 126 | return True 127 | 128 | 129 | class OpenWebTextDataset(Dataset): 130 | def name(self): 131 | return "OpenWebText" 132 | 133 | def _download(self): 134 | # todo: convert 135 | download_directory = "components/openwebtext" 136 | done_file = os.path.join(download_directory, "download.done") 137 | if not os.path.exists(done_file): 138 | os.makedirs(download_directory, exist_ok=True) 139 | url = "https://drive.google.com/uc?id=1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx" 140 | output_file = os.path.join(download_directory, "openwebtext.tar.xz") 141 | gdown.download(url, output_file, quiet=False) 142 | sha256sum(output_file,'9fe39d154c5bc67da8c359415372b79510eb1e2edb0d035fe4f7fc3a732b9336') 143 | 144 | with open(done_file, "w") as fh: 145 | fh.write("done!") 146 | 147 | def documents(self): 148 | self._download() 149 | 150 | return dummy_meta(lmd.Reader('components/openwebtext/openwebtext').stream_data()) 151 | 152 | def clean(self): 153 | rm_if_exists('components/openwebtext') 154 | 155 | 156 | def size(self): 157 | return 39757465434 158 | 159 | def num_docs(self): 160 | return 8013769 161 | 162 | 163 | class GutenbergDataset(Dataset): 164 | def name(self): 165 | return "Gutenberg (PG-19)" 166 | 167 | def _download(self): 168 | if not os.path.exists('components/gutenberg'): 169 | # todo: convert after gcloud download is implemented 170 | sh(""" 171 | mkdir -p components/gutenberg 172 | cd components/gutenberg 173 | virtualenv env 174 | . env/bin/activate 175 | pip install gsutil 176 | mkdir -p pg19_train 177 | gsutil -m rsync gs://deepmind-gutenberg/train ./pg19_train 178 | """) 179 | 180 | def documents(self): 181 | self._download() 182 | 183 | return dummy_meta(map(fread, ls('components/gutenberg/pg19_train'))) 184 | 185 | def clean(self): 186 | rm_if_exists('components/gutenberg') 187 | 188 | def size(self): 189 | return 11678184672 190 | 191 | def num_docs(self): 192 | return 28602 193 | 194 | def already_shuffled(self): 195 | return True 196 | 197 | 198 | class DMMathDataset(Dataset): 199 | def name(self): 200 | return "DM Mathematics" 201 | 202 | def _download(self): 203 | if not os.path.exists('components/dm_math'): 204 | # todo: convert after gcloud download is implemented 205 | sh(""" 206 | mkdir -p components/dm_math 207 | cd components/dm_math 208 | virtualenv env 209 | . env/bin/activate 210 | pip install gsutil 211 | gsutil -m rsync gs://mathematics-dataset/ $PWD 212 | tar xf mathematics_dataset-v1.0.tar.gz 213 | """) 214 | sha256sum('components/dm_math/mathematics_dataset-v1.0.tar.gz', 'def638343403cb9ed60437d6b684c859dd23b72779f5cc5661b0a31e67c58576') 215 | 216 | def documents(self): 217 | self._download() 218 | 219 | return dummy_meta(chunk_at_even_lines(concat( 220 | map( 221 | lambda x: map(fread, ls('components/dm_math/mathematics_dataset-v1.0/train-' + x)), 222 | ['easy', 'medium', 'hard']) 223 | ), 8192)) 224 | 225 | def clean(self): 226 | rm_if_exists('components/dm_math') 227 | 228 | def size(self): 229 | return 8316165951 230 | 231 | def num_docs(self): 232 | return 1014997 233 | 234 | 235 | class EnronEmailsDataset(Dataset): 236 | def name(self): 237 | return "Enron Emails" 238 | 239 | def _download(self): 240 | download('components/enron_emails/enron_emails.jsonl.zst', '6968dd2d6d9c4328ee3b77b263aad38401b77c326f693ce051c98a3f215bf583', [ 241 | Source('direct', 'http://eaidata.bmk.sh/data/enron_emails.jsonl.zst'), 242 | ]) 243 | 244 | def documents(self): 245 | self._download() 246 | 247 | return lmd.Reader('components/enron_emails/enron_emails.jsonl.zst').stream_data(get_meta=True) 248 | 249 | def clean(self): 250 | rm_if_exists('components/enron_emails') 251 | 252 | def size(self): 253 | return 945212874 254 | 255 | def num_docs(self): 256 | return 517401 257 | 258 | class LiteroticaDataset(Dataset): 259 | """ Source: https://www.reddit.com/r/literotica/comments/6xvxvh/i_downloaded_all_380000_stories_on_literotica/?utm_source=share&utm_medium=ios_app&utm_name=iossmf """ 260 | def name(self): 261 | return "Literotica" 262 | 263 | def _download(self): 264 | download('components/literotica/Literotica.jsonl.zst', '3c6b968f851831c6345f175b394416f7521da3bacd90fdc827093f0d310bd4ef', [ 265 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/Literotica.jsonl.zst'), 266 | Source('gdrive', 'https://drive.google.com/uc?id=1Nx63w9BFZZSI_s2pmJnhcBU9c-y803T7'), 267 | ]) 268 | 269 | def documents(self): 270 | self._download() 271 | 272 | return lmd.Reader('components/literotica/Literotica.jsonl.zst').stream_data(get_meta=True) 273 | 274 | def clean(self): 275 | rm_if_exists('components/literotica') 276 | 277 | def size(self): 278 | return 12458318640 279 | 280 | def num_docs(self): 281 | return 473653 282 | 283 | 284 | class BibliotikDataset(Dataset): 285 | def name(self): 286 | return "Bibliotik" 287 | 288 | def _download(self): 289 | raise NotImplementedError('bibliotik temporarily unavailable') 290 | download('components/bibliotik/Bibliotik.jsonl.zst', '1aa43653f6de7ad074796bb6ca949beab584d91c5e188a66d994643838373b06', [ 291 | ]) 292 | 293 | def documents(self): 294 | self._download() 295 | 296 | yield from lmd.Reader('components/bibliotik/Bibliotik.jsonl.zst').stream_data(get_meta=True) 297 | 298 | def clean(self): 299 | rm_if_exists('components/bibliotik') 300 | 301 | def size(self): 302 | return 108404259563 303 | 304 | def num_docs(self): 305 | return 196640 306 | 307 | def already_shuffled(self): 308 | return True 309 | 310 | 311 | class CORD19Dataset(Dataset): 312 | def name(self): 313 | return "CORD-19" 314 | 315 | def _download(self): 316 | 317 | if not os.path.exists('components/cord19'): 318 | if not os.path.exists('document_parses'): 319 | raise AssertionError('Must download document_parses manually!') 320 | 321 | sh(""" 322 | mkdir -p components/cord19 323 | cd components/cord19 324 | 325 | git clone https://github.com/EleutherAI/pile_cord19 . 326 | virtualenv env 327 | . env/bin/activate 328 | 329 | mv ../../document_parses . 330 | 331 | pip install -r requirements.txt 332 | python main.py 333 | """) 334 | 335 | def documents(self): 336 | self._download() 337 | 338 | return lmd.Reader('components/cord19/out').stream_data(get_meta=True) 339 | 340 | def clean(self): 341 | rm_if_exists('components/cord19') 342 | 343 | def size(self): 344 | return 4573360967 345 | 346 | def num_docs(self): 347 | return 174560 348 | 349 | 350 | class UbuntuIRCDataset(Dataset): 351 | def name(self): 352 | return "Ubuntu IRC" 353 | 354 | def _download(self): 355 | download('components/ubuntu_irc/ubuntu_irc_weekly.jsonl.zst', 'b744a253c5406f32c7a9c76ba4cf7888fdeb4b5b6bdc368ca9359a0238b968c9', [ 356 | Source('direct', 'http://eaidata.bmk.sh/data/ubuntu_irc_weekly.jsonl.zst'), 357 | ]) 358 | 359 | def documents(self): 360 | self._download() 361 | 362 | return lmd.Reader('components/ubuntu_irc/ubuntu_irc_weekly.jsonl.zst').stream_data(get_meta=True) 363 | 364 | def clean(self): 365 | rm_if_exists('components/ubuntu_irc') 366 | 367 | def size(self): 368 | return 5923631555 369 | 370 | def num_docs(self): 371 | return 10605 372 | 373 | 374 | class ArXivDataset(Dataset): 375 | def name(self): 376 | return "ArXiv" 377 | 378 | def _download(self): 379 | download('components/arxiv/arxiv.jsonl.zst', '084b894f513986076a7d97e5c323c7fa8ebef1733f151a7fbdb139c19c07b571', [ 380 | Source('direct', 'http://eaidata.bmk.sh/data/arxiv.jsonl.zst'), 381 | ]) 382 | 383 | def documents(self): 384 | self._download() 385 | 386 | return lmd.Reader('components/arxiv/arxiv.jsonl.zst').stream_data(get_meta=True) 387 | 388 | def clean(self): 389 | rm_if_exists('components/arxiv') 390 | 391 | def size(self): 392 | return 60353358395 393 | 394 | def num_docs(self): 395 | return 1264405 396 | 397 | def already_shuffled(self): 398 | return True 399 | 400 | 401 | class PubMedDataset(Dataset): 402 | def name(self): 403 | return "PubMed Abstracts" 404 | 405 | def _download(self): 406 | download('components/pubmed/PUBMED_title_abstracts_2019_baseline.jsonl.zst', '15c26a83ac2b11378b8e6ba5a16bab92428de29bacb85709834948cfcf1f029b', [ 407 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/PUBMED_title_abstracts_2019_baseline.jsonl.zst'), 408 | Source('direct', 'http://eaidata.bmk.sh/data/PUBMED_title_abstracts_2019_baseline.jsonl.zst'), 409 | ]) 410 | 411 | def documents(self): 412 | self._download() 413 | 414 | return lmd.Reader('components/pubmed/PUBMED_title_abstracts_2019_baseline.jsonl.zst').stream_data(get_meta=True) 415 | 416 | def clean(self): 417 | rm_if_exists('components/pubmed') 418 | 419 | def size(self): 420 | return 20684050384 421 | 422 | def num_docs(self): 423 | return 15518009 424 | 425 | 426 | class ExPorterDataset(Dataset): 427 | def name(self): 428 | return "NIH ExPorter" 429 | 430 | def _download(self): 431 | download('components/exporter/NIH_ExPORTER_awarded_grant_text.jsonl.zst', 'be7fc69b9a3652391b6567891b99277ac99e7dfd5892ba19cb312f909357c458', [ 432 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst'), 433 | Source('gdrive', 'https://drive.google.com/uc?id=11mO-0LuL2YeKoqqWXyHPHf3d2ODnjVPP'), 434 | ]) 435 | 436 | def documents(self): 437 | self._download() 438 | 439 | return lmd.Reader('components/exporter/NIH_ExPORTER_awarded_grant_text.jsonl.zst').stream_data(get_meta=True) 440 | 441 | def clean(self): 442 | rm_if_exists('components/exporter') 443 | 444 | def size(self): 445 | return 2034579138 446 | 447 | def num_docs(self): 448 | return 939661 449 | 450 | 451 | class StackExchangeDataset(Dataset): 452 | def name(self): 453 | return "StackExchange" 454 | 455 | def _download(self): 456 | download('components/stackexchange/stackexchange_dataset.tar', 'f64f31d20db8d8692c1a019314a14974b4911a34ffef126feaf42da88860c666', [ 457 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/stackexchange_dataset.tar'), 458 | Source('direct', 'http://eaidata.bmk.sh/data/stackexchange_dataset.tar'), 459 | ], extract=True) 460 | 461 | def documents(self): 462 | self._download() 463 | 464 | return dummy_meta(lmd.Reader('components/stackexchange/out').stream_data()) 465 | 466 | def clean(self): 467 | rm_if_exists('components/stackexchange/out') 468 | 469 | def size(self): 470 | return 34571286358 471 | 472 | def num_docs(self): 473 | return 15622475 474 | 475 | 476 | class FreeLawDataset(Dataset): 477 | def name(self): 478 | return "FreeLaw" 479 | 480 | def _download(self): 481 | download('components/freelaw/FreeLaw_Opinions.jsonl.zst', '7d7ba907cf397e8585bb3ef148b3e9678edbf142b2247460f907c16aecbaed2d', [ 482 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/FreeLaw_Opinions.jsonl.zst'), 483 | Source('gdrive', 'https://drive.google.com/uc?id=1L-x3g3V888gHNUVHQWDkJBJHs5N02Kjz'), 484 | ]) 485 | 486 | def documents(self): 487 | self._download() 488 | 489 | return lmd.Reader('components/freelaw/FreeLaw_Opinions.jsonl.zst').stream_data(get_meta=True) 490 | 491 | def clean(self): 492 | rm_if_exists('components/freelaw') 493 | 494 | def size(self): 495 | return 54923939791 496 | 497 | def num_docs(self): 498 | return 3562015 499 | 500 | 501 | class PubMedCentralDataset(Dataset): 502 | def name(self): 503 | return "PubMed Central" 504 | 505 | def _download(self): 506 | download('components/pubmedcentral/PMC_extracts.tar.gz', 'dd2ecc79480bd5b78c29ea78af96941c69f6bda3d36a7d510019ccc4848fb867', [ 507 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/PMC_extracts.tar.gz'), 508 | Source('direct', 'http://eaidata.bmk.sh/data/PMC_extracts.tar.gz'), 509 | ]) 510 | 511 | def documents(self): 512 | self._download() 513 | 514 | return dummy_meta(map(strip_markdown_colons, lmd.Reader('components/pubmedcentral/PMC_extracts.tar.gz').stream_data())) 515 | 516 | def clean(self): 517 | rm_if_exists('components/pubmedcentral') 518 | 519 | def size(self): 520 | return 96929951580 521 | 522 | def num_docs(self): 523 | return 3098931 524 | 525 | 526 | class CZICDataset(Dataset): 527 | def name(self): 528 | return "CZIC" 529 | 530 | def _download(self): 531 | # todo: convert CZIC 532 | if not os.path.exists('components/czic'): 533 | sh(""" 534 | mkdir -p components/czic 535 | cd components/czic 536 | virtualenv env 537 | . env/bin/activate 538 | pip install gdown 539 | gdown https://drive.google.com/uc?id=1qjZZTqS-m63TMKBYB1eNRc5Bh4W--SYQ 540 | """) 541 | sha256sum('components/czic/GOVINFO_CZIC_KL.jsonl.zst', 'c7a46f5af12789fc8b2105b208e22fa400c63ac720c72073e90ee91af6744f00') 542 | 543 | def documents(self): 544 | self._download() 545 | 546 | return lmd.Reader('components/czic/GOVINFO_CZIC_KL.jsonl.zst').stream_data(get_meta=True) 547 | 548 | def clean(self): 549 | rm_if_exists('components/czic') 550 | 551 | def size(self): 552 | return 837798818 553 | 554 | def num_docs(self): 555 | return 4774 556 | 557 | 558 | class PhilPapersDataset(Dataset): 559 | def name(self): 560 | return "PhilPapers" 561 | 562 | def _download(self): 563 | download('components/philpapers/PhilArchive.jsonl.zst', 'e90529b9b3961328d1e34b60534a8e0f73d5ad1f104e22a217de53cd53c41fea', [ 564 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/PhilArchive.jsonl.zst'), 565 | Source('gdrive', 'https://drive.google.com/uc?id=1u01vkBNAS8jtu0AZeQW56bzf-6QbeSRB'), 566 | ]) 567 | 568 | def documents(self): 569 | self._download() 570 | 571 | return lmd.Reader('components/philpapers/PhilArchive.jsonl.zst').stream_data(get_meta=True) 572 | 573 | def clean(self): 574 | rm_if_exists('components/philpapers') 575 | 576 | def size(self): 577 | return 2553543227 578 | 579 | def num_docs(self): 580 | return 33990 581 | 582 | 583 | class USPTODataset(Dataset): 584 | def name(self): 585 | return "USPTO" 586 | 587 | def _download(self): 588 | download('components/uspto/pile_uspto.jsonl.zst.tar', '7a7d2c8e21df2ad0324810a8e675f4d8bdc5ee40b17914a6c0542ddfda1560fd', [ 589 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/pile_uspto.tar'), 590 | Source('direct', 'http://eaidata.bmk.sh/data/pile_uspto.tar'), 591 | ]) 592 | 593 | def documents(self): 594 | self._download() 595 | 596 | return lmd.Reader('components/uspto/pile_uspto.jsonl.zst.tar').stream_data(get_meta=True) 597 | 598 | def clean(self): 599 | rm_if_exists('components/uspto') 600 | 601 | def size(self): 602 | return 24593538339 603 | 604 | def num_docs(self): 605 | return 5883037 606 | 607 | 608 | class EuroParlDataset(Dataset): 609 | def name(self): 610 | return "EuroParl" 611 | 612 | def _download(self): 613 | download('components/europarl/EuroParliamentProceedings_1996_2011.jsonl.zst', '6111400e7b7f75ce91fed1b5fc0a3630b8263217bd01ce75f7d8701f26ac0e98', [ 614 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/EuroParliamentProceedings_1996_2011.jsonl.zst'), 615 | Source('gdrive', 'https://drive.google.com/uc?id=12Q23Y7IKQyjF28xH0Aw6yZaYEx2YIOiB'), 616 | ]) 617 | 618 | def documents(self): 619 | self._download() 620 | 621 | return lmd.Reader('components/europarl/EuroParliamentProceedings_1996_2011.jsonl.zst').stream_data(get_meta=True) 622 | 623 | def clean(self): 624 | rm_if_exists('components/europarl') 625 | 626 | def size(self): 627 | return 4923130035 628 | 629 | def num_docs(self): 630 | return 69814 631 | 632 | 633 | class YTSubtitlesDataset(Dataset): 634 | def name(self): 635 | return "YoutubeSubtitles" 636 | 637 | def _download(self): 638 | download('components/youtubesubtitles/yt_subs.jsonl.zst', '0b9130b8c92290eba360337fea90c2617721f65d955f785f8755cb5f4a8e319c', [ 639 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst'), 640 | Source('direct', 'http://eaidata.bmk.sh/data/yt_subs.jsonl.zst'), 641 | ]) 642 | 643 | def documents(self): 644 | self._download() 645 | 646 | return lmd.Reader('components/youtubesubtitles/yt_subs.jsonl.zst').stream_data(get_meta=True) 647 | 648 | def clean(self): 649 | rm_if_exists('components/youtubesubtitles') 650 | 651 | def size(self): 652 | return 4010420381 653 | 654 | def num_docs(self): 655 | return 173651 656 | 657 | 658 | class HackerNewsDataset(Dataset): 659 | def name(self): 660 | return "HackerNews" 661 | 662 | def _download(self): 663 | download('components/hackernews/hn.jsonl.zst', '9fbc978c92a466b1653cd578700eeb8b417ddcf8c66c7c468d5c1d11ef82aed7', [ 664 | Source('direct', 'http://eaidata.bmk.sh/data/hn.jsonl.zst'), 665 | ]) 666 | 667 | def documents(self): 668 | self._download() 669 | 670 | return lmd.Reader('components/hackernews/hn.jsonl.zst').stream_data(get_meta=True) 671 | 672 | def clean(self): 673 | rm_if_exists('components/hackernews') 674 | 675 | def size(self): 676 | return 4185091916 677 | 678 | def num_docs(self): 679 | return 831198 680 | 681 | 682 | class FullGithubDataset(Dataset): 683 | def name(self): 684 | return "Github" 685 | 686 | def _download(self): 687 | download('components/github/github.jsonl.zst.tar', 'f7a66e8226baf075a42628d10d8eba234460da73b0ffd300736036db9be3b3c3', [ 688 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/github.tar'), 689 | Source('direct', 'http://eaidata.bmk.sh/data/github.tar'), 690 | ]) 691 | 692 | def documents(self): 693 | self._download() 694 | 695 | return filter(lambda x: len(x[0]) < 100000, lmd.Reader('components/github/github.jsonl.zst.tar').stream_data(get_meta=True)) 696 | 697 | def clean(self): 698 | rm_if_exists('components/github') 699 | 700 | def size(self): 701 | return 677143668214 702 | 703 | def num_docs(self): 704 | return 56626342 705 | 706 | 707 | class GithubDataset(Dataset): 708 | def name(self): 709 | return "Github" 710 | 711 | def _download(self): 712 | download('components/github/github_small.jsonl.zst', '4323250bed817466de868f752b7685350123cff1f1363e87dfb6f22585b97f96', [ 713 | Source('direct', 'http://eaidata.bmk.sh/data/github_small.jsonl.zst'), 714 | ]) 715 | 716 | def documents(self): 717 | self._download() 718 | 719 | return lmd.Reader('components/github/github_small.jsonl.zst').stream_data(get_meta=True) 720 | 721 | def clean(self): 722 | rm_if_exists('components/github') 723 | 724 | def size(self): 725 | return 102180233200 726 | 727 | def num_docs(self): 728 | return 19021454 729 | 730 | 731 | class OpenWebText2Dataset(Dataset): 732 | def name(self): 733 | return "OpenWebText2" 734 | 735 | def _download(self): 736 | download('components/openwebtext2/openwebtext2.jsonl.zst.tar', '9043d1b93c35ff1a38a17e16c73c009d4617dcaab6da15adc0faf4779739a027', [ 737 | Source('direct', 'https://the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar'), 738 | Source('direct', 'http://eaidata.bmk.sh/data/openwebtext2.jsonl.zst.tar'), 739 | ]) 740 | 741 | def documents(self): 742 | self._download() 743 | 744 | return map(lambda x: (remove_advertisement(x[0]), x[1]), lmd.Reader('components/openwebtext2/openwebtext2.jsonl.zst.tar').stream_data(get_meta=True)) 745 | 746 | def clean(self): 747 | rm_if_exists('components/openwebtext2') 748 | 749 | def size(self): 750 | return 67396380547 751 | 752 | def num_docs(self): 753 | return 17103059 754 | 755 | 756 | class CommonCrawlDataset(Dataset): 757 | def name(self): 758 | return "CommonCrawl" 759 | 760 | def _download(self): 761 | download('components/commoncrawl/pile_cc_filtered_deduped.jsonl.zst', '4906a6731a7d2d9182c40a13d681078ed537508cf75b1d32ad7f7c491b2f272a', [ 762 | Source('direct', 'http://eaidata.bmk.sh/data/pile_cc_filtered_deduped.jsonl.zst'), 763 | ]) 764 | 765 | def documents(self): 766 | self._download() 767 | 768 | return lmd.Reader('components/commoncrawl/pile_cc_filtered_deduped.jsonl.zst').stream_data(get_meta=True) 769 | 770 | def clean(self): 771 | rm_if_exists('components/commoncrawl') 772 | 773 | def size(self): 774 | return 243872121726 775 | 776 | def num_docs(self): 777 | return 54953117 778 | -------------------------------------------------------------------------------- /the_pile/pile.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import fasttext 4 | 5 | from pytablewriter import MarkdownTableWriter 6 | from tqdm import tqdm 7 | 8 | from the_pile.utils import humanbytes, parse_size 9 | from the_pile.datasets import * 10 | 11 | 12 | datasets = [ 13 | # Academic 14 | (PubMedCentralDataset(), 2. ), 15 | (ArXivDataset() , 2. ), 16 | (FreeLawDataset() , 1.5 ), 17 | (USPTODataset() , 2. ), 18 | (PubMedDataset() , 2. ), 19 | (PhilPapersDataset() , 2. ), 20 | (ExPorterDataset() , 2. ), 21 | 22 | # General internet 23 | (OpenWebText2Dataset() , 2. ), 24 | (StackExchangeDataset(), 2. ), 25 | (WikipediaDataset() , 3. ), 26 | 27 | # Prose 28 | (BibliotikDataset() , 1.5 ), 29 | (GutenbergDataset() , 2.5 ), 30 | (BookCorpusDataset() , 1.5 ), 31 | 32 | # Github 33 | (GithubDataset() , 1. ), 34 | 35 | # Dialogue 36 | (UbuntuIRCDataset() , 2. ), 37 | (HackerNewsDataset() , 2. ), 38 | (EuroParlDataset() , 2. ), 39 | (YTSubtitlesDataset() , 2. ), 40 | (OpensubtitlesDataset(), 1.5 ), 41 | 42 | # Misc 43 | (DMMathDataset() , 2. ), 44 | (EnronEmailsDataset() , 2. ), 45 | 46 | ] 47 | 48 | 49 | def take(n, iter): 50 | ret = [] 51 | for i in range(n): 52 | try: 53 | ret.append(next(iter)) 54 | except StopIteration: 55 | break 56 | return ret 57 | 58 | def mk_table(datasets, train_chars, print_latex=False): 59 | values = [] 60 | 61 | total_weight = sum([x[1] * x[0].size() for x in datasets]) 62 | 63 | for dataset, weight in datasets: 64 | size = dataset.size() 65 | relative_weight = size * weight / total_weight 66 | values.append([dataset.name(), size, '{:.2%}'.format(relative_weight), '{:.4f}'.format(train_chars / size * relative_weight), size * weight, humanbytes(size / dataset.num_docs(), 'KiB')]) 67 | 68 | values.sort(key=lambda x: -x[4]) 69 | values.append(['**Total**', "", "", "", sum([x[4] for x in values]), humanbytes(sum([x[1] for x in values]) / sum(x[0].num_docs() for x in datasets), 'KiB')]) 70 | values = [[x[0], humanbytes(x[1], 'GiB') if x[1] else "", x[2], x[3], humanbytes(x[4], 'GiB'), x[5]] for x in values] 71 | 72 | writer = MarkdownTableWriter() 73 | writer.table_name = "The Pile™" 74 | writer.headers = ["Component", "Raw Size", "Weight", "Epochs", "Effective Size", "Mean Document Size"] 75 | writer.value_matrix = values 76 | 77 | if print_latex: 78 | rows = [] 79 | for row in values[:-1]: 80 | rows.append(" " + " & ".join(map(lambda x: str(x).replace('%', r'\%'), row)) + r" \\") 81 | totalrow = " & ".join(map(lambda x: r'\textbf{%s}' % str(x).replace('%', r'\%') if x else "", values[-1][1:])) + r" \\" 82 | latex = r""" 83 | \begin{table*}[t!] 84 | \centering 85 | \begin{tabular}{l r r r r r} 86 | \toprule 87 | \textbf{Component} & \textbf{Raw Size} & \textbf{Weight} & \textbf{Copies} & \textbf{Effective Size} & \textbf{Mean Document Size} \\ 88 | \midrule 89 | """ + "\n".join(rows) + r""" 90 | \midrule 91 | \textbf{The Pile} & """ + totalrow + r""" 92 | \bottomrule 93 | \end{tabular} 94 | \caption{Overview of datasets in \textit{The Pile} before deduplication. The Pile is distributed with a predefined up/down-sampling of the different constituent datasets.} 95 | \label{table:pile_overview} 96 | \end{table*} 97 | """ 98 | print(latex) 99 | return writer.dumps() 100 | 101 | 102 | def dataset_tqdm(dset): 103 | if isinstance(dset, PileReplication): 104 | return dset.documents() 105 | pbar = tqdm(total=dset.size(), unit='B', unit_scale=True, unit_divisor=1024) 106 | for doc in dset.documents(): 107 | pbar.update(utf8len(doc)) 108 | yield doc 109 | 110 | 111 | class Profiler: 112 | def __init__(self, profile): 113 | self.i = 0 114 | self.profile = profile 115 | self.time_per_dataset = collections.defaultdict(lambda: [0, 0]) 116 | 117 | def measured_next(self, name, iter): 118 | if not self.profile: 119 | # no-op 120 | return next(iter) 121 | else: 122 | self.i += 1 123 | start = time.time() 124 | doc = next(iter) 125 | elapsed = time.time() - start 126 | 127 | self.time_per_dataset[name][0] += elapsed 128 | self.time_per_dataset[name][1] += 1 129 | 130 | if self.i % 100000 == 0: 131 | times = [(dsname, total, ct) for dsname, (total, ct) in self.time_per_dataset.items()] 132 | times.sort(key=lambda x: x[1]) 133 | for name, total, ct in times: 134 | print(name.ljust(22), '{:.8f}'.format(total / ct), str(ct).rjust(8), '{:.4f}'.format(total)) 135 | 136 | return doc 137 | 138 | 139 | class PileReplication(Dataset): 140 | def __init__(self, datasets, dataset_bytes, profile=False): 141 | self.datasets = datasets 142 | self.dataset_bytes = dataset_bytes 143 | self.profile = profile 144 | self.rnd = random.Random(42) 145 | 146 | def name(self): 147 | return "Custom Pile" 148 | 149 | def documents(self): 150 | datasets = [] 151 | weights = [] 152 | 153 | # calculate relative_weight for each 154 | total_weight = sum([x[1] * x[0].num_docs() for x in self.datasets]) 155 | for dataset, weight in self.datasets: 156 | size = dataset.size() 157 | relative_weight = weight * dataset.num_docs() / total_weight 158 | datasets.append((dataset.name(), cycle_documents(dataset))) 159 | weights.append(relative_weight) 160 | 161 | # yield from dataset until right number of bytes 162 | total_bytes = 0 163 | pbar = tqdm(total=self.dataset_bytes, unit='B', unit_scale=True, unit_divisor=1024) 164 | 165 | 166 | profiler = Profiler(profile=self.profile) 167 | while True: 168 | chunk = self.rnd.choices(population=datasets, weights=weights, k=1000) 169 | for name, dset in chunk: 170 | doc, meta = profiler.measured_next(name, dset) 171 | 172 | size = utf8len(doc) 173 | total_bytes += size 174 | pbar.update(size) 175 | 176 | meta['pile_set_name'] = name 177 | 178 | yield doc, meta 179 | 180 | if total_bytes > self.dataset_bytes: 181 | return 182 | 183 | def clean(self): 184 | for dataset, _ in self.datasets: dataset.clean() 185 | 186 | def size(self): 187 | return self.dataset_bytes 188 | 189 | 190 | class ThePile(Dataset): 191 | def name(self): 192 | return "The Pile" 193 | 194 | def _download(self): 195 | # TODO: host final pile 196 | pass 197 | 198 | def documents(self): 199 | self._download() 200 | 201 | return lmd.Reader('pile_output').stream_data(get_meta=True) 202 | 203 | def clean(self): 204 | rm_if_exists('pile_output') 205 | 206 | def size(self): 207 | return 1200 * 1024 * 1024 * 1024 208 | 209 | class LimitedDataset(Dataset): 210 | def __init__(self, dataset, limit_size): 211 | self.dataset = dataset 212 | self.limit_size = limit_size 213 | self.rnd = random.Random(42) 214 | 215 | def name(self): 216 | return self.dataset.name() + " (truncated)" 217 | 218 | def documents(self): 219 | numer = self.limit_size 220 | denom = self.dataset.size() 221 | for doc, meta in dataset_tqdm(self.dataset): 222 | docsize = utf8len(doc) 223 | if self.rnd.random() < numer / denom: 224 | yield doc, meta 225 | numer -= docsize 226 | denom -= docsize 227 | 228 | if numer <= 0 or denom <= 0: 229 | break 230 | 231 | def clean(self): 232 | self.dataset.clean() 233 | 234 | def size(self): 235 | return self.limit_size 236 | 237 | 238 | def preprocess_for_fasttext(x): 239 | return x.replace('\n', ' ').replace('\r', ' ')[:4000][-1500:] 240 | 241 | 242 | import collections 243 | import argparse 244 | import json 245 | 246 | def make_fasttext(pile, keep_frac): 247 | with open('fasttext_pile.txt', 'w') as fh, open('pile_sample.txt', 'w') as fh2: 248 | for x, _ in pile: 249 | if random.random() < keep_frac: 250 | p = preprocess_for_fasttext(x) 251 | if len(p) > 100: 252 | fh.write('__label__pile ' + p + '\n') 253 | if random.random() < 0.001: 254 | fh2.write(x + '<|endoftext|>\n') 255 | 256 | def lang_stats(pile): 257 | download_file('https://dl.fbaipublicfiles.com/fasttext/supervised-models/lid.176.bin', 'lid.176.bin', '7e69ec5451bc261cc7844e49e4792a85d7f09c06789ec800fc4a44aec362764e') 258 | 259 | langdet = fasttext.load_model("lid.176.bin") 260 | langs = collections.defaultdict(lambda: collections.defaultdict(int)) 261 | for i, (data, meta) in enumerate(pile.documents()): 262 | details = langdet.predict(data.replace('\n', ' ')[:3000], k=1) 263 | 264 | langs[meta['pile_set_name']][details[0][0].replace('__label__', '')] += 1 265 | 266 | if (i+1) % 100000 == 0: 267 | for name, x in langs.items(): 268 | print('========= {} =========='.format(name)) 269 | print('\n'.join([k + ',' + str(v / sum(x.values())).ljust(9) for k,v in sorted(list(x.items()), key=lambda x: -x[1])])) 270 | 271 | ob = { 272 | 'langs': langs 273 | } 274 | with open('language_stats.json', 'w') as fout: 275 | fout.write(json.dumps(ob)) 276 | 277 | 278 | def sample_from_sets(datasets, n_docs): 279 | random.seed(42) 280 | for dset, _ in datasets: 281 | print(dset.name()) 282 | fname = 'dataset_samples/{}.json'.format(dset.name().replace(' ', '_')) 283 | if os.path.exists(fname): continue 284 | 285 | n = dset.num_docs() 286 | 287 | indices = set(random.sample(range(n), n_docs)) 288 | pbar = tqdm(total=n_docs) 289 | 290 | docs = [] 291 | for i, (doc, meta) in enumerate(dset.documents()): 292 | if i > max(indices): break 293 | if i in indices: 294 | docs.append((doc, meta)) 295 | pbar.update(1) 296 | 297 | try: 298 | os.mkdir('dataset_samples') 299 | except: 300 | pass 301 | 302 | with open(fname, 'w') as fh: 303 | json.dump(docs, fh) 304 | 305 | pbar.close() 306 | 307 | 308 | def docs_for_dedupe(): 309 | # format: ((priority, offset, sha256sum), document) 310 | dset = CommonCrawlDataset() 311 | for i, (doc, _) in dset.documents(): 312 | yield (100, i, sha256str(doc)), doc 313 | 314 | 315 | if __name__ == '__main__': 316 | 317 | parser = argparse.ArgumentParser(description='Process some integers.') 318 | parser.add_argument('--force_download', action='store_true', help='force download all') 319 | parser.add_argument('--limit', type=str, help='limit output size - this option causes read_amount tokens to be generated and then limit tokens to be sampled') 320 | parser.add_argument('--using', type=str, default='pile', help='the dataset to use') 321 | parser.add_argument('--chunk', type=str, help='output chunk size (for make_lmd)') 322 | parser.add_argument('--interleave_output', type=int, help='output interleaved chunks (for make_lmd)') 323 | parser.add_argument('--make_dummy', action='store_true', help='dummy consumer') 324 | parser.add_argument('--make_lmd', action='store_true', help='generate lm_dataformat') 325 | parser.add_argument('--make_fasttext', action='store_true', help='make data for fasttext') 326 | parser.add_argument('--make_lang_analysis', action='store_true', help='make language analysis data') 327 | parser.add_argument('--make_dataset_samples', type=int, help='make dataset sample data') 328 | parser.add_argument('--profile', action='store_true', help='turn on profiler') 329 | parser.add_argument('--read_amount', type=str, help='the size of the data read from the set') 330 | 331 | args = parser.parse_args() 332 | random.seed(42) 333 | 334 | if args.using != 'pile_reprod_no_cc': 335 | # add CC 336 | datasets.append((CommonCrawlDataset(), 1.)) 337 | 338 | if args.read_amount is None: 339 | args.read_amount = sum([ds.size() * epochs for ds, epochs in datasets]) 340 | else: 341 | args.read_amount = parse_size(args.read_amount) 342 | 343 | print(mk_table(datasets, args.read_amount)) 344 | 345 | if args.using == 'pile_reprod' or args.using == 'pile_reprod_no_cc': 346 | pile = PileReplication(datasets, args.read_amount, profile=args.profile) 347 | elif args.using == 'cc': 348 | pile = CommonCrawlDataset() 349 | elif args.using == 'pile': 350 | pile = ThePile() 351 | elif args.using == 'owt2': 352 | pile = OpenWebText2Dataset() 353 | elif args.using == 'bibliotik': 354 | pile = BibliotikDataset() 355 | else: 356 | print('We don\'t have a shortcut for that yet!') 357 | 358 | if args.force_download: 359 | for dset, _ in datasets: 360 | dset._download() 361 | 362 | if args.limit: 363 | size_limit = parse_size(args.limit) 364 | pile = LimitedDataset(pile, size_limit) 365 | 366 | if args.make_lmd: 367 | assert not (args.interleave_output and args.chunk) # can't chunk and interleave 368 | 369 | if args.interleave_output: 370 | ars = [lmd.Archive('pile_pass1/chunk{}'.format(i)) for i in range(args.interleave_output)] 371 | else: 372 | ar = lmd.Archive('pile_output') 373 | 374 | if args.chunk: 375 | chunk_size = parse_size(args.chunk) 376 | 377 | cursize = 0 378 | for doc, meta in pile.documents(): 379 | if args.interleave_output: 380 | ar = random.choice(ars) 381 | 382 | ar.add_data(doc, meta) 383 | 384 | cursize += len(doc) 385 | if args.chunk and cursize > chunk_size: 386 | # interleave will not be on 387 | cursize = 0 388 | ar.commit(archive_name=args.using) 389 | 390 | if args.interleave_output: 391 | for ar in ars: ar.commit(archive_name=args.using) 392 | else: 393 | ar.commit(archive_name=args.using) 394 | 395 | if args.make_fasttext: 396 | make_fasttext(pile.documents(), 0.1) 397 | 398 | if args.make_dataset_samples: 399 | sample_from_sets(datasets, args.make_dataset_samples) 400 | 401 | if args.make_lang_analysis: 402 | lang_stats(pile) 403 | 404 | if args.make_dummy: 405 | pbar = tqdm(total=pile.size(), unit="B", unit_scale=1) 406 | for doc, meta in pile.documents(): 407 | pbar.update(utf8len(doc)) 408 | pbar.close() -------------------------------------------------------------------------------- /the_pile/tfds_pile.py: -------------------------------------------------------------------------------- 1 | """the_pile dataset""" 2 | 3 | import tensorflow_datasets as tfds 4 | import tensorflow as tf 5 | import io 6 | import zstandard 7 | import jsonlines 8 | import os 9 | 10 | """ 11 | Tips for Colab - Change _PILE_SPLITS below to increments of 8 to allow downloading and storing in GCS 12 | After every 8 parts, tfds will flush the tempfiles from local and it will be cached on GCS, allowing reuse 13 | preventing th need to redownload again. Example below 14 | 15 | _download: Skipping download of http://eaidata.bmk.sh/data/pile/train/26.jsonl.zst: File cached in gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_26.jsonlCue2aNl9cxodxAvl9vIacuexGWYSoJAt4Rpcy19pqds.zst 16 | _download: Skipping download of http://eaidata.bmk.sh/data/pile/train/27.jsonl.zst: File cached in gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_27.jsonlt8W_PLYeC4bZeaNMqMhe0-lhS3ijPL7RjvILWsMZlhQ.zst 17 | _download: Downloading http://eaidata.bmk.sh/data/pile/train/28.jsonl.zst into gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_28.jsonl7Fj9nvI6std-e0H2ScxDKMpTWEC8iJMI8OT2vxLw2A4.zst.tmp.576c9ac11d30419b8ea8f30a5157ee53... 18 | _download: Downloading http://eaidata.bmk.sh/data/pile/train/29.jsonl.zst into gs://your_bucket/datasets/cached/downloads/eaidata.bmk.sh_pile_train_29.jsonl1syFpl-ESnwk__9_6Xrj_OO5mRxpmaxQG7bZ_5d2sZc.zst.tmp.2f7f6afb86d74e988dcdb71d59b0d3f2... 19 | 20 | 21 | Use tfds.disable_progress_bar() to prevent javascript issues 22 | This uses pysimdjson for faster parsing of json. The entire dataset should be completed in around 12 hours on Colab. 23 | 24 | """ 25 | 26 | _USAGE_EXAMPLE = """ 27 | This can be run in a script or in a notebook. 28 | 29 | _GCS_BUCKET = 'gs://your_gcs_bucket/path' 30 | 31 | import os 32 | os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = '/path/to/adc.json' # if building to store in GCS 33 | os.environ['TFDS_DATA_DIR'] = _GCS_BUCKET 34 | 35 | import tensorflow_datasets as tfds 36 | from the_pile import tfds_pile 37 | from transformers import GPT2TokenizerFast 38 | 39 | tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') 40 | tokenizer.add_special_tokens({'pad_token': '<|padding|>'}) 41 | 42 | def simple_tokenization(item): 43 | return tokenizer.encode(item['text'], return_tensors='tf') 44 | 45 | tfds.disable_progress_bar() # optional - will help with colab since tqdm breaks often 46 | 47 | ds = tfds.load(name="ThePile", try_gcs=True) 48 | 49 | # Have not tested below 50 | ds.map(simple_tokenization, num_parallel_calls=tf.data.experimental.AUTOTUNE) 51 | # or 52 | ds.map(lambda item: simple_tokenization(item), num_parallel_calls=tf.data.experimental.AUTOTUNE) 53 | 54 | """ 55 | 56 | try: 57 | import simdjson as json 58 | except ImportError: 59 | print('Installing simdjson library') 60 | os.system('pip install -q pysimdjson') 61 | import simdjson as json 62 | parser = json.Parser() 63 | 64 | 65 | _DESCRIPTION = """ 66 | The Pile is a large, diverse, open source language modelling data set 67 | that consists of many smaller datasets combined together. 68 | The objective is to obtain text from as many modalities as possible to 69 | ensure that models trained using The Pile will have much broader generalization abilities. 70 | We are currently developing Version 1, with an ultimate goal of 1 TiB of English text. 71 | After the completion of Version 1, our next goal is a fully-multilingual, 10TiB text dataset. 72 | """ 73 | 74 | _CITATION = """ 75 | """ 76 | _DATASET_MODES = ["lm"] 77 | 78 | _PILE_URL = 'http://eaidata.bmk.sh/data/pile/train/{}.jsonl.zst' 79 | _PILE_SPLITS = 30 80 | 81 | _URLS = { 82 | 'the_pile': { 83 | 'train': [_PILE_URL.format(str(i).zfill(2)) for i in range(_PILE_SPLITS)], 84 | 'test': 'http://eaidata.bmk.sh/data/pile/test.jsonl.zst', 85 | 'validation': 'http://eaidata.bmk.sh/data/pile/val.jsonl.zst', 86 | } 87 | } 88 | 89 | 90 | _VERSION = tfds.core.Version('1.0.0') 91 | _RELEASE_NOTES = { 92 | '1.0.0': 'Initial release.', 93 | } 94 | 95 | _NAME = 'the_pile' 96 | _FILE_FORMAT = 'jsonlines' 97 | 98 | def json_parser(x): 99 | try: 100 | line = parser.parse(x).as_dict() 101 | return line 102 | except ValueError: 103 | return x 104 | 105 | class PileReader: 106 | def __init__(self, filenames, para_joiner='\n\n'): 107 | if not isinstance(filenames, list): 108 | filenames = [filenames] 109 | self.filenames = filenames 110 | self.para_joiner = para_joiner 111 | 112 | def _read_fn(self, filename): 113 | with tf.io.gfile.GFile(filename, 'rb+') as f: 114 | cctx = zstandard.ZstdDecompressor() 115 | reader_stream = io.BufferedReader(cctx.stream_reader(f)) 116 | reader = jsonlines.Reader(reader_stream, loads=json_parser) 117 | for item in reader: 118 | result = dict() 119 | if isinstance(item, str): 120 | result['text'] = item 121 | else: 122 | text = item['text'] 123 | if isinstance(text, list): 124 | text = self.para_joiner.join(text) 125 | result['text'] = text 126 | yield result 127 | 128 | def __iter__(self): 129 | for filename in self.filenames: 130 | return self._read_fn(filename) 131 | 132 | 133 | class ThePileConfig(tfds.core.BuilderConfig): 134 | def __init__(self, *, mode=None, **kwargs): 135 | super(ThePileConfig, self).__init__( 136 | name=mode, 137 | description="The Pile dataset", 138 | **kwargs) 139 | 140 | class ThePile(tfds.core.GeneratorBasedBuilder): 141 | BUILDER_CONFIGS = [ 142 | ThePileConfig(version=_VERSION, mode=mode) for mode in _DATASET_MODES 143 | ] 144 | def _info(self) -> tfds.core.DatasetInfo: 145 | return tfds.core.DatasetInfo( 146 | builder=self, 147 | description=_DESCRIPTION, 148 | features=tfds.features.FeaturesDict({ 149 | 'text': tfds.features.Text() 150 | }), 151 | supervised_keys=("text", "text"), 152 | homepage='https://github.com/EleutherAI/The-Pile', 153 | citation=_CITATION, 154 | ) 155 | 156 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 157 | dl_manager.verify_ssl = False 158 | dl_paths = dl_manager.download(_URLS['the_pile']) 159 | return [ 160 | tfds.core.SplitGenerator( 161 | name=tfds.Split.TRAIN, 162 | gen_kwargs={"paths": dl_paths['train']}), 163 | tfds.core.SplitGenerator( 164 | name=tfds.Split.VALIDATION, 165 | gen_kwargs={"paths": dl_paths['validation']}), 166 | tfds.core.SplitGenerator( 167 | name=tfds.Split.TEST, 168 | gen_kwargs={"paths": dl_paths['test']}), 169 | ] 170 | 171 | def _generate_examples(self, paths): 172 | pipeline = PileReader(paths) 173 | for x, result in enumerate(pipeline): 174 | if result: 175 | idx = f'{x}_the_pile' 176 | yield idx, {'text': result['text']} -------------------------------------------------------------------------------- /the_pile/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import hashlib 4 | from functools import reduce 5 | import operator 6 | import collections 7 | from pathlib import Path 8 | import tarfile 9 | import shutil 10 | 11 | import gdown 12 | from tqdm import tqdm 13 | from best_download import download_file 14 | 15 | def touch(x): 16 | Path(x).touch() 17 | 18 | Source = collections.namedtuple('Source', ['type', 'url']) 19 | 20 | def download(fname, checksum, sources, extract=False): 21 | if os.path.exists(fname + '.done'): return 22 | 23 | print('Finding source for', fname) 24 | 25 | parentdir = Path(fname).parent 26 | os.makedirs(parentdir, exist_ok=True) 27 | 28 | for source in sources: 29 | try: 30 | # todo: implement torrent handling 31 | if source.type == 'direct': 32 | download_file(source.url, fname, checksum) 33 | elif source.type == 'gdrive': 34 | if os.path.exists(fname): 35 | try: 36 | print(fname, 'already exists.') 37 | sha256sum(fname, expected=checksum) 38 | touch(fname + '.done') 39 | return 40 | except AssertionError: 41 | print('{} exists but doesn\'t match checksum!'.format(fname)) 42 | rm_if_exists(fname) 43 | 44 | gdown.download(source.url, fname, quiet=False) 45 | sha256sum(fname, expected=checksum) 46 | elif source.type == 'gcloud': 47 | raise NotImplementedError('gcloud download not implemented!') 48 | 49 | if extract: 50 | tar_xf(fname) 51 | rm_if_exists(fname) 52 | touch(fname + '.done') 53 | return 54 | except SystemExit: 55 | raise 56 | except KeyboardInterrupt: 57 | raise 58 | except: 59 | import traceback 60 | traceback.print_exc() 61 | print('Download method [{}] {} failed, trying next option'.format(source.type, source.url)) 62 | # rm_if_exists(fname) 63 | continue 64 | 65 | break 66 | 67 | raise Exception('Failed to download {} from any source'.format(fname)) 68 | 69 | 70 | def tar_xf(x): 71 | parentdir = Path(x).parent 72 | tf = tarfile.open(x) 73 | tf.extractall(parentdir) 74 | 75 | class ExitCodeError(Exception): pass 76 | 77 | 78 | def stableorder(x): 79 | arr = [(elem, sha256str(elem.encode('utf-8'))) for elem in x] 80 | arr.sort(key=lambda x: x[1]) 81 | return [elem for elem,_ in arr] 82 | 83 | def id(x): 84 | return x 85 | 86 | def utf8len(s): 87 | return len(s.encode('utf-8')) 88 | 89 | def sh(x): 90 | if os.system(x): raise ExitCodeError() 91 | 92 | def fwrite(fname, content): 93 | with open(fname, 'w') as fh: 94 | fh.write(content) 95 | 96 | def fread(fname): 97 | with open(fname) as fh: 98 | return fh.read() 99 | 100 | def ls(x): 101 | return [x + '/' + fn for fn in stableorder(os.listdir(x))] 102 | 103 | 104 | def cycle_documents(dataset): 105 | while True: 106 | yield from filter(id, dataset.documents()) 107 | 108 | def concat(xs): 109 | for x in xs: 110 | yield from x 111 | 112 | 113 | def flatMap(f, x): 114 | return reduce(operator.add, map(f, x), []) 115 | 116 | 117 | def sha256str(s): 118 | h = hashlib.sha256() 119 | h.update(s) 120 | return h.hexdigest() 121 | 122 | def sha256sum(filename, expected=None): 123 | h = hashlib.sha256() 124 | b = bytearray(128*1024) 125 | mv = memoryview(b) 126 | progress = tqdm(total=os.path.getsize(filename), unit="byte", unit_scale=1) 127 | tqdm.write(f"Verifying checksum for {filename}") 128 | with open(filename, 'rb', buffering=0) as f: 129 | for n in iter(lambda : f.readinto(mv), 0): 130 | h.update(mv[:n]) 131 | progress.update(n) 132 | progress.close() 133 | 134 | if expected: 135 | assert h.hexdigest() == expected 136 | print('CHECKSUM OK', filename) 137 | else: 138 | print(filename, h.hexdigest()) 139 | 140 | 141 | def rm_if_exists(path): 142 | try: 143 | if os.path.exists(path): 144 | shutil.rmtree(path) 145 | except NotADirectoryError: 146 | os.remove(path) 147 | 148 | 149 | # https://stackoverflow.com/questions/12523586/python-format-size-application-converting-b-to-kb-mb-gb-tb/37423778 150 | def humanbytes(B, units=None): 151 | 'Return the given bytes as a human friendly KB, MB, GB, or TB string' 152 | B = float(B) 153 | KB = float(1024) 154 | MB = float(KB ** 2) # 1,048,576 155 | GB = float(KB ** 3) # 1,073,741,824 156 | TB = float(KB ** 4) # 1,099,511,627,776 157 | 158 | if (B < KB and units is None) or units == "B": 159 | return '{0} {1}'.format(B,'Bytes' if 0 == B > 1 else 'Byte') 160 | elif (KB <= B < MB and units is None) or units == "KiB": 161 | return '{0:.2f} KiB'.format(B/KB) 162 | elif (MB <= B < GB and units is None) or units == "MiB": 163 | return '{0:.2f} MiB'.format(B/MB) 164 | elif (GB <= B < TB and units is None) or units == "GiB": 165 | return '{0:.2f} GiB'.format(B/GB) 166 | elif (TB <= B and units is None) or units == "TiB": 167 | return '{0:.2f} TiB'.format(B/TB) 168 | 169 | 170 | def strip_markdown_colons(x): 171 | return re.sub(r'^:::.*?\n', '', x, flags=re.MULTILINE) 172 | 173 | def remove_advertisement(x): 174 | return re.sub(r'^Advertisement\n', '', x, flags=re.MULTILINE) 175 | 176 | 177 | def compose(*fs): 178 | def _f(x): 179 | for f in reversed(fs): 180 | x = f(x) 181 | return x 182 | 183 | return _f 184 | 185 | 186 | def parse_size(sizestr): 187 | unit = sizestr[-1] 188 | size = float(sizestr[:-1]) 189 | 190 | if unit.upper() == 'B': 191 | return size 192 | if unit.upper() == 'K': 193 | return size * 1024 194 | if unit.upper() == 'M': 195 | return size * 1024 * 1024 196 | if unit.upper() == 'G': 197 | return size * 1024 * 1024 * 1024 198 | if unit.upper() == 'T': 199 | return size * 1024 * 1024 * 1024 * 1024 200 | 201 | def dummy_meta(xs): 202 | return ((x, {}) for x in xs) 203 | 204 | def chunk_at_even_lines(it, chunksize): 205 | for doc in it: 206 | totlen = 0 207 | res = [] 208 | for i, line in enumerate(doc.split('\n')): 209 | res.append(line) 210 | totlen += len(line) 211 | 212 | if totlen > chunksize and i % 2 == 1: 213 | yield '\n'.join(res) 214 | totlen = 0 215 | res = [] 216 | if res: yield '\n'.join(res) 217 | 218 | --------------------------------------------------------------------------------