├── .gitattributes ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── README.md ├── alembic.ini.template ├── alembic ├── README ├── env.py └── script.py.mako ├── cleaning ├── __init__.py ├── dedupe_from_indexes.py ├── filter_from_reddit_scores.py ├── generate_minhashes.py ├── minhash_lsh_batching.py └── minhash_lsh_dedupe.py ├── data_analysis ├── __init__.py └── final_stats.py ├── mkdocs ├── docs │ ├── background.md │ ├── css │ │ └── extra.css │ ├── index.md │ ├── licence.md │ └── replication.md └── mkdocs.yml ├── pushshift ├── __init__.py ├── download_pushshift_dumps.py ├── generate_urls.py ├── models.py ├── process_dump_files_sqlite.py └── pushshift_to_sqlite.py ├── requirements.txt ├── scraping ├── __init__.py ├── filter.py ├── scrape_urls.py └── scrapers.py └── utils ├── __init__.py ├── archive_stream_readers.py ├── archiver.py ├── logger.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # SageMath parsed files 98 | *.sage.py 99 | 100 | # Environments 101 | .env 102 | .venv 103 | env/ 104 | venv/ 105 | ENV/ 106 | env.bak/ 107 | venv.bak/ 108 | 109 | # Spyder project settings 110 | .spyderproject 111 | .spyproject 112 | 113 | # Rope project settings 114 | .ropeproject 115 | 116 | # mkdocs documentation 117 | /site 118 | 119 | # mypy 120 | .mypy_cache/ 121 | .dmypy.json 122 | dmypy.json 123 | 124 | # Pyre type checker 125 | .pyre/ 126 | .vs* 127 | pushshift_dump_processing.sln 128 | pushshift_dump_processing.pyproj 129 | Untitled.ipynb 130 | playing_around* 131 | process_dump_files.txt 132 | url_count.pkl 133 | url_dedupe.txt 134 | duplication_stats.json 135 | csv/score_over_time.csv 136 | alembic.ini 137 | migration/find_missing_metadata_nplus1.py 138 | testing_archive.jsonl.zst 139 | migration* 140 | testa.py 141 | webtext2_colab_old.ipynb 142 | possibly_useful* 143 | data_analysis/document_count_by_stage.py 144 | data_analysis/score_over_time.py 145 | data_analysis/score_over_time_scrapes.py 146 | test_db.py 147 | cleaning/aggregate_archives_by_month.py 148 | mkdocs/site/* 149 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | version: 2 6 | 7 | mkdocs: 8 | configuration: mkdocs/mkdocs.yml 9 | fail_on_warning: false -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenWebText2 2 | 3 | This project is part of EleutherAI's quest to create a massive repository of high quality text data for training language models. 4 | 5 | Very briefly, OpenWebText2 is a large filtered dataset of text documents scraped from URL found on Reddit submisisons. 6 | 7 | The plug and play version of OpenWebText2 contains: 8 | - 17,103,059 documents 9 | - 65.86GB uncompressed text 10 | 11 | ## Download Dataset / Documentation 12 | 13 | For further information please visit our [documentation](https://openwebtext2.readthedocs.io/en/latest/). 14 | 15 | ## Acknowledgements 16 | [researcher2](https://github.com/researcher2) Wrote much of this code, with inspiration and some straight copying of the scraping code found [here](https://github.com/yet-another-account/openwebtext/).
17 | [sdtblck](https://github.com/sdtblck/) kindly put together the Colab notebook, and performed a chunk of the scraping.
18 | [leogao2](https://github.com/leogao2/) provided overall design guidance, lm_dataformat, and performed another chunk of scraping.
19 | [Colaboratory](https://colab.research.google.com/) VMs helped us with about 10% of our overall scraping.
20 | [The Eye](http://the-eye.eu/) host our processed datasets.
21 | [Read The Docs](https://readthedocs.org/) host our documentation.
22 | -------------------------------------------------------------------------------- /alembic.ini.template: -------------------------------------------------------------------------------- 1 | # A generic, single database configuration. 2 | 3 | [alembic] 4 | # path to migration scripts 5 | script_location = alembic 6 | 7 | # template used to generate migration files 8 | # file_template = %%(rev)s_%%(slug)s 9 | 10 | # timezone to use when rendering the date 11 | # within the migration file as well as the filename. 12 | # string value is passed to dateutil.tz.gettz() 13 | # leave blank for localtime 14 | # timezone = 15 | 16 | # max length of characters to apply to the 17 | # "slug" field 18 | # truncate_slug_length = 40 19 | 20 | # set to 'true' to run the environment during 21 | # the 'revision' command, regardless of autogenerate 22 | # revision_environment = false 23 | 24 | # set to 'true' to allow .pyc and .pyo files without 25 | # a source .py file to be detected as revisions in the 26 | # versions/ directory 27 | # sourceless = false 28 | 29 | # version location specification; this defaults 30 | # to alembic/versions. When using multiple version 31 | # directories, initial revisions must be specified with --version-path 32 | # version_locations = %(here)s/bar %(here)s/bat alembic/versions 33 | 34 | # the output encoding used when revision files 35 | # are written from script.py.mako 36 | # output_encoding = utf-8 37 | 38 | sqlalchemy.url = sqlite:///e:/Eleuther_AI/webtext2/dumps/submissions.sqlite 39 | 40 | [post_write_hooks] 41 | # post_write_hooks defines scripts or Python functions that are run 42 | # on newly generated revision scripts. See the documentation for further 43 | # detail and examples 44 | 45 | # format using "black" - use the console_scripts runner, against the "black" entrypoint 46 | # hooks=black 47 | # black.type=console_scripts 48 | # black.entrypoint=black 49 | # black.options=-l 79 50 | 51 | # Logging configuration 52 | [loggers] 53 | keys = root,sqlalchemy,alembic 54 | 55 | [handlers] 56 | keys = console 57 | 58 | [formatters] 59 | keys = generic 60 | 61 | [logger_root] 62 | level = WARN 63 | handlers = console 64 | qualname = 65 | 66 | [logger_sqlalchemy] 67 | level = WARN 68 | handlers = 69 | qualname = sqlalchemy.engine 70 | 71 | [logger_alembic] 72 | level = INFO 73 | handlers = 74 | qualname = alembic 75 | 76 | [handler_console] 77 | class = StreamHandler 78 | args = (sys.stderr,) 79 | level = NOTSET 80 | formatter = generic 81 | 82 | [formatter_generic] 83 | format = %(levelname)-5.5s [%(name)s] %(message)s 84 | datefmt = %H:%M:%S 85 | -------------------------------------------------------------------------------- /alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /alembic/env.py: -------------------------------------------------------------------------------- 1 | from logging.config import fileConfig 2 | 3 | from sqlalchemy import engine_from_config 4 | from sqlalchemy import pool 5 | 6 | from alembic import context 7 | 8 | # this is the Alembic Config object, which provides 9 | # access to the values within the .ini file in use. 10 | config = context.config 11 | 12 | # Interpret the config file for Python logging. 13 | # This line sets up loggers basically. 14 | fileConfig(config.config_file_name) 15 | 16 | # add your model's MetaData object here 17 | # for 'autogenerate' support 18 | # from myapp import mymodel 19 | # target_metadata = mymodel.Base.metadata 20 | import sys 21 | sys.path = ['', '..'] + sys.path[1:] # ew ew ew ew ew 22 | from pushshift.models import base 23 | target_metadata = base.metadata 24 | 25 | # other values from the config, defined by the needs of env.py, 26 | # can be acquired: 27 | # my_important_option = config.get_main_option("my_important_option") 28 | # ... etc. 29 | 30 | 31 | def run_migrations_offline(): 32 | """Run migrations in 'offline' mode. 33 | 34 | This configures the context with just a URL 35 | and not an Engine, though an Engine is acceptable 36 | here as well. By skipping the Engine creation 37 | we don't even need a DBAPI to be available. 38 | 39 | Calls to context.execute() here emit the given string to the 40 | script output. 41 | 42 | """ 43 | url = config.get_main_option("sqlalchemy.url") 44 | context.configure( 45 | url=url, 46 | target_metadata=target_metadata, 47 | literal_binds=True, 48 | dialect_opts={"paramstyle": "named"}, 49 | ) 50 | 51 | with context.begin_transaction(): 52 | context.run_migrations() 53 | 54 | 55 | def run_migrations_online(): 56 | """Run migrations in 'online' mode. 57 | 58 | In this scenario we need to create an Engine 59 | and associate a connection with the context. 60 | 61 | """ 62 | connectable = engine_from_config( 63 | config.get_section(config.config_ini_section), 64 | prefix="sqlalchemy.", 65 | poolclass=pool.NullPool, 66 | ) 67 | 68 | with connectable.connect() as connection: 69 | context.configure( 70 | connection=connection, target_metadata=target_metadata 71 | ) 72 | 73 | with context.begin_transaction(): 74 | context.run_migrations() 75 | 76 | 77 | if context.is_offline_mode(): 78 | run_migrations_offline() 79 | else: 80 | run_migrations_online() 81 | -------------------------------------------------------------------------------- /alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from alembic import op 9 | import sqlalchemy as sa 10 | ${imports if imports else ""} 11 | 12 | # revision identifiers, used by Alembic. 13 | revision = ${repr(up_revision)} 14 | down_revision = ${repr(down_revision)} 15 | branch_labels = ${repr(branch_labels)} 16 | depends_on = ${repr(depends_on)} 17 | 18 | 19 | def upgrade(): 20 | ${upgrades if upgrades else "pass"} 21 | 22 | 23 | def downgrade(): 24 | ${downgrades if downgrades else "pass"} 25 | -------------------------------------------------------------------------------- /cleaning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/cleaning/__init__.py -------------------------------------------------------------------------------- /cleaning/dedupe_from_indexes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script builds a list of all duplicates by file_id & document_id, and then iterates 3 | through all ".minscored" files from the filename lookup, creating a new archive for each 4 | file in the original containing all documents that were not marked as duplicates during 5 | the previous step. 6 | 7 | So for each original file, a "_final.jsonl.zst" files will be output in the original 8 | directory. 9 | 10 | Arguments 11 | ------ 12 | --batch_directory (-dir) 13 | Directory containing the "*duplicates.txt" files along with the "file_name_lookup.pkl" 14 | created during batch slicing. The "_final.jsonl.zst" files will be output in their 15 | original directories. 16 | """ 17 | 18 | import glob 19 | import os 20 | import pickle 21 | import argparse 22 | 23 | import tqdm 24 | 25 | from utils.archiver import Archive, Reader 26 | 27 | import logging 28 | from utils.logger import setup_logger_tqdm 29 | logger = logging.getLogger(__name__) 30 | 31 | def main(batch_directory): 32 | file_name_lookup_path = os.path.join(batch_directory, "file_name_lookup.pkl") 33 | file_name_lookup = pickle.load(open(file_name_lookup_path,"rb")) 34 | 35 | logger.info("Building duplicates dictionary...") 36 | duplicates_dict = {file_id : set() for file_id in range(len(file_name_lookup))} 37 | duplicate_files = glob.glob(os.path.join(batch_directory, "*_duplicates.txt")) 38 | for duplicate_file in duplicate_files: 39 | with open(duplicate_file, "r") as fh: 40 | duplicates = fh.read().splitlines() 41 | for duplicate in duplicates: 42 | file_id, document_id = tuple(map(int, duplicate.split(" "))) 43 | duplicates_dict[file_id].add(document_id) 44 | 45 | logger.info("De-duplicating files...") 46 | for file_id, original_file_name in enumerate(tqdm.tqdm(file_name_lookup)): 47 | final_file_name = original_file_name.replace("_default.jsonl.zst.deduped.merged.minscored", 48 | "_final.jsonl.zst") 49 | 50 | reader = Reader() 51 | count = 0 52 | archiver = Archive(final_file_name) 53 | for document, metadata in reader.read_jsonl(original_file_name, get_meta=True): 54 | if count not in duplicates_dict[file_id]: 55 | archiver.add_data(document, metadata) 56 | count += 1 57 | archiver.commit() 58 | 59 | parser = argparse.ArgumentParser(description='Dedupe from provided indexes.') 60 | parser.add_argument("-dir", "--batch_directory", default="") 61 | 62 | if __name__ == '__main__': 63 | logfile_path = "dedupe_from_index.log" 64 | setup_logger_tqdm(logfile_path) 65 | 66 | args = parser.parse_args() 67 | main(args.batch_directory) -------------------------------------------------------------------------------- /cleaning/filter_from_reddit_scores.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script filters all scrape files "scrapes_*.jsonl.zst" by minimum total Reddit score. 3 | Unlike the original WebText we aggregate scores for all submissions containing a given 4 | URL so the bar is slightly lower in some cases, but in others where a URL went negative 5 | in one submission it will balance out. 6 | 7 | The filtered scrapes file will have the original name and path of the scrape file with a 8 | ".minscored" extension. 9 | 10 | Arguments 11 | --------- 12 | --scrape_directory (-dir) 13 | Directory containing the scrapes. You could use the overall work directory if you 14 | want as we use glob.glob to search recursively. 15 | """ 16 | 17 | import argparse 18 | import glob 19 | import os 20 | import sys 21 | import math 22 | from functools import reduce 23 | from operator import add 24 | 25 | import tqdm 26 | from tqdm_multiprocess import TqdmMultiProcessPool 27 | 28 | from utils.archiver import Reader, Archive 29 | 30 | import logging 31 | from utils.logger import setup_logger_tqdm 32 | logger = logging.getLogger(__name__) 33 | 34 | million = math.pow(10, 6) 35 | 36 | # Multiprocessed 37 | def process_file(file_path, tqdm_func, global_tqdm): 38 | reader = Reader() 39 | 40 | filtered_archive_path = file_path + ".minscored" 41 | archiver = Archive(filtered_archive_path) 42 | 43 | for document, metadata in reader.read_jsonl(file_path, get_meta=True): 44 | total_score = reduce(add, metadata["reddit_scores"]) 45 | if total_score >= 3: 46 | archiver.add_data(document, metadata) 47 | 48 | global_tqdm.update(os.path.getsize(file_path)) 49 | archiver.commit() 50 | 51 | def filter_from_reddit_scores(scrape_directory): 52 | files = glob.glob(os.path.join(scrape_directory, "**/scrapes_*.jsonl.zst"), recursive=True) 53 | total_file_size = reduce(add, map(os.path.getsize, files)) 54 | logger.info(f"Total File Size: {(total_file_size / million):.2f} MB") 55 | 56 | # [(file_name, [doc0_minhash, doc1_minhash, ...]), ....] 57 | with tqdm.tqdm(total=total_file_size, dynamic_ncols=True, unit_scale=1) as progress: 58 | pool = TqdmMultiProcessPool() 59 | process_count = 4 60 | tasks = [] 61 | for file_path in files: 62 | task = (process_file, (file_path,)) 63 | tasks.append(task) 64 | 65 | on_done = lambda _ : None 66 | on_error = on_done 67 | result = pool.map(process_count, progress, tasks, on_error, on_done) 68 | 69 | return result 70 | 71 | parser_description = 'Filter scrapes based on minimum reddit scores.' 72 | parser = argparse.ArgumentParser(description=parser_description) 73 | parser.add_argument("-dir", "--scrape_directory", default="") 74 | 75 | if __name__ == '__main__': 76 | args = parser.parse_args() 77 | if not os.path.isdir(args.scrape_directory): 78 | print("Scrape directory doesn't exist, exiting.") 79 | sys.exit(0) 80 | 81 | log_file = "filter_from_reddit_scores.log" 82 | setup_logger_tqdm(log_file) 83 | 84 | logger.info("Filtering scrapes based on minimum reddit scores.") 85 | filter_from_reddit_scores(args.scrape_directory) 86 | 87 | -------------------------------------------------------------------------------- /cleaning/generate_minhashes.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script calculates minhashes for all filtered scrape files found using a recursive 3 | search on "*.minscored". 4 | 5 | More explicity, we create a set of 5-grams for each document, and generate document 6 | level minhashes using 10 hash functions with the excellent datasketch library. 7 | 8 | A single file "minhashes.pkl" is created in the scrape directory storing a data 9 | structure in the following format: 10 | 11 | [(file_name1, [doc0_minhash, doc1_minhash, ...]), (file_name2, [....]), ....] 12 | 13 | Arguments 14 | --------- 15 | --scrape_directory (-dir) 16 | Directory containing the minscored scrapes. You could use the overall work directory if you 17 | want as we use glob.glob to search recursively. 18 | --process_count (-procs) 19 | Number of worker processes in the pool. Defaults to 4. 20 | """ 21 | 22 | import argparse 23 | import glob 24 | import os 25 | import sys 26 | import math 27 | from functools import reduce 28 | from operator import add 29 | from contextlib import redirect_stdout 30 | 31 | import tqdm 32 | import nltk 33 | from nltk.util import ngrams 34 | from datasketch import MinHash, LeanMinHash 35 | from tqdm_multiprocess import TqdmMultiProcessPool 36 | 37 | from utils.utils import timed_pickle_dump 38 | from utils.archiver import Reader 39 | 40 | import logging 41 | from utils.logger import setup_logger_tqdm 42 | logger = logging.getLogger(__name__) 43 | 44 | million = math.pow(10, 6) 45 | 46 | def extract_ngrams(data, num): 47 | n_grams = ngrams(nltk.word_tokenize(data), num) 48 | return [ ' '.join(grams) for grams in n_grams] 49 | 50 | # Multiprocessed 51 | def process_file(file_path, tqdm_func, global_tqdm): 52 | reader = Reader() 53 | minhashes = [] 54 | previous_file_position = 0 55 | for document, metadata in reader.read_jsonl(file_path, get_meta=True): 56 | 57 | n_grams = extract_ngrams(document, 5) 58 | five_gram_set = set(n_grams) 59 | minhash = MinHash(num_perm=10) 60 | for five_gram in five_gram_set: 61 | minhash.update(five_gram.encode('utf8')) 62 | minhashes.append(LeanMinHash(minhash)) 63 | 64 | # Update Progress Bar 65 | current_file_position = reader.fh.tell() 66 | global_tqdm.update(current_file_position - previous_file_position) 67 | previous_file_position = current_file_position 68 | 69 | return file_path, minhashes 70 | 71 | def generate_minhashes(scrape_directory, process_count): 72 | files = glob.glob(os.path.join(scrape_directory, "**/*.minscored"), recursive=True) 73 | total_file_size = reduce(add, map(os.path.getsize, files)) 74 | logger.info(f"Total File Size: {(total_file_size / million):.2f} MB") 75 | 76 | # [(file_name1, [doc0_minhash, doc1_minhash, ...]), (file_name2, [....]), ....] 77 | with tqdm.tqdm(total=total_file_size, dynamic_ncols=True, unit_scale=1) as progress: 78 | pool = TqdmMultiProcessPool() 79 | tasks = [] 80 | for file_path in files: 81 | task = (process_file, (file_path,)) 82 | tasks.append(task) 83 | 84 | on_done = lambda _ : None 85 | on_error = on_done 86 | result = pool.map(process_count, progress, tasks, on_error, on_done) 87 | 88 | return result 89 | 90 | parser_description = 'Generate minhashes for all documents found.' 91 | parser = argparse.ArgumentParser(description=parser_description) 92 | parser.add_argument("-dir", "--scrape_directory", default="") 93 | parser.add_argument("-procs", "--process_count", type=int, default=4) 94 | 95 | if __name__ == '__main__': 96 | args = parser.parse_args() 97 | if not os.path.isdir(args.scrape_directory): 98 | print("Scrape directory doesn't exist, exiting.") 99 | sys.exit(0) 100 | 101 | with redirect_stdout(open(os.devnull, "w")): 102 | nltk.download('punkt') 103 | 104 | log_file = "generate_minhashes.log" 105 | setup_logger_tqdm(log_file) 106 | 107 | logger.info("Generating document level minhashes from 5 gram sets") 108 | minhashes_by_file = generate_minhashes(args.scrape_directory, args.process_count) 109 | 110 | output_pickle_path = os.path.join(args.scrape_directory, "minhashes.pkl") 111 | timed_pickle_dump(minhashes_by_file, output_pickle_path, "minhashes_by_file") 112 | 113 | -------------------------------------------------------------------------------- /cleaning/minhash_lsh_batching.py: -------------------------------------------------------------------------------- 1 | """ 2 | Splits minhashes.pkl into approximately the desired number of batches. 3 | As we always split on file lines this won't always be exact unless all 4 | files have the same number of documents. 5 | 6 | The "directory" must contain a 'minhashes.pkl' file created with 'generate_minhashes.py'. 7 | 8 | Produces batch files named 'batch0.pkl, batch1.pkl ...'. They contain the following 9 | pickled data structure: 10 | [(file_id, [doc0_minhash, doc1_minhash, ...]), ....] 11 | 12 | Produces a file name lookup named 'file_name_lookup.pkl'. Contains the following 13 | pickled data structure: 14 | [file_name1, file_name2, file_name3, ...] 15 | 16 | Arguments 17 | ------ 18 | --directory (-dir) 19 | Directory containing the 'minhashes.pkl' file. Batch files and 20 | file name lookup will be saved here. 21 | --number_of_batches (-batches) 22 | Approximate number of batches to split minhashes into. 23 | """ 24 | 25 | import os 26 | import argparse 27 | import pickle 28 | 29 | import tqdm 30 | from utils.utils import timed_pickle_dump, timed_pickle_load 31 | 32 | import logging 33 | from utils.logger import setup_logger_tqdm 34 | logger = logging.getLogger(__name__) 35 | 36 | def main(number_of_batches, batch_directory): 37 | minhashes_pickle_path = os.path.join(batch_directory, "minhashes.pkl") 38 | 39 | # [(file_name, [doc0_minhash, doc1_minhash, ...]), ....] 40 | minhashes = timed_pickle_load(minhashes_pickle_path, "minhashes") 41 | 42 | logger.info("Splitting minhashes for batching...") 43 | total_documents = 0 44 | for _ , documents in minhashes: 45 | total_documents += len(documents) 46 | 47 | document_count = 0 48 | documents_per_batch = total_documents / number_of_batches 49 | current_batch = [] 50 | batch_count = 0 51 | for file_id, (file_name, documents) in tqdm.tqdm(enumerate(minhashes)): 52 | document_count += len(documents) 53 | current_batch.append((file_id, documents)) # Note we only store globally unique file_id here 54 | 55 | if document_count > (batch_count + 1) * documents_per_batch: 56 | batch_pickle_file_path = os.path.join(batch_directory, f"batch{batch_count}.pkl") 57 | timed_pickle_dump(current_batch, batch_pickle_file_path, f"batch {batch_count} minhashes") 58 | current_batch = [] 59 | batch_count += 1 60 | 61 | if current_batch: 62 | batch_pickle_file_path = os.path.join(batch_directory, f"batch{batch_count}.pkl") 63 | timed_pickle_dump(current_batch, batch_pickle_file_path, f"batch {batch_count} minhashes") 64 | current_batch = None 65 | 66 | file_name_lookup = [file_name for file_name, documents in minhashes] 67 | file_name_lookup_path = os.path.join(batch_directory, "file_name_lookup.pkl") 68 | timed_pickle_dump(file_name_lookup, file_name_lookup_path, "Filename lookup") 69 | 70 | document_count_path = os.path.join(batch_directory, "document_count.pkl") 71 | pickle.dump(total_documents, open(document_count_path,"wb")) 72 | 73 | 74 | parser = argparse.ArgumentParser(description='Generate batches of minhashes for cassandra lsh dedupe.') 75 | parser.add_argument("-dir", "--directory", default="") 76 | parser.add_argument("-batches", "--number_of_batches", type=int, required=True) 77 | 78 | if __name__ == '__main__': 79 | logfile_path = "minhash_lsh_batching.log" 80 | setup_logger_tqdm(logfile_path) 81 | 82 | args = parser.parse_args() 83 | 84 | main(args.number_of_batches, args.directory) -------------------------------------------------------------------------------- /cleaning/minhash_lsh_dedupe.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script uses MinHashLSH from the datasketch library to deduplicate 3 | across the whole set of document minhashes. 4 | 5 | See http://ekzhu.com/datasketch/lsh.html for more details. 6 | 7 | We use the Cassandra backend to keep memory usage low and a 0.5 threshold 8 | for duplicate detection. 9 | 10 | In it's current state, the script creates and pickles a MinHashLSH object 11 | containing the parameters for a local cassandra instance. If you want 12 | a remote instance you could either change the script or perform port 13 | redirection to a remote machine with SSH tunelling for example. This uses 14 | the default Cassandra port. 15 | 16 | lsh = MinHashLSH( 17 | threshold=0.5, num_perm=10, storage_config={ 18 | 'type': 'cassandra', 19 | 'cassandra': { 20 | 'seeds': ['127.0.0.1'], 21 | 'keyspace': 'minhash_lsh_keyspace', 22 | 'replication': { 23 | 'class': 'SimpleStrategy', 24 | 'replication_factor': '1', 25 | }, 26 | 'drop_keyspace': False, 27 | 'drop_tables': False, 28 | } 29 | } 30 | ) 31 | 32 | Importantly, once you run the script you must keep using the same lsh object 33 | if you want the same basenames for the created Cassandra tables. 34 | 35 | We are running into issues with loading this MinHashLSH object in multiple processes 36 | so did our run with just a single process. The first run will freeze when trying 37 | to unpickle the MinHashLSH object in the worker process. On running a second time it 38 | will work. 39 | 40 | We save file and document level checkpoints after each query to allow for easy resuming. 41 | Each batch file will have a corresponding "*_duplicates.txt" when done. 42 | 43 | Arguments 44 | ------ 45 | --batch_directory (-dir) 46 | Directory containing the "batch*.pkl" files. "lsh.pkl", duplicate lists and 47 | batch checkpoints will be saved here. 48 | --process_count (-procs) 49 | Number of processes in the pool. Defaults to 1. 50 | """ 51 | 52 | import os 53 | import glob 54 | import argparse 55 | import json 56 | import pickle 57 | 58 | import tqdm 59 | from datasketch import MinHashLSH 60 | from tqdm_multiprocess import TqdmMultiProcessPool 61 | 62 | from utils.utils import Timer, timed_pickle_dump, timed_pickle_load 63 | 64 | import logging 65 | from utils.logger import setup_logger_tqdm 66 | logger = logging.getLogger(__name__) 67 | 68 | def get_minhash_lsh_cassandra(): 69 | lsh = MinHashLSH( 70 | threshold=0.5, num_perm=10, storage_config={ 71 | 'type': 'cassandra', 72 | 'cassandra': { 73 | 'seeds': ['127.0.0.1'], 74 | 'keyspace': 'minhash_lsh_keyspace', 75 | 'replication': { 76 | 'class': 'SimpleStrategy', 77 | 'replication_factor': '1', 78 | }, 79 | 'drop_keyspace': False, 80 | 'drop_tables': False, 81 | } 82 | } 83 | ) 84 | return lsh 85 | 86 | def minhash_lsh_dedupe_cassandra(batch_minhashes_pickle_path, lsh_pickle_path, tqdm_func, global_tqdm): 87 | # [(file_id, [doc0_minhash, doc1_minhash, ...]), ....] 88 | batch_minhashes = timed_pickle_load(batch_minhashes_pickle_path, "batch minhashes") 89 | 90 | # For some reason this will freeze when loading on the first run. 91 | lsh = timed_pickle_load(lsh_pickle_path, "lsh") 92 | 93 | checkpoint_file = batch_minhashes_pickle_path.replace(".pkl","_ckpt.pkl") 94 | if os.path.exists(checkpoint_file): 95 | ckpt_file_id, ckpt_document_id = pickle.load(open(checkpoint_file,"rb")) 96 | else: 97 | ckpt_file_id = -1 98 | ckpt_document_id = -1 99 | 100 | logger.info("Detecting duplicates") 101 | timer = Timer().start() 102 | duplicate_file_path = batch_minhashes_pickle_path.replace(".pkl", "_duplicates.txt") 103 | with open(duplicate_file_path, "a") as fh: 104 | for file_id, documents in batch_minhashes: 105 | if file_id <= ckpt_file_id: 106 | global_tqdm.update(len(documents)) 107 | continue 108 | for document_id, minhash in enumerate(documents): 109 | if document_id <= ckpt_document_id: 110 | global_tqdm.update(ckpt_document_id + 1) 111 | ckpt_document_id = -1 112 | continue 113 | results = lsh.query(minhash) 114 | duplicate_found = True if results else False 115 | is_self = False 116 | for json_results in results: 117 | found_file_id, found_document_id = json.loads(json_results) 118 | # This check is needed in case you re-run things 119 | if file_id == found_file_id and document_id == found_document_id: 120 | duplicate_found = False 121 | is_self = True 122 | break 123 | 124 | if duplicate_found: 125 | fh.write(f"{file_id} {document_id}\n") 126 | else: 127 | if not is_self: 128 | lsh.insert(json.dumps((file_id, document_id)), minhash) 129 | 130 | global_tqdm.update() 131 | pickle.dump((file_id, document_id), open(checkpoint_file,"wb")) 132 | 133 | logger.info(timer.stop_string()) 134 | 135 | return True 136 | 137 | def main(process_count, batch_directory): 138 | 139 | # Ensure LSH object containing cassandra connection info exists 140 | lsh_pickle_path = os.path.join(batch_directory, "lsh.pkl") 141 | if not os.path.exists(lsh_pickle_path): 142 | logger.info("Getting cassandra minhash lsh") 143 | lsh = get_minhash_lsh_cassandra() 144 | timed_pickle_dump(lsh, lsh_pickle_path, "lsh") 145 | 146 | files = glob.glob(os.path.join(batch_directory, "batch*.pkl"), recursive=True) 147 | 148 | pool = TqdmMultiProcessPool() 149 | tasks = [] 150 | 151 | document_count_path = os.path.join(batch_directory, "document_count.pkl") 152 | total_documents = pickle.load(open(document_count_path,"rb")) 153 | 154 | for batch_file in files: 155 | arguments = (batch_file, lsh_pickle_path) 156 | task = (minhash_lsh_dedupe_cassandra, arguments) 157 | tasks.append(task) 158 | 159 | on_done = lambda _ : logger.info("done") 160 | on_error = lambda _ : logger.info("error") 161 | with tqdm.tqdm(total=total_documents, dynamic_ncols=True) as progress: 162 | result = pool.map(process_count, progress, tasks, on_error, on_done) 163 | logger.info(result) 164 | 165 | parser = argparse.ArgumentParser(description='Minhash LSH dedupe with cassandra backend.') 166 | parser.add_argument("-dir", "--batch_directory", default="") 167 | parser.add_argument("-procs", "--process_count", type=int, default=1) 168 | 169 | if __name__ == '__main__': 170 | logfile_path = "minhash_lsh_dedupe.log" 171 | setup_logger_tqdm(logfile_path) 172 | 173 | args = parser.parse_args() 174 | 175 | main(args.process_count, args.batch_directory) -------------------------------------------------------------------------------- /data_analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/data_analysis/__init__.py -------------------------------------------------------------------------------- /data_analysis/final_stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | import pickle 5 | import math 6 | 7 | import tqdm 8 | 9 | from utils.archiver import Reader 10 | 11 | import logging 12 | from utils.logger import setup_logger_tqdm 13 | logger = logging.getLogger(__name__) 14 | 15 | def get_stats_old(): 16 | batch_directory = "/home/researcher2/webtext2/test" 17 | files = glob.glob(os.path.join(batch_directory, "*_duplicates.txt")) 18 | duplicate_count = 0 19 | for file_path in files: 20 | with open(file_path, "r") as fh: 21 | duplicate_count += len(fh.readlines()) 22 | 23 | document_count_path = os.path.join(batch_directory, "document_count.pkl") 24 | document_count = pickle.load(open(document_count_path, "rb")) 25 | 26 | print("Total Duplicates: ", duplicate_count) 27 | print("Original Documents: ", document_count) 28 | 29 | useful_percentage = (1 - duplicate_count / document_count) * 100 30 | print(f"Useful Data: {useful_percentage:0.2f}%") 31 | 32 | def get_stats(final_directory): 33 | 34 | reader = Reader() 35 | files = glob.glob(os.path.join(final_directory, "*jsonl.zst")) 36 | 37 | document_count = 0 38 | total_text_size = 0 39 | logger.info("Getting final document count and total uncompressed text size.") 40 | for file_path in tqdm.tqdm(files, dynamic_ncols=True): 41 | for document, metadata in reader.read_jsonl(file_path, get_meta=True): 42 | document_count += 1 43 | total_text_size += len(document) 44 | 45 | return document_count, total_text_size 46 | 47 | parser = argparse.ArgumentParser(description='Final statistics') 48 | parser.add_argument("-dir", "--final_directory", default="") 49 | 50 | if __name__ == '__main__': 51 | logfile_path = "final_statistics.log" 52 | setup_logger_tqdm(logfile_path) 53 | 54 | args = parser.parse_args() 55 | 56 | final_count, total_text_size = get_stats(args.final_directory) 57 | billion = math.pow(10, 9) 58 | logger.info(f"Final Document Count: {final_count:,}") 59 | print(f"Total uncompressed text size: {(total_text_size / billion):.2f} GB") 60 | pickle.dump((final_count, total_text_size), open("final_stats.pkl", "wb")) -------------------------------------------------------------------------------- /mkdocs/docs/background.md: -------------------------------------------------------------------------------- 1 | # WebText Background 2 | 3 | OpenAI required around 40gb of high quality text corpus for training GPT2. While Common Crawl provides the scale necessary for modern language models, the quality is unreliable. Manual curation of Common Crawl is always an option, albeit an expensive one. Thankfully Reddit provides decentralized curation by design, and this became the key innovation for the WebText dataset. 4 | 5 | The generation of WebText can be summarized as: 6 | 7 | 1. Scrape URLs from all Reddit submissions up to December 2017 with 3 or higher score. 8 | 2. Deduplicate scraped content based on URL 9 | 3. Exclude wikipedia - OpenAI already had a separate Wikipedia dataset 10 | 4. Deduplicate remaining content using undisclosed "heuristic based cleaning". This includes removal of non-english web pages. 11 | 12 | Neither the resulting corpus or generation source code was made public, inspiring Aaron Gokaslan and Vanya Cohen to create the OpenWebTextCorpus. 13 | 14 | OpenWebTextCorpus is an open source reproduction of WebText, reifying the "heuristic based cleaning" stage with fuzzy deduplication and enforcing a minimum token length. For content based de-duplication they used local-sensitivity-hashing (LSH) with minhash on sets of 5-grams at the document level. Documents were then tokenized and any with less then 128 tokens were removed. After all processing there remained 40GB of text across 8,013,769 documents. 15 | 16 | The original code for OpenWebTextCorpus unavailable at this time, but there are several popular repositories that cover the pipeline to various degrees. 17 | 18 | ## OpenWebText2 Motivation 19 | 20 | Our primary goals for the corpus are: 21 | 22 | 1. More data! Coverage of the original OpenWebTextCorpus ended at December 2017. 23 | 2. Include all languages, providing metadata for easy filtering 24 | 3. Provide several versions of the generated corpus for differing user requirements. Both versions will be broken up by month and frozen, with future months available once PushShift submission dumps become available. 25 | * Raw version containing all scraped pages with associated Reddit submission metadata 26 | * Plug and play version based on submissions of minimum 3 score with content based fuzzy de-duplication 27 | 4. Provide full source code for all stages of the pipeline including deduplication. 28 | 29 | We decided on a rewrite taking inspiration from both 1 and 2. -------------------------------------------------------------------------------- /mkdocs/docs/css/extra.css: -------------------------------------------------------------------------------- 1 | table { 2 | margin-bottom: 16px; 3 | } 4 | 5 | table th { 6 | padding: 8px; 7 | } 8 | 9 | table td { 10 | padding: 8px; 11 | } 12 | 13 | .rst-content table code { 14 | white-space: nowrap; 15 | } 16 | 17 | p { 18 | margin-bottom: 16px; 19 | } 20 | 21 | .download-button 22 | { 23 | 24 | } 25 | 26 | .download-button svg { 27 | margin-left: 4px; 28 | border: none; 29 | } 30 | 31 | .darken { 32 | filter: brightness(75%); 33 | } -------------------------------------------------------------------------------- /mkdocs/docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome! 2 | 3 | OpenWebText2 is an enhanced version of the original OpenWebTextCorpus covering all Reddit submissions from 2005 up until April 2020, with further months becoming available after the corresponding PushShift dump files are released. 4 | 5 | In case you haven't heard of WebText, the core principle is extracting URLs from reddit submissions, scraping the URLs, then performing filtering & deduplication. See [Background](background) for more information. 6 | 7 |
8 | 9 | ## Download Plug and Play Version 10 | This version has already been cleaned for you: 11 | 12 | - Deduplicated by URL 13 | - Filtered by minimum combined reddit score 3 14 | - Deduplicated at document level with MinHashLSH. 15 | 16 | **Stats**
17 | 17,103,059 documents
18 | 65.86 GB uncompressed text
19 | 28 GB compressed including text and metadata 20 | 21 | 22 | 23 | 30 | 31 | 32 |
33 | 34 | ## Download Raw Scrapes Version 35 | Only deduplicated by URL. 36 | 37 | **Stats**
38 | 69,547,149 documents
39 | 193.89gb uncompressed text.
40 | 79gb compressed including text and metadata 41 | 42 | 43 | 50 | 51 | 52 |
53 | 54 | ## Using The Data 55 | 56 | The data is stored using lm_dataformat. We use a slightly modified version to allow file peeking for tqdm progress bars: utils/archiver.py. Be sure to call *read_jsonl* with `get_meta=True` as both versions contain useful metadata for each document, including several original Reddit fields. 57 | 58 | ```python 59 | import glob 60 | import os 61 | import math 62 | 63 | import tqdm 64 | 65 | from utils.archiver import Reader 66 | 67 | document_count = 0 68 | total_text_size = 0 69 | dataset_directory = "PATH_TO_FILES" 70 | files = glob.glob(os.path.join(dataset_directory, "*jsonl.zst")) 71 | for file_path in tqdm.tqdm(files, dynamic_ncols=True): 72 | reader = Reader() 73 | for document, metadata in reader.read_jsonl(file_path, get_meta=True): 74 | document_count += 1 75 | total_text_size += len(document) 76 | 77 | billion = math.pow(10, 9) 78 | print(f"Total Document Count: {document_count:,}") 79 | print(f"Total Uncompressed Text Size: {(total_text_size / billion):.2f} GB") 80 | ``` 81 | 82 | Alternatively checkout The-Pile, which acts as an aggregator/dataloader for multiple text datasets. It allows you to configure your total data size requirement, along with the desired weighting for each subset. Once configured, you get a randomized stream of documents, allowing easy feeding to your language model. 83 | 84 | ## Cite as 85 | 86 |
87 | @article{pile,
88 |     title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
89 |     author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
90 |     journal={arXiv preprint arXiv:2101.00027},
91 |     year={2020}
92 | }
93 | 
94 | -------------------------------------------------------------------------------- /mkdocs/docs/licence.md: -------------------------------------------------------------------------------- 1 | # MIT License 2 | 3 | Copyright (c) 2020 EleutherAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mkdocs/docs/replication.md: -------------------------------------------------------------------------------- 1 | # Dataset Replication 2 | 3 | This area of the documentation provides instructions for building the full dataset from scratch. If you just want the dataset, please see [Welcome](/). 4 | 5 | PushShift provides dumps of all reddit posts and submissions, however they are normally a few months behind. While this would be problematic for certain use cases, we didn't require up to the minute data for training GPTNeo. In the future we may look into getting recent data either by scraping Reddit directly or using one of the existing APIs. 6 | 7 | ## Pipeline Overview 8 | 9 | At a high level the pipeline works as follows: 10 | 11 | 1. Download and process the PushShift submission dumps to extract unique URLs & Metadata. 12 | 2. Scrape the URLs using Newspaper3k, saving both text and metadata with lm_dataformat. 13 | 5. Filter the scraped documents by minimum Reddit score 3. 14 | 4. Perform fuzzy deduplication using MinHashLSH. 15 | 5. Package up the various dataset releases. 16 | 6. Produce some useful size stats for the releases. 17 | 18 | ## Environment Setup 19 | We tested everything on Ubuntu 18/20 & Windows 10 with miniconda. You could use virtualenv, venv or even the global python environment if you wish. 20 | 21 | ### Miniconda Install For Linux 22 | 23 | Follow the below steps, or read the conda instructions:
24 | https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html 25 | 26 | ```bash 27 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 28 | sha256sum Miniconda3-latest-Linux-x86_64.sh 29 | bash Miniconda3-latest-Linux-x86_64.sh 30 | ``` 31 | Select yes on the init step. 32 | 33 | Restart your shell to refresh the path. 34 | 35 | ### Create and activate conda environment 36 | 37 | Environments are saved in a central store on the local disk, no need to create folders like with venv. 38 | ``` 39 | conda create --name pushshift python=3.8 40 | conda activate pushshift 41 | ``` 42 | 43 | ### Install Repo and Requirements 44 | ```bash 45 | git clone https://github.com/EleutherAI/openwebtext2.git 46 | cd openwebtext2 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ### General Recommendations 51 | 52 | Use the screen command if running a remote terminal, many scripts below take a long time to run and while they often support resuming it's better not to rely on it. 53 | 54 | Create a working directory on a drive with at least 500gb of space. 55 | 56 | ## Stage 1 - Processing PushShift Submission Dumps 57 | 58 | This stage consists of the following steps: 59 | 60 | 1. Sqlite database setup. 61 | 2. Download and verify the PushShift submission dumps, extracting and storing urls and metadata 62 | from relevant submissions into the sqlite database. Performed by "pushshift/pushshift_to_sqlite.py". 63 | 3. Query the populated sqlite database and build a list of URLs with metadata for all related submissions. 64 | Performed by "pushshift/generate_urls_from_sqlite.py" 65 | 66 | All scripts for this stage can be found within the "pushshift" package. 67 | 68 | ### Sqlite Database Setup 69 | We use alembic to manage the sqlite database in case you want to add extra indexes or fields easily later. 70 | 71 | Inside the cloned repo: 72 | ```bash 73 | cp alembic.ini.template alembic.ini 74 | ``` 75 | Modify the following line within the alembic.ini file. For example on windows: 76 | ```python 77 | sqlalchemy.url = sqlite:///e:/Eleuther_AI/openwebtext2/submissions.sqlite 78 | ``` 79 | Or on Linux (notice the extra forward slash before "mnt" to indicate system root): 80 | ```python 81 | sqlalchemy.url = sqlite:////mnt/data/openwebtext2/submissions.sqlite 82 | ``` 83 | 84 | ### Insert PushShift Submission Data Into Sqlite 85 | 86 | This step is performed by *pushshift/pushshift_to_sqlite.py*. 87 | 88 | | Script Argument | Description | 89 | | -----------: | ----------- | 90 | | `--start_period (-s)` | Month and Year of first pushshift dump. Default: 6,2005. | 91 | | `--finish_period (-f)` | Month and Year of final pushshift dump. Defaults to current month. | 92 | | `--output_directory (-dir)` | Will contain the dumps subdirectory created as part of the process. | 93 | | `--keep_dumps (-kd)` | If specified the dumps won't be deleted after successful processing. | 94 | 95 | Notice the database location is not specified here, this is always sourced from the alembic.ini file. 96 | 97 | For example on Linux, to download and process all dumps, leaving the downloaded dumps afterwards: 98 | ```bash 99 | python -m pushshift.pushshift_to_sqlite -dir /mnt/data/openwebtext2 -kd 100 | ``` 101 | 102 | Test run on 2006 only, deleting dumps when done: 103 | ```bash 104 | python -m pushshift.pushshift_to_sqlite -s 1,2006 -f 12,2006 -dir /mnt/data/openwebtext2 105 | ``` 106 | 107 | This step uses checkpointing, saving a .dbdone file for each dump once processing is complete. So if you need to stop and come back later you can. 108 | 109 | ### Processing the PushShift Dumps With jq and sort instead 110 | 111 | TO DO. 112 | 113 | 114 | ### Extract Unique URLs with Reddit Metadata 115 | 116 | This step is performed by *pushshift/generate_urls.py*. Currently only supports sqlite mode, working on the json->sorted tsv version now. 117 | 118 | | Script Argument | Description | 119 | | -----------: | ----------- | 120 | | `--start_period (-s)` | Month and Year of first URLs. Defaults to None (query all URLs). | 121 | | `--finish_period (-f)` | Month and Year of final URLs. Defaults to None (query all URLs). | 122 | | `--output_directory (-dir)` | Will contain the urls subdirectory created as part of the process. | 123 | | `--urls_per_file` | Maximum number of urls per file. Defaults to 100,000. | 124 | | `--min_score (-score)` | Minimum aggregate submissions score to include url. Defaults to 3. | 125 | | `--data_source (-source)` | Where to find sorted URLs: "db" or "tsv". tsv doesn't support date ranges. Defaults "to db". | 126 | 127 | 128 | 129 | Notice the database location is not specified here, this is always sourced from the alembic.ini file. 130 | 131 | For example on Linux, to extract all urls 132 | ```bash 133 | python -m pushshift.generate_urls -dir /mnt/data/openwebtext2 134 | ``` 135 | 136 | Test run on 2006 only: 137 | ```bash 138 | python -m pushshift.generate_urls -s 1,2006 -f 12,2006 -dir /mnt/data/openwebtext2 139 | ``` 140 | 141 | ## Stage 2 - Scraping From Sourced URLs 142 | 143 | This stage is performed by *scraping/scrape_urls.py* and took several weeks compute time. To decrease this you can run on multiple servers by passing out the URL files. 144 | 145 | | Script Argument | Description | 146 | | -----------: | ----------- | 147 | | `--job_directory (-dir)` | Base directory containing the urls subdirectory and location where the scrapes subdirectory will be created. | 148 | | `--process_count (-procs)` | Number of worker processes in the pool. Defaults to 60. Don't go above this on Windows. | 149 | | `--request_timeout (-timeout)` | Scraping timeout for each URL. Defaults to 30 seconds. | 150 | 151 | The script iterates through URL files generated in step 2 above. For each file its hands out the URLs 152 | to a multiprocessing pool for scraping. Once all URLs in the batch are scraped, the successful results are 153 | archived using a slightly modified version of lm_dataformat. For each document (URL), the following metadata fields are saved in the metadata dict offered by lm_dataformat: 154 | 155 | | Meta Field | Description | 156 | | -----------: | ----------- | 157 | | title | Web Page Title. | 158 | | lang | Language detected by Newspaper scraper. | 159 | | url | Original URL. | 160 | | word_count | Total words outputted by Newspaper. | 161 | | elapsed | Scraping time. | 162 | | scraper | Always "newspaper". | 163 | | domain | Top level domain for the original URL. | 164 | | reddit_id | List of submission IDs containing URL - converted from base36. | 165 | | subreddit | List of subreddits for the corresponding submissions. | 166 | | reddit_score | List of reddit scores for the corresponding submissions. | 167 | | reddit_title | List of submissions titles for the corresponding submissions. | 168 | | reddit_created_utc | List of submissions created times for the corresponding submissions. | 169 | 170 | The program will look for URL files within "job_directory/urls". All scrapes will be stored in "job_directory/scrapes" 171 | 172 | For example on Linux, this will scrape using 90 processes and 30 second timeout: 173 | ```bash 174 | python -m scraping.scrape_urls -dir /mnt/data/openwebtext2 -procs 90 175 | ``` 176 | 177 | On a dedicated 2012 i7 Linux machine we used between 90 and 120 processes successfully. 178 | 179 | We do some limited URL filtering in *scraping/filter.py*. This is mainly to speed up the process by avoiding timeouts or files that obviously won't contain text. 180 | 181 | Once each URL file is scraped, the program saves a ".done" file so you can resume later without rescraping. That file contains a count of successfully scraped URLs if you are interested. 182 | 183 | ## Stage 3 - Filtering scraped documents by minimum total Reddit score 184 | 185 | This stage is performed by *cleaning/filter_from_reddit_score.py*. 186 | 187 | | Script Argument | Description | 188 | | -----------: | ----------- | 189 | | `--scrape_directory (-dir)` | Directory containing the scrapes. You could use the overall work directory if you want as we use glob.glob to search recursively. | 190 | 191 | The script filters all scrape files "scrapes_*.jsonl.zst" by minimum total Reddit score. 192 | Unlike the original WebText we aggregate scores for all submissions containing a given 193 | URL so the bar is slightly lower in some cases, but in others where a URL went negative 194 | in some submission it will balance out. 195 | 196 | The filtered scrapes file will have the original name and path of the scrape file with a 197 | ".minscored" extension. 198 | 199 | For example on Linux: 200 | ```bash 201 | python -m cleaning.filter_from_reddit_score -dir /mnt/data/openwebtext2/scrapes 202 | ``` 203 | 204 | ## Stage 4 - Deduplicate Filtered Documents using MinHashLSH with Cassandra 205 | 206 | There are several sub-stages here: 207 | 208 | 1. Setup Cassandra 209 | 2. Generate minhashes for every document 210 | 3. Batch up the minhashes for running parallel dedupe 211 | 4. Using MinHashLSH With Cassandra - Generate lists of duplicates 212 | 5. Deduplicating our documents using the lists from step 3. 213 | 214 | All scripts for this stage can be found within the "cleaning" package. 215 | 216 | ### Setup Cassandra 217 | 218 | We used a local Cassandra install, simplifying the setup process. Some good cassandra guides: 219 | 220 | Installation Guide For Ubuntu 20
221 | Introduction To Cassandra + Connecting With Python API 222 | 223 | Summarized Quick Install For Ubuntu 20: 224 | ```bash 225 | sudo apt install openjdk-8-jdk 226 | sudo apt install apt-transport-https 227 | wget -q -O - https://www.apache.org/dist/cassandra/KEYS | sudo apt-key add - 228 | sudo sh -c 'echo "deb http://www.apache.org/dist/cassandra/debian 311x main" > /etc/apt/sources.list.d/cassandra.list' 229 | sudo apt update 230 | sudo apt install cassandra 231 | sudo systemctl status cassandra 232 | ``` 233 | 234 | To test your installation was successful, run the cqlsh CLI: 235 | ```bash 236 | cqlsh 237 | ``` 238 | 239 | Once inside: 240 | ```describe keyspaces``` 241 | 242 | If you want multiple nodes or remote connection you need to set the following in your /etc/cassandra/cassandra.yaml: 243 | 244 | seeds: "your_server_external_ip, other nodes in cluster"
245 | listen_address: your_server_external_ip
246 | start_rpc: true
247 | rpc_address: 0.0.0.0 (this will bind to same address as listen_address) 248 | 249 | For some reason they recommend not to make this available on the internet despite supporting various forms of authentication. So either use a tunnel or fancy networking to get around this. 250 | 251 | ### Generate Minhashes For Every Document 252 | 253 | This step is performed by *cleaning/generate_minhashes.py* and took about 1.5 days 254 | on a 2012 i7 Linux machine. 255 | 256 | | Script Argument | Description | 257 | | -----------: | ----------- | 258 | | `scrape_directory (-dir)` | Directory containing the minscored scrapes. You could use the overall work directory if you want as we use glob.glob to search recursively. | 259 | | `process_count (-procs)` | Number of worker processes in the pool. Defaults to 4. | 260 | 261 | This script calculates minhashes for all filtered scrape files found using a recursive 262 | search on "\*.minscored". 263 | 264 | More explicity, we create a set of 5-grams for each document, and generate document 265 | level minhashes using 10 hash functions with the excellent datasketch library. 266 | 267 | A single file "minhashes.pkl" is created in the scrape directory storing a data 268 | structure in the following format: 269 | 270 | ```python 271 | [(file_name1, [doc0_minhash, doc1_minhash, ...]), (file_name2, [....]), ....] 272 | ``` 273 | 274 | For example on Linux: 275 | ```bash 276 | python -m cleaning.generate_minhashes -dir /mnt/data/openwebtext2/scrapes 277 | ``` 278 | 279 | ### Slice The Minhashes For Batching 280 | 281 | This step is performed by *cleaning/minhash_lsh_batching.py*. 282 | 283 | | Script Argument | Description | 284 | | -----------: | ----------- | 285 | | `directory (-dir) ` | Directory containing the 'minhashes.pkl' file. Batch files and file name lookup will be saved here. | 286 | | `number_of_batches (-batches) ` | Approximate number of batches to split minhashes into. | 287 | 288 | The "directory" must contain a 'minhashes.pkl' file created with *cleaning/generate_minhashes.py*. 289 | 290 | This splits "minhashes.pkl" into approximately the desired number of batches, producing batch files named 'batch0.pkl, batch1.pkl, etc'. They contain the following pickled data structure: 291 | ```python 292 | [(file_id, [doc0_minhash, doc1_minhash, ...]), ....] 293 | ``` 294 | 295 | It also creates a file name lookup named 'file_name_lookup.pkl' containing the following pickled datastructure: 296 | ```python 297 | [file_name1, file_name2, file_name3, ...] 298 | ``` 299 | 300 | For example on Linux with 16 batches: 301 | ```bash 302 | python -m cleaning.minhash_lsh_batching -dir /mnt/data/openwebtext2/scrapes -batches 16 303 | ``` 304 | 305 | ### Find Duplicates Using MinHashLSH with Cassandra 306 | 307 | This step is performed by *cleaning/minhash_lsh_dedupe.py*. 308 | 309 | | Script Argument | Description | 310 | | -----------: | ----------- | 311 | | `batch_directory (-dir)` | Directory containing the "batch\*.pkl" files. Duplicate lists and batch checkpoints will be saved here. | 312 | | `process_count (-procs)` | Number of processes in the pool. Defaults to 4. | 313 | 314 | The script generates a list of detected duplicates for files/documents located in the various "batch\*.pkl" files. 315 | 316 | We save file and document level checkpoints after each query to allow for easy resuming. 317 | Each batch file will have a corresponding "\*duplicates.txt" when done. 318 | 319 | For example on Linux with the default 4 processes: 320 | 321 | ```bash 322 | python -m cleaning.minhash_lsh_dedupe -dir /mnt/data/openwebtext2/scrapes 323 | ``` 324 | 325 | ### De-Duplicating Using Generated Duplicate Lists 326 | 327 | This step is performed by *cleaning/dedupe_from_indexes.py*. 328 | 329 | | Script Argument | Description | 330 | | -----------: | ----------- | 331 | | `batch_directory (-dir)` | Directory containing the "\*duplicates.txt" files along with the "file_name_lookup.pkl" created during batch slicing. The "\*final.jsonl.zst" files will be output in their original directories. | 332 | 333 | This script builds a list of all duplicates by file_id & document_id, and then iterates 334 | through all ".minscored" files from the filename lookup, creating a new archive for each 335 | file in the original containing all documents that were not marked as duplicates during 336 | the previous step. 337 | 338 | For each original file, a "\*final.jsonl.zst" files will be output in the same directory. 339 | 340 | For example on Linux: 341 | ```bash 342 | python -m cleaning.dedupe_from_indexes -dir /mnt/data/openwebtext2/scrapes 343 | ``` 344 | 345 | ## Stage 5 - Packaging The Dataset Releases 346 | 347 | ### Plug And Play Release 348 | 349 | We originally did processing by month, but now just use files containing scrapes for the original 350 | URL files. 351 | 352 | Simply tar all the "\*final.jsonl.zst" files. 353 | 354 | ### Raw Scrapes Release 355 | 356 | Similarly just tar all the "scrapes_\*.jsonl.zst" files. 357 | 358 | ## Stage 6 - Produce Release Stats 359 | 360 | If you move the files from each release into their own subdirectory, you can run the "data_analysis/final_stats.py" 361 | to get total document count and text size for all "jsonl.zst" files in each directory: 362 | 363 | For example on Linux: 364 | ```bash 365 | python -m data_analysis.final_stats -dir /mnt/data/openwebtext2/final 366 | python -m data_analysis.final_stats -dir /mnt/data/openwebtext2/raw_release 367 | ``` 368 | -------------------------------------------------------------------------------- /mkdocs/mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: OpenWebText2 2 | theme: readthedocs 3 | docs_dir: 'docs' 4 | nav: 5 | - Welcome: index.md 6 | - WebText Background: background.md 7 | - Dataset Replication: replication.md 8 | - Licence: licence.md 9 | extra_css: 10 | - css/extra.css 11 | -------------------------------------------------------------------------------- /pushshift/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/pushshift/__init__.py -------------------------------------------------------------------------------- /pushshift/download_pushshift_dumps.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module is responsible for downloading the PushShift submission dump 3 | files, and contains the following functions: 4 | 5 | Functions 6 | --------- 7 | get_url_content_length(url) 8 | Attempts to retrieve the Content-Length header when performing 9 | a get request on the provided 'url'. Returns the number of bytes 10 | if available, or None. 11 | 12 | build_file_list(start_date, end_date): 13 | Builds a list of PushShift submission dump files located at 14 | "https://files.pushshift.io/reddit/submissions" within the desired date 15 | range. 16 | 17 | get_sha256sums 18 | Downloads the sha256sum file for the PushShift submission dumps from 19 | "https://files.pushshift.io/reddit/submissions/sha256sums.txt". Builds 20 | and returns a dictionary with file name as key and sha256 sum as value. 21 | """ 22 | 23 | from dateutil.relativedelta import * 24 | import math 25 | 26 | import requests 27 | 28 | import logging 29 | logger = logging.getLogger(__name__) 30 | 31 | million = math.pow(10, 6) 32 | possible_archive_formats = ["zst", "xz", "bz2"] 33 | 34 | def get_url_content_length(url): 35 | response = requests.head(url) 36 | response.raise_for_status() 37 | 38 | if "Content-Length" in response.headers: 39 | return int(response.headers['Content-length']) 40 | else: 41 | return None 42 | 43 | def build_file_list(start_date, end_date): 44 | base_url = "https://files.pushshift.io/reddit/submissions" 45 | url_list = [] 46 | date = start_date 47 | current_year = None 48 | while date <= end_date: 49 | year = date.strftime("%Y") 50 | month = date.strftime("%m") 51 | 52 | if year != current_year: 53 | current_year = year 54 | logger.info(f"Scanning Year {current_year}") 55 | 56 | if year < "2011": 57 | url = f"{base_url}/RS_v2_{year}-{month}.xz" 58 | url_list.append(url) 59 | else: 60 | for extension in possible_archive_formats: 61 | url = f"{base_url}/RS_{year}-{month}.{extension}" 62 | try: 63 | get_url_content_length(url) # If this fails there's no file 64 | url_list.append(url) 65 | break 66 | except: 67 | pass 68 | 69 | date = date + relativedelta(months=+1) 70 | 71 | return url_list 72 | 73 | def get_sha256sums(): 74 | sha256sum_url = "https://files.pushshift.io/reddit/submissions/sha256sums.txt" 75 | 76 | sha256sum_lookup = {} 77 | with requests.get(sha256sum_url) as response: 78 | response.raise_for_status() 79 | for line in response.text.splitlines(): 80 | if line.strip(): 81 | sha256sum, file_name = tuple(line.strip().split(" ")) 82 | sha256sum_lookup[file_name] = sha256sum 83 | 84 | return sha256sum_lookup -------------------------------------------------------------------------------- /pushshift/generate_urls.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script produces files containing the urls and associated Reddit metadata 3 | for a given period, defaulting to 1,000,000 urls per file. Using lm_dataformat, 4 | one record is stored for each URL, along with the metadata of all submissions for 5 | that particular URL. 6 | 7 | No separate URL based deduplication is required unless you run multiple iterations of 8 | the script across different time periods (unimplemented but very simple). 9 | 10 | Note that we don't filter by score at this stage as the full pipeline scrapes all urls 11 | and leaves the filtering to be done by the user if they don't want the plug and play version. 12 | 13 | Arguments 14 | --------- 15 | --start_period (-s) 16 | Month and Year of first URLs. Defaults to None (query all URLs). 17 | --finish_period (-f) 18 | Month and Year of final URLs. Defaults to None (query all URLs). 19 | --output_directory (-dir) 20 | Base directory that will contain the urls subdirectory created as part of the process. 21 | --urls_per_file 22 | Maximum number of urls per file. Defaults to 100,000. 23 | --min_score 24 | Minimum aggregate submissions score to include url. 25 | --data_source 26 | Where to find sorted URLs: "db" or "tsv". tsv doesn't support date ranges. 27 | 28 | If both start_period and finish_period are blank then we can use a faster query on the reddit_submission 29 | table. 30 | """ 31 | 32 | import datetime 33 | from dateutil.relativedelta import * 34 | import argparse 35 | import os 36 | import json 37 | from functools import partial 38 | import sys 39 | 40 | from utils.archiver import Archive 41 | from .models import RedditSubmission, get_db_session 42 | 43 | import logging 44 | from utils.logger import setup_logger_tqdm 45 | logger = logging.getLogger(__name__) 46 | 47 | def get_from_db(start_date, end_date): 48 | db_session = get_db_session() 49 | 50 | # SELECT id, url, score, title, subreddit, created_utc 51 | # FROM reddit_submission 52 | # WHERE created_utc >= start_date and created_utc <= end_date 53 | # ORDER BY url 54 | select_fields = (RedditSubmission.id, RedditSubmission.url, RedditSubmission.score, 55 | RedditSubmission.title, RedditSubmission.subreddit, RedditSubmission.created_utc) 56 | 57 | if start_date or end_date: 58 | month_end = end_date + relativedelta(months=+1) 59 | query = db_session.query(*select_fields) \ 60 | .filter(RedditSubmission.created_utc >= start_date) \ 61 | .filter(RedditSubmission.created_utc < month_end) \ 62 | .order_by(RedditSubmission.url) \ 63 | .yield_per(1000) 64 | else: 65 | query = db_session.query(*select_fields) \ 66 | .order_by(RedditSubmission.url) \ 67 | .yield_per(1000) 68 | 69 | logger.info("Querying sqlite database for submissions") 70 | logger.info(query) 71 | return query 72 | 73 | def get_from_tsv(): 74 | pass 75 | 76 | 77 | def generate_urls(url_directory, urls_per_file, min_score, source): 78 | 79 | url_batch = 0 80 | url_file_path = os.path.join(url_directory, f"urls_{url_batch}.jsonl.zst") 81 | archiver = Archive(url_file_path) 82 | 83 | current_url = "" 84 | current_meta = {} 85 | current_meta["id"] = [] 86 | current_meta["score"] = [] 87 | current_meta["title"] = [] 88 | current_meta["subreddit"] = [] 89 | current_meta["created_utc"] = [] 90 | 91 | total_url_count = 0 92 | url_count = 0 93 | logger.info("Generating now...") 94 | for submission_id, url, score, title, subreddit, created_utc in source(): 95 | if not current_url: 96 | current_url = url 97 | elif url != current_url: 98 | # New URL - Add Old URL and meta to archive if score is high enough 99 | total_score = sum(current_meta["score"]) 100 | if (total_score >= min_score): 101 | archiver.add_data(current_url, current_meta) 102 | url_count += 1 103 | total_url_count += 1 104 | 105 | # Commit and Init New Archive if full 106 | if url_count == urls_per_file: 107 | archiver.commit() 108 | url_batch += 1 109 | url_file_path = os.path.join(url_directory, f"urls_{url_batch}.jsonl.zst") 110 | archiver = Archive(url_file_path) 111 | url_count = 0 112 | 113 | current_url = url 114 | current_meta = {} 115 | current_meta["id"] = [] 116 | current_meta["score"] = [] 117 | current_meta["title"] = [] 118 | current_meta["subreddit"] = [] 119 | current_meta["created_utc"] = [] 120 | 121 | current_meta["id"].append(submission_id) 122 | current_meta["score"].append(score) 123 | current_meta["title"].append(title) 124 | current_meta["subreddit"].append(subreddit) 125 | current_meta["created_utc"].append(created_utc) 126 | 127 | if url_count > 0: 128 | archiver.add_data(current_url, current_meta) 129 | total_url_count += 1 130 | archiver.commit() 131 | 132 | url_count_path = os.path.join(url_directory, "url_count.json") 133 | json.dump(total_url_count, open(url_count_path, "w")) 134 | 135 | parser_description = 'Generate URL files from sqlite database containing URLs and reddit metadata.' 136 | parser = argparse.ArgumentParser(description=parser_description) 137 | parser.add_argument("-s", "--start_period", default=None) 138 | parser.add_argument("-f", "--finish_period", default=None) 139 | parser.add_argument("-dir", "--output_directory", default="") 140 | parser.add_argument("--urls_per_file", type=int, default=100000) 141 | parser.add_argument("-score", "--min_score", type=int, default=3) 142 | parser.add_argument("-source", "--data_source", default="db") 143 | 144 | if __name__ == '__main__': 145 | args = parser.parse_args() 146 | 147 | logfile_path = "generate_urls.log" 148 | setup_logger_tqdm(logfile_path) 149 | 150 | # If both none we can just query all and use the index on url field 151 | # Otherwise we use the index on created_utc and have to do a costly url sort 152 | if args.start_period or args.finish_period: 153 | if args.start_period: 154 | start_month, start_year = tuple(map(int,args.start_period.split(","))) 155 | start_date = datetime.datetime(start_year, start_month, 1) 156 | else: 157 | start_date = datetime.datetime(2005, 6, 1) 158 | 159 | if args.finish_period: 160 | finish_month, finish_year = tuple(map(int,args.finish_period.split(","))) 161 | end_date = datetime.datetime(finish_year, finish_month, 1) 162 | else: 163 | end_date = datetime.datetime.now() 164 | 165 | logger.info(f"Finding URLs between {start_date.strftime('%Y-%m')} and {end_date.strftime('%Y-%m')}") 166 | else: 167 | logger.info(f"Finding all URLs.") 168 | start_date = None 169 | end_date = None 170 | 171 | urls_directory = os.path.join(args.output_directory, "urls") 172 | 173 | logger.info(f"Urls output directory: {urls_directory}") 174 | logger.info(f"Minimum score: {args.min_score}") 175 | logger.info(f"URLs per file: {args.urls_per_file}") 176 | 177 | if args.data_source == "db": 178 | source = partial(get_from_db, start_date, end_date) 179 | elif args.data_source == "tsv": 180 | source = get_from_tsv 181 | else: 182 | logger.info(f"Invalid source {args.data_source}") 183 | sys.exit(-1) 184 | 185 | logger.info(f"Data source: {args.data_source}") 186 | 187 | generate_urls(urls_directory, args.urls_per_file, args.min_score, source) 188 | 189 | -------------------------------------------------------------------------------- /pushshift/models.py: -------------------------------------------------------------------------------- 1 | """ 2 | SqlAlchemy model for RedditSubmission along with some database helper functions. 3 | To populate the database see pushshift_to_sqlite.py. 4 | """ 5 | 6 | import os 7 | import configparser 8 | 9 | from sqlalchemy.ext.declarative import declarative_base 10 | from sqlalchemy import Column, Integer, Text, DateTime 11 | from sqlalchemy import create_engine 12 | from sqlalchemy.schema import Index 13 | from sqlalchemy.orm import sessionmaker 14 | 15 | base = declarative_base() 16 | 17 | class RedditSubmission(base): 18 | __tablename__ = "reddit_submission" 19 | id = Column(Integer, primary_key=True) # Converted from Base36 20 | url = Column(Text, index=True) 21 | score = Column(Integer) 22 | title = Column(Text) 23 | subreddit = Column(Text) 24 | created_utc = Column(DateTime, index=True) 25 | 26 | # Index('idx_url_created', RedditSubmission.url, RedditSubmission.created_utc) 27 | 28 | def get_db_session(): 29 | config = configparser.ConfigParser() 30 | config.read('alembic.ini') 31 | db_url = config["alembic"]["sqlalchemy.url"] 32 | 33 | fresh_db = False 34 | db_file_path = db_url.replace("sqlite:///","") 35 | if not os.path.exists(db_file_path): 36 | fresh_db = True 37 | 38 | engine = create_engine(db_url) 39 | if fresh_db: 40 | base.metadata.create_all(engine) 41 | 42 | Session = sessionmaker(bind=engine) 43 | db_session = Session() 44 | return db_session 45 | 46 | def recreate_db(): 47 | config = configparser.ConfigParser() 48 | config.read('alembic.ini') 49 | db_url = config["alembic"]["sqlalchemy.url"] 50 | 51 | db_file_path = db_url.replace("sqlite:///","") 52 | if os.path.exists(db_file_path): 53 | os.remove(db_file_path) # sqlite doesn't truncate on drop_all 54 | 55 | engine = create_engine(db_url) 56 | base.metadata.create_all(engine) 57 | 58 | if __name__ == '__main__': 59 | recreate_db() -------------------------------------------------------------------------------- /pushshift/process_dump_files_sqlite.py: -------------------------------------------------------------------------------- 1 | """ 2 | Called from pushshift_to_sqlite.py. 3 | 4 | Processes a PushShift submission dump file, storing the url and relevant Reddit metadata 5 | into the sqlite database specified in alembic.ini. Note the reddit submission 6 | id is converted from base36 first and the created_utc is stored as a datetime 7 | object. 8 | 9 | process_dump_file is the entry point, requiring you to specify 'dump_file_path' 10 | and 'output_directory'. Supports tqdm-multiprocess. 11 | 12 | metadata = {} 13 | metadata["id"] = base36.loads(post["id"]) 14 | metadata["subreddit"] = post.get("subreddit") 15 | metadata["title"] = post.get("title") 16 | metadata["score"] = post.get("score") 17 | metadata["created_utc"] = datetime.datetime.fromtimestamp(int(post["created_utc"])) 18 | """ 19 | 20 | import json 21 | import os 22 | import math 23 | import datetime 24 | 25 | import base36 26 | from sqlalchemy import exc 27 | 28 | from .models import RedditSubmission 29 | from utils.archive_stream_readers import get_archive_stream_reader 30 | 31 | import logging 32 | logger = logging.getLogger() 33 | 34 | million = math.pow(10, 6) 35 | 36 | chunk_size = int(16 * million) # Must be larger then the size of any post 37 | 38 | def process_reddit_post(post): 39 | is_self = post.get("is_self") 40 | if is_self is None or is_self: 41 | return None 42 | 43 | url = post.get("url") 44 | if url is None or url == "": 45 | return None 46 | 47 | reddit_submission = RedditSubmission() 48 | reddit_submission.id = base36.loads(post["id"]) 49 | reddit_submission.subreddit = post.get("subreddit") 50 | reddit_submission.title = post.get("title") 51 | reddit_submission.score = post.get("score", 0) 52 | reddit_submission.created_utc = datetime.datetime.fromtimestamp(int(post["created_utc"])) 53 | reddit_submission.url = url 54 | 55 | return reddit_submission 56 | 57 | def process_dump_file(dump_file_path, db_session, tqdm_func): 58 | logging.info(f"Processing dump file '{dump_file_path}'") 59 | dump_file_size = os.path.getsize(dump_file_path) 60 | 61 | previous_file_position = 0 62 | count = 0 63 | insert_batch_size = 100000 64 | with get_archive_stream_reader(dump_file_path) as reader, \ 65 | tqdm_func(total=dump_file_size, unit="byte", unit_scale=1) as progress: 66 | 67 | progress.set_description(f"Processing {os.path.basename(dump_file_path)}") 68 | 69 | previous_line = "" 70 | while True: 71 | chunk = reader.read(chunk_size) 72 | if not chunk: 73 | break 74 | 75 | # Update Progress Bar 76 | current_file_position = reader.tell() 77 | progress.update(current_file_position - previous_file_position) 78 | previous_file_position = current_file_position 79 | 80 | # Process chunk + leftover, ignore possibly incomplete last line 81 | try: 82 | string_data = chunk.decode("utf-8") 83 | except UnicodeDecodeError as ex: 84 | logger.info(f"Error in position {current_file_position} in file {dump_file_path}") 85 | logger.info(ex) 86 | continue 87 | lines = string_data.split("\n") 88 | for i, line in enumerate(lines[:-1]): 89 | if i == 0: 90 | line = previous_line + line 91 | 92 | reddit_post = None 93 | try: 94 | reddit_post = json.loads(line) 95 | except Exception as ex: 96 | logger.info(f"JSON decoding failed: {ex}") 97 | continue 98 | 99 | reddit_submission = process_reddit_post(reddit_post) 100 | if reddit_submission: 101 | db_session.add(reddit_submission) 102 | count += 1 103 | 104 | if count == insert_batch_size: 105 | logging.info(f"Committing {count} records to db.") 106 | try: 107 | db_session.commit() 108 | except exc.IntegrityError: 109 | logger.info(f"Duplicate INSERT, ignoring.") 110 | db_session.rollback() 111 | count = 0 112 | 113 | previous_line = lines[-1] 114 | 115 | if count > 0: 116 | logging.info(f"Committing {count} records to db.") 117 | try: 118 | db_session.commit() 119 | except exc.IntegrityError: 120 | logger.info(f"Duplicate INSERT, ignoring.") 121 | db_session.rollback() 122 | count = 0 123 | 124 | logging.info("Done with file.") 125 | -------------------------------------------------------------------------------- /pushshift/pushshift_to_sqlite.py: -------------------------------------------------------------------------------- 1 | """ 2 | Builds a list of PushShift submission dump files located in "https://files.pushshift.io/reddit/submissions" 3 | within the desired date range, and then performs the following steps for each file. Note this can't be done 4 | with a multiprocessing pool due to locking issues with sqlite. 5 | 6 | 1. Download and verify the file using the available sha256 sums 7 | 2. Process the file, storing the url and relevant Reddit metadata into the sqlite database 8 | specified in alembic.ini (copy alembic.ini.template and set sqlalchemy.url). 9 | 3. Create a .dbdone file to mark the particular file as being processed, allowing script resume. 10 | 4. Delete the PushShift dump file to save storage space if --keep_dumps not specified. 11 | 12 | Arguments 13 | --------- 14 | --start_period (-s) 15 | Month and Year of first pushshift dump. Default: 6,2005 16 | --finish_period (-f) 17 | Month and Year of final pushshift dump. Defaults to current month, ignoring any missing months. 18 | --output_directory (-dir) 19 | Base directory that will contain the dumps subdirectory created as part of the process. 20 | --keep_dumps (-kd) 21 | If specified the dump won't be deleted after successful processing. 22 | """ 23 | 24 | import datetime 25 | import os 26 | import argparse 27 | import sys 28 | 29 | from best_download import download_file 30 | import cutie 31 | import tqdm 32 | 33 | from .download_pushshift_dumps import build_file_list, get_sha256sums 34 | from .process_dump_files_sqlite import process_dump_file 35 | from .models import get_db_session 36 | 37 | import logging 38 | from utils.logger import setup_logger_tqdm 39 | logger = logging.getLogger(__name__) 40 | 41 | def reddit_processing(url, sha256sums, dumps_directory, keep_dumps): 42 | 43 | base_name = url.split('/')[-1] 44 | dump_file_path = os.path.join(dumps_directory, base_name) 45 | db_done_file = dump_file_path + ".dbdone" 46 | 47 | if os.path.exists(db_done_file): 48 | return True 49 | 50 | try: 51 | download_file(url, dump_file_path, sha256sums.get(base_name)) 52 | except Exception as ex: 53 | logger.info(f"Download failed {ex}, skipping processing.") 54 | return False 55 | 56 | db_session = get_db_session() 57 | process_dump_file(dump_file_path, db_session, tqdm.tqdm) 58 | 59 | with open(db_done_file, "w") as fh: 60 | fh.write("Done!") 61 | 62 | if not keep_dumps: 63 | os.remove(dump_file_path) 64 | 65 | return True 66 | 67 | parser = argparse.ArgumentParser(description='Download PushShift submission dumps, extra urls') 68 | parser.add_argument("-s", "--start_period", default="6,2005") 69 | parser.add_argument("-f", "--finish_period", default=None) 70 | parser.add_argument("-dir", "--output_directory", default="") 71 | parser.add_argument("-kd", "--keep_dumps", action='store_true') 72 | 73 | # First available file: https://files.pushshift.io/reddit/submissions/RS_v2_2005-06.xz 74 | def main(): 75 | logfile_path = "download_pushshift_dumps.log" 76 | setup_logger_tqdm(logfile_path) # Logger will write messages using tqdm.write 77 | 78 | args = parser.parse_args() 79 | 80 | start_month, start_year = tuple(map(int,args.start_period.split(","))) 81 | start_date = datetime.datetime(start_year, start_month, 1) 82 | 83 | if args.finish_period: 84 | finish_month, finish_year = tuple(map(int,args.finish_period.split(","))) 85 | end_date = datetime.datetime(finish_year, finish_month, 1) 86 | else: 87 | end_date = datetime.datetime.now() 88 | 89 | logger.info("Running Script - PushShift submission dumps to sqlite") 90 | logger.info("Downloading and processing dumps in the following range:") 91 | logger.info(start_date.strftime("Start Period: %m-%Y")) 92 | logger.info(end_date.strftime("End Period: %m-%Y")) 93 | 94 | dumps_directory = os.path.join(args.output_directory, "dumps") 95 | 96 | if os.path.isdir(dumps_directory): 97 | message = f"Directory '{dumps_directory}' already exists, if there are done files" \ 98 | " in the directory then these particular months will be skipped. Delete" \ 99 | " these files or the directory to avoid this." 100 | logger.info(message) 101 | if not cutie.prompt_yes_or_no('Do you want to continue?'): 102 | sys.exit(0) 103 | 104 | os.makedirs(dumps_directory, exist_ok=True) 105 | 106 | logger.info("Building PushShift submission dump file list...") 107 | url_list = build_file_list(start_date, end_date) 108 | 109 | logger.info("Getting sha256sums") 110 | sha256sums = get_sha256sums() 111 | 112 | # Download and Process 113 | logger.info("Commencing download and processing into sqlite.") 114 | results = [] 115 | for url in url_list: 116 | result = reddit_processing(url, sha256sums, dumps_directory, args.keep_dumps) 117 | results.append(result) 118 | 119 | if __name__ == '__main__': 120 | main() 121 | 122 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | zstandard 2 | requests 3 | python-dateutil 4 | tqdm 5 | tldextract 6 | beautifulsoup4 7 | newspaper3k 8 | htmlmin 9 | lm_dataformat 10 | jsonlines 11 | datasketch 12 | colorama 13 | cutie 14 | sqlalchemy 15 | tqdm-multiprocess 16 | base36 17 | alembic 18 | cassandra-driver 19 | best-download 20 | 21 | mkdocs -------------------------------------------------------------------------------- /scraping/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/scraping/__init__.py -------------------------------------------------------------------------------- /scraping/filter.py: -------------------------------------------------------------------------------- 1 | import tldextract 2 | import re 3 | 4 | import logging 5 | logger = logging.getLogger("filelock") 6 | logger.setLevel(logging.WARNING) 7 | 8 | # https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not 9 | url_regex = re.compile( 10 | r'^(?:http)s?://' # http:// or https:// 11 | r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain... 12 | r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip 13 | r'(?::\d+)?' # optional port 14 | r'(?:/?|[/?]\S+)$', re.IGNORECASE) 15 | 16 | # domains that aren't scraper friendly. do not include subdomains! 17 | exclude_domains = set([ 18 | # image & video hosting sites 19 | 'imgur.com', 20 | 'redd.it', 21 | 'instagram.com', 22 | 'discord.gg', 23 | 'gfycat.com', 24 | 'giphy.com', 25 | 'reddituploads.com', 26 | 'redditmedia.com', 27 | 'twimg.com', 28 | 'sli.mg', 29 | 'magaimg.net', 30 | 'flickr.com', 31 | 'imgflip.com', 32 | 'youtube.com', 33 | 'youtu.be', 34 | 'youtubedoubler.com', 35 | 'vimeo.com', 36 | 'twitch.tv', 37 | 'streamable.com', 38 | 'bandcamp.com', 39 | 'soundcloud.com', 40 | 41 | # not scraper friendly 42 | 'reddit.com', 43 | 'gyazo.com', 44 | 'github.com', 45 | 'xkcd.com', 46 | 'twitter.com', 47 | 'spotify.com', 48 | 'itunes.apple.com', 49 | 'facebook.com', 50 | 'gunprime.com', 51 | 'strawpoll.me', 52 | 'voyagefusion.com', 53 | 'rollingstone.com', 54 | 'google.com', 55 | 'timeanddate.com', 56 | 'walmart.com', 57 | 'roanoke.com', 58 | 'spotrac.com', 59 | 60 | # original paper excluded wikipedia 61 | 'wikipedia.org', 62 | 63 | # lots of top posts for this one 64 | 'battleforthenet.com', 65 | ]) 66 | 67 | exclude_extensions = ( 68 | '.png', 69 | '.jpg', 70 | '.jpeg', 71 | '.gif', 72 | '.gifv', 73 | '.pdf', 74 | '.mp4', 75 | '.mp3', 76 | '.ogv', 77 | '.webm', 78 | '.doc', 79 | '.docx', 80 | '.log', 81 | '.csv', 82 | '.dat', 83 | '.iso', 84 | '.bin', 85 | '.exe', 86 | '.apk', 87 | '.jar', 88 | '.app', 89 | '.ppt', 90 | '.pps', 91 | '.pptx', 92 | '.xml', 93 | '.gz', 94 | '.xz', 95 | '.bz2', 96 | '.tgz', 97 | '.tar', 98 | '.zip', 99 | '.wma', 100 | '.mov', 101 | '.wmv', 102 | '.3gp', 103 | '.svg', 104 | '.rar', 105 | '.wav', 106 | '.avi', 107 | '.7z' 108 | ) 109 | 110 | def should_exclude(url): 111 | 112 | ext = tldextract.extract(url) 113 | domain = '.'.join([x for x in ext if x]) 114 | basedomain = '.'.join(ext[-2:]) 115 | 116 | # Ignore non-URLs 117 | if len(url) <= 8 or ' ' in url or re.match(url_regex, url) is None: 118 | return True 119 | 120 | # Ignore excluded domains 121 | if basedomain in exclude_domains or domain in exclude_domains: 122 | return True 123 | 124 | # Ignore case-insensitive matches for excluded extensions 125 | if url.lower().split('?')[0].endswith(exclude_extensions): 126 | return True 127 | 128 | return False -------------------------------------------------------------------------------- /scraping/scrape_urls.py: -------------------------------------------------------------------------------- 1 | """ 2 | This program iterates through URL files generated in step 2 above. For each file its hands out the URLs 3 | to a multiprocessing pool for scraping. Once all URLs in the batch are scraped, the successful results are 4 | archived using a slightly modified version of [lm_dataformat](https://github.com/leogao2/lm_dataformat) 5 | (thanks @bmk). The following metadata fields are saved in the metadata dict offered by lm_dataformat: 6 | 7 | title: Web Page Title 8 | lang: Language detected by Newspaper scraper. 9 | url: Original URL. 10 | word_count: Total words outputted by Newspaper. 11 | elapsed: Scraping time. 12 | scraper: Always "newspaper". 13 | domain: Top level domain for the original URL. 14 | reddit_id: List of submission IDs containing URL - converted from base36. 15 | subreddit: List of subreddits for the corresponding submissions. 16 | reddit_score: List of reddit scores for the corresponding submissions. 17 | reddit_title: List of submissions titles for the corresponding submissions. 18 | reddit_created_utc: List of submissions created times for the corresponding submissions. 19 | 20 | Arguments 21 | --------- 22 | --job_directory (-dir) 23 | Base directory containing the urls subdirectory and location where the scrapes subdirectory 24 | will be created. 25 | --process_count (-procs) 26 | Number of worker processes in the pool. Defaults to 60. Don't go above this on Windows. 27 | --request_timeout (-timeout) 28 | Scraping timeout for each URL. Defaults to 30 seconds. 29 | """ 30 | 31 | import os 32 | import sys 33 | import glob 34 | import json 35 | import argparse 36 | 37 | import tqdm 38 | from tqdm_multiprocess import TqdmMultiProcessPool 39 | 40 | from scraping.scrapers import newspaper_scraper 41 | from utils.archiver import Reader, Archive 42 | from utils.utils import Timer 43 | 44 | import logging 45 | from utils.logger import setup_logger_tqdm 46 | logger = logging.getLogger(__name__) 47 | 48 | # Multiprocessed 49 | def download(url_entry, request_timeout, scraper, 50 | memoize, tqdm_func, global_tqdm): 51 | 52 | url, reddit_meta = url_entry 53 | text, meta, success = scraper(url, memoize, request_timeout=request_timeout) 54 | 55 | if not success or text is None or text.strip() == "": 56 | if global_tqdm: 57 | global_tqdm.update() 58 | return (text, meta, False) 59 | 60 | # Add extra meta 61 | meta["reddit_id"] = reddit_meta["id"] 62 | meta["subreddit"] = reddit_meta["subreddit"] 63 | meta["reddit_score"] = reddit_meta["score"] 64 | meta["reddit_title"] = reddit_meta["title"] 65 | meta["reddit_created_utc"] = reddit_meta["created_utc"] 66 | 67 | if global_tqdm: 68 | global_tqdm.update() 69 | 70 | return (text, meta, success) 71 | 72 | def scrape_urls(urls_directory, scrapes_directory, process_count, request_timeout): 73 | 74 | url_files = glob.glob(os.path.join(urls_directory, "urls_*.jsonl.zst")) 75 | 76 | # Get Total URL count 77 | url_count_path = os.path.join(urls_directory, "url_count.json") 78 | if os.path.exists(url_count_path): 79 | total_url_count = json.load(open(url_count_path, "r")) 80 | else: 81 | logger.info("Getting total URL count...") 82 | total_url_count = 0 83 | for url_file_path in tqdm.tqdm(url_files, dynamic_ncols=True): 84 | reader = Reader() 85 | url_data = [] 86 | for url in reader.read_jsonl(url_file_path, get_meta=False): 87 | total_url_count += 1 88 | json.dump(total_url_count, open(url_count_path, "w")) 89 | 90 | # overall progress bar 91 | progress = tqdm.tqdm(total=total_url_count, dynamic_ncols=True) 92 | progress.set_description("Total URLs") 93 | 94 | for url_file_path in url_files: 95 | # Skip if previously done 96 | done_file_path = url_file_path + ".done" 97 | if os.path.exists(done_file_path): 98 | batch_url_count = json.load(open(done_file_path, "r")) 99 | progress.update(batch_url_count) 100 | logger.info(f"'{os.path.basename(url_file_path)}' already scraped, skipping.") 101 | continue 102 | 103 | logger.info(f"Scraping URLs from '{os.path.basename(url_file_path)}'.") 104 | 105 | reader = Reader() 106 | url_data = [] 107 | for url, reddit_meta in reader.read_jsonl(url_file_path, get_meta=True): 108 | url_data.append((url, reddit_meta)) 109 | 110 | timer = Timer().start() 111 | 112 | # Download and Process With Pool 113 | pool = TqdmMultiProcessPool(process_count) 114 | tasks = [] 115 | for url_entry in url_data: 116 | arguments = (url_entry, request_timeout, newspaper_scraper, False) 117 | task = (download, arguments) 118 | tasks.append(task) 119 | 120 | # tqdm-multiprocess doesn't support multiple global tqdms, use on_done as well 121 | on_done = lambda _ : progress.update() 122 | on_error = lambda _ : None 123 | 124 | with tqdm.tqdm(total=len(url_data), dynamic_ncols=True) as batch_progress: 125 | batch_progress.set_description(f"{os.path.basename(url_file_path)}") 126 | results = pool.map(batch_progress, tasks, on_error, on_done) 127 | 128 | logger.info("Archiving chunk with lm_dataformat...") 129 | # urls_*.jsonl.zst -> scrapes_*.jsonl.zst 130 | output_archive_name = os.path.basename(url_file_path).replace("urls", "scrapes") 131 | output_archive_path = os.path.join(scrapes_directory, output_archive_name) 132 | archiver = Archive(output_archive_path) 133 | batch_error_count = 0 134 | for text, meta, status in results: 135 | if not status: 136 | batch_error_count += 1 137 | else: 138 | archiver.add_data(text, meta) 139 | archiver.commit() 140 | 141 | error_percentage = batch_error_count / len(url_data) * 100 142 | logger.info(f"Errors: {batch_error_count} / {len(url_data)} ({error_percentage:0.2f}%)") 143 | logger.info(f"Batch time: {timer.stop():0.2f} seconds") 144 | 145 | json.dump(len(url_data), open(done_file_path, "w")) 146 | 147 | progress.close() 148 | logger.info("Done!") 149 | 150 | parser_description = 'Scrape urls extracted from Reddit.' 151 | parser = argparse.ArgumentParser(description=parser_description) 152 | parser.add_argument("-dir", "--job_directory", default="") 153 | parser.add_argument("-procs", "--process_count", type=int, default=60) 154 | parser.add_argument("-timeout", "--request_timeout", type=int, default=30) 155 | 156 | if __name__ == "__main__": 157 | logfile_name = "scrape_urls.log" 158 | setup_logger_tqdm(logfile_name) 159 | 160 | args = parser.parse_args() 161 | 162 | urls_directory = os.path.join(args.job_directory, "urls") 163 | if not os.path.exists(urls_directory): 164 | logger.info(f"No 'urls' directory found in '{args.job_directory}', aborting") 165 | sys.exit(0) 166 | 167 | scrapes_directory = os.path.join(args.job_directory, "scrapes") 168 | os.makedirs(scrapes_directory, exist_ok=True) 169 | 170 | logger.info(f"Scrapes outputting to: '{scrapes_directory}'") 171 | 172 | scrape_urls(urls_directory, scrapes_directory, args.process_count, args.request_timeout) 173 | -------------------------------------------------------------------------------- /scraping/scrapers.py: -------------------------------------------------------------------------------- 1 | # Code taken in large part from https://github.com/jcpeterson/openwebtext 2 | 3 | import time 4 | import unicodedata 5 | 6 | import bs4 7 | import newspaper 8 | 9 | from lxml.html.clean import Cleaner 10 | from htmlmin import minify 11 | from scraping.filter import should_exclude 12 | 13 | 14 | def find_and_filter_tag(tag, soup): 15 | """tag specific filter logic""" 16 | 17 | candidates = soup.find_all(tag) 18 | candidates = [ 19 | unicodedata.normalize("NFKD", x.string) 20 | for x in candidates 21 | if x.string is not None 22 | ] 23 | 24 | if tag == "p": 25 | candidates = [y.strip() for y in candidates if len(y.split(" ")) >= 4] 26 | count = sum(len(y.split(" ")) for y in candidates) 27 | else: 28 | raise NotImplementedError 29 | 30 | return (candidates, count) 31 | 32 | 33 | def raw_scraper(url, memoize): 34 | t1 = time.time() 35 | if should_exclude(url): 36 | # heuristic to make downloading faster 37 | return None, { 38 | "url": url, 39 | "scraper": "raw", 40 | } 41 | 42 | try: 43 | cleaner = Cleaner() 44 | cleaner.javascript = True 45 | cleaner.style = True 46 | article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize) 47 | article.download() 48 | html = minify(article.html) 49 | html = cleaner.clean_html(html) 50 | article.parse() 51 | except: 52 | return None, { 53 | "url": url, 54 | "scraper": "raw", 55 | } 56 | if article.text == "": 57 | return None, { 58 | "url": url, 59 | "scraper": "raw", 60 | } 61 | 62 | metadata = {"url": url, "elapsed": time.time() - t1, "scraper": "raw"} 63 | return html, metadata 64 | 65 | 66 | def newspaper_scraper(url, memoize, request_timeout): 67 | t1 = time.time() 68 | 69 | if should_exclude(url): 70 | # heuristic to make downloading faster 71 | return None, { 72 | "url": url, 73 | "scraper": "newspaper", 74 | }, False 75 | 76 | try: 77 | article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize, 78 | request_timeout=request_timeout) 79 | article.download() 80 | article.parse() 81 | except Exception as ex: 82 | return None, ex, False 83 | 84 | #print(article.__dict__.keys()) 85 | #dict_keys(['config', 'extractor', 'source_url', 'url', 'title', 'top_img', 'top_image', 'meta_img', 'imgs', 'images', 86 | # 'movies', 'text', 'keywords', 'meta_keywords', 'tags', 'authors', 'publish_date', 'summary', 'html', 'article_html', 87 | # 'is_parsed', 'download_state', 'download_exception_msg', 'meta_description', 'meta_lang', 'meta_favicon', 'meta_data', 88 | # 'canonical_link', 'top_node', 'clean_top_node', 'doc', 'clean_doc', 'additional_data', 'link_hash']) 89 | 90 | #print(article.title) 91 | # print(article.meta_lang) 92 | # if article.meta_lang != "en" and article.meta_lang: 93 | # print(article.text) 94 | 95 | text = article.text 96 | count = len(text.split()) 97 | 98 | metadata = { 99 | "title": article.title, 100 | "lang": article.meta_lang, 101 | "url": url, 102 | "word_count": count, 103 | "elapsed": time.time() - t1, 104 | "scraper": "newspaper", 105 | } 106 | return text, metadata, True 107 | 108 | def bs4_scraper(url, memoize): 109 | t1 = time.time() 110 | if should_exclude(url): 111 | # heuristic to make downloading faster 112 | return None, { 113 | "url": url, 114 | "scraper": "bs4", 115 | } 116 | 117 | try: 118 | article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize) 119 | article.download() 120 | html = article.html 121 | soup = bs4.BeautifulSoup(html, "lxml") 122 | text, count = find_and_filter_tag("p", soup) 123 | # DDB: keep text as a single string for consistency with 124 | # newspaper_scraper 125 | text = " ".join(text) 126 | except: 127 | return None, { 128 | "url": url, 129 | "scraper": "bs4", 130 | } 131 | 132 | metadata = { 133 | "url": url, 134 | "word_count": count, 135 | "elapsed": time.time() - t1, 136 | "scraper": "bs4", 137 | } 138 | return text, metadata 139 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/utils/__init__.py -------------------------------------------------------------------------------- /utils/archive_stream_readers.py: -------------------------------------------------------------------------------- 1 | import lzma 2 | import zstandard as zstd 3 | import bz2 4 | 5 | # Lets you access the underyling file's tell function for updating pqdm 6 | class ArchiveStreamReader(object): 7 | def __init__(self, file_path, decompressor): 8 | self.file_path = file_path 9 | self.file_handle = None 10 | self.decompressor = decompressor 11 | self.stream_reader = None 12 | 13 | def __enter__(self): 14 | self.file_handle = open(self.file_path, 'rb') 15 | self.stream_reader = self.decompressor(self.file_handle) 16 | return self 17 | 18 | def __exit__(self, exc_type, exc_value, traceback): 19 | self.stream_reader.close() 20 | self.file_handle.close() 21 | 22 | def tell(self): 23 | return self.file_handle.tell() 24 | 25 | def read(self, size): 26 | return self.stream_reader.read(size) 27 | 28 | def get_archive_stream_reader(file_path): 29 | extension = file_path.split(".")[-1] 30 | 31 | if extension == "zst": 32 | return ArchiveStreamReader(file_path, zstd.ZstdDecompressor().stream_reader) 33 | elif extension == "bz2": 34 | return ArchiveStreamReader(file_path, bz2.BZ2File) 35 | elif extension == "xz": 36 | return ArchiveStreamReader(file_path, lzma.open) -------------------------------------------------------------------------------- /utils/archiver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zstandard 3 | import json 4 | import jsonlines 5 | import io 6 | import datetime 7 | 8 | def json_serial(obj): 9 | """JSON serializer for objects not serializable by default json code""" 10 | 11 | if isinstance(obj, (datetime.datetime,)): 12 | return obj.isoformat() 13 | raise TypeError ("Type %s not serializable" % type(obj)) 14 | 15 | # Modified version of lm_dataformat Archive for single file. 16 | class Archive: 17 | def __init__(self, file_path, compression_level=3): 18 | self.file_path = file_path 19 | dir_name = os.path.dirname(file_path) 20 | if dir_name: 21 | os.makedirs(dir_name, exist_ok=True) 22 | self.fh = open(self.file_path, 'wb') 23 | self.cctx = zstandard.ZstdCompressor(level=compression_level) 24 | self.compressor = self.cctx.stream_writer(self.fh) 25 | 26 | def add_data(self, data, meta={}): 27 | self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n') 28 | 29 | def commit(self): 30 | self.compressor.flush(zstandard.FLUSH_FRAME) 31 | self.fh.flush() 32 | self.fh.close() 33 | 34 | # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm. 35 | class Reader: 36 | def __init__(self): 37 | pass 38 | 39 | def read_jsonl(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'): 40 | with open(file, 'rb') as fh: 41 | self.fh = fh 42 | cctx = zstandard.ZstdDecompressor() 43 | reader = io.BufferedReader(cctx.stream_reader(fh)) 44 | rdr = jsonlines.Reader(reader) 45 | for ob in rdr: 46 | # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility. 47 | if isinstance(ob, str): 48 | assert not get_meta 49 | yield ob 50 | continue 51 | 52 | text = ob['text'] 53 | 54 | if autojoin_paragraphs and isinstance(text, list): 55 | text = para_joiner.join(text) 56 | 57 | if get_meta: 58 | yield text, (ob['meta'] if 'meta' in ob else {}) 59 | else: 60 | yield text -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from datetime import timedelta 4 | from tqdm import tqdm 5 | 6 | class LogFormatter(): 7 | 8 | def __init__(self): 9 | self.start_time = time.time() 10 | 11 | def format(self, record): 12 | elapsed_seconds = round(record.created - self.start_time) 13 | 14 | prefix = "%s - %s - %s" % ( 15 | record.levelname, 16 | time.strftime('%x %X'), 17 | timedelta(seconds=elapsed_seconds) 18 | ) 19 | message = record.getMessage() 20 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3)) 21 | return "%s - %s" % (prefix, message) if message else '' 22 | 23 | def reset_time(self): 24 | self.start_time = time.time() 25 | 26 | def setup_logger(filepath=None, to_console=True, formatter=LogFormatter()): 27 | 28 | # create logger 29 | logger = logging.getLogger() 30 | logger.setLevel(logging.DEBUG) 31 | logger.propagate = False 32 | 33 | logger.handlers = [] 34 | 35 | # create file handler 36 | if filepath is not None: 37 | file_handler = logging.FileHandler(filepath, "a") 38 | file_handler.setLevel(logging.DEBUG) 39 | file_handler.setFormatter(formatter) 40 | logger.addHandler(file_handler) 41 | 42 | # create console handler 43 | if to_console: 44 | console_handler = logging.StreamHandler() 45 | console_handler.setLevel(logging.INFO) 46 | console_handler.setFormatter(formatter) 47 | logger.addHandler(console_handler) 48 | 49 | class ChildProcessHandler(logging.StreamHandler): 50 | def __init__(self, message_queue): 51 | self.message_queue = message_queue 52 | logging.StreamHandler.__init__(self) 53 | 54 | def emit(self, record): 55 | self.message_queue.put(record) 56 | 57 | def setup_logger_child_process(message_queue): 58 | # create logger 59 | logger = logging.getLogger() 60 | logger.setLevel(logging.DEBUG) 61 | logger.propagate = False 62 | 63 | logger.handlers = [] 64 | 65 | # create queue handler 66 | child_process_handler = ChildProcessHandler(message_queue) 67 | child_process_handler.setLevel(logging.INFO) 68 | logger.addHandler(child_process_handler) 69 | 70 | class TqdmHandler(logging.StreamHandler): 71 | def __init__(self): 72 | logging.StreamHandler.__init__(self) 73 | 74 | def emit(self, record): 75 | msg = self.format(record) 76 | tqdm.write(msg) 77 | 78 | def setup_logger_tqdm(filepath=None, formatter=LogFormatter()): 79 | 80 | # create logger 81 | logger = logging.getLogger() 82 | logger.setLevel(logging.DEBUG) 83 | logger.propagate = False 84 | 85 | logger.handlers = [] 86 | 87 | # create file handler 88 | if filepath is not None: 89 | file_handler = logging.FileHandler(filepath, "a") 90 | file_handler.setLevel(logging.DEBUG) 91 | file_handler.setFormatter(formatter) 92 | logger.addHandler(file_handler) 93 | 94 | # create tqdm handler 95 | tqdm_handler = TqdmHandler() 96 | tqdm_handler.setLevel(logging.INFO) 97 | tqdm_handler.setFormatter(formatter) 98 | logger.addHandler(tqdm_handler) -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import time 4 | import pickle 5 | 6 | import logging 7 | logger = logging.getLogger(__name__) 8 | 9 | def timed_pickle_load(file_name, pickle_description): 10 | logger.info(f"Unpickling {pickle_description}...") 11 | timer = Timer().start() 12 | unpickled = pickle.load(open(file_name, "rb")) 13 | logger.info(timer.stop_string()) 14 | return unpickled 15 | 16 | def timed_pickle_dump(the_object, file_name, pickle_description): 17 | logger.info(f"Pickling {pickle_description}...") 18 | timer = Timer().start() 19 | pickle.dump(the_object, open(file_name, "wb")) 20 | logger.info(timer.stop_string()) 21 | 22 | class TimerError(Exception): 23 | """A custom exception used to report errors in use of Timer class""" 24 | 25 | class Timer: 26 | def __init__(self): 27 | self._start_time = None 28 | 29 | def start(self): 30 | """Start a new timer""" 31 | if self._start_time is not None: 32 | raise TimerError(f"Timer is running. Use .stop() to stop it") 33 | 34 | self._start_time = time.perf_counter() 35 | 36 | return self 37 | 38 | def stop(self): 39 | """Stop the timer, and report the elapsed time""" 40 | if self._start_time is None: 41 | raise TimerError(f"Timer is not running. Use .start() to start it") 42 | 43 | elapsed_time = time.perf_counter() - self._start_time 44 | self._start_time = None 45 | return elapsed_time 46 | 47 | def stop_string(self): 48 | elapsed = self.stop() 49 | return f"Took {elapsed:0.2f}s" 50 | 51 | def linecount(filename): 52 | f = open(filename, 'rb') 53 | lines = 0 54 | buf_size = 1024 * 1024 55 | read_f = f.raw.read 56 | 57 | buf = read_f(buf_size) 58 | while buf: 59 | lines += buf.count(b'\n') 60 | buf = read_f(buf_size) 61 | 62 | return lines 63 | 64 | def chunker(l, n, s=0): 65 | """Yield successive n-sized chunks from l, skipping the first s chunks.""" 66 | if isinstance(l, collections.Iterable): 67 | chnk = [] 68 | for i, elem in enumerate(l): 69 | if i < s: 70 | continue 71 | 72 | chnk.append(elem) 73 | if len(chnk) == n: 74 | yield chnk 75 | chnk = [] 76 | if len(chnk) != 0: 77 | yield chnk 78 | 79 | else: 80 | for i in range(s, len(l), n): 81 | yield l[i : i + n] 82 | 83 | def mkdir(fp): 84 | try: 85 | os.makedirs(fp) 86 | except FileExistsError: 87 | pass 88 | return fp --------------------------------------------------------------------------------