├── .gitattributes
├── .gitignore
├── .readthedocs.yml
├── LICENSE
├── README.md
├── alembic.ini.template
├── alembic
├── README
├── env.py
└── script.py.mako
├── cleaning
├── __init__.py
├── dedupe_from_indexes.py
├── filter_from_reddit_scores.py
├── generate_minhashes.py
├── minhash_lsh_batching.py
└── minhash_lsh_dedupe.py
├── data_analysis
├── __init__.py
└── final_stats.py
├── mkdocs
├── docs
│ ├── background.md
│ ├── css
│ │ └── extra.css
│ ├── index.md
│ ├── licence.md
│ └── replication.md
└── mkdocs.yml
├── pushshift
├── __init__.py
├── download_pushshift_dumps.py
├── generate_urls.py
├── models.py
├── process_dump_files_sqlite.py
└── pushshift_to_sqlite.py
├── requirements.txt
├── scraping
├── __init__.py
├── filter.py
├── scrape_urls.py
└── scrapers.py
└── utils
├── __init__.py
├── archive_stream_readers.py
├── archiver.py
├── logger.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # SageMath parsed files
98 | *.sage.py
99 |
100 | # Environments
101 | .env
102 | .venv
103 | env/
104 | venv/
105 | ENV/
106 | env.bak/
107 | venv.bak/
108 |
109 | # Spyder project settings
110 | .spyderproject
111 | .spyproject
112 |
113 | # Rope project settings
114 | .ropeproject
115 |
116 | # mkdocs documentation
117 | /site
118 |
119 | # mypy
120 | .mypy_cache/
121 | .dmypy.json
122 | dmypy.json
123 |
124 | # Pyre type checker
125 | .pyre/
126 | .vs*
127 | pushshift_dump_processing.sln
128 | pushshift_dump_processing.pyproj
129 | Untitled.ipynb
130 | playing_around*
131 | process_dump_files.txt
132 | url_count.pkl
133 | url_dedupe.txt
134 | duplication_stats.json
135 | csv/score_over_time.csv
136 | alembic.ini
137 | migration/find_missing_metadata_nplus1.py
138 | testing_archive.jsonl.zst
139 | migration*
140 | testa.py
141 | webtext2_colab_old.ipynb
142 | possibly_useful*
143 | data_analysis/document_count_by_stage.py
144 | data_analysis/score_over_time.py
145 | data_analysis/score_over_time_scrapes.py
146 | test_db.py
147 | cleaning/aggregate_archives_by_month.py
148 | mkdocs/site/*
149 |
--------------------------------------------------------------------------------
/.readthedocs.yml:
--------------------------------------------------------------------------------
1 | # .readthedocs.yml
2 | # Read the Docs configuration file
3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4 |
5 | version: 2
6 |
7 | mkdocs:
8 | configuration: mkdocs/mkdocs.yml
9 | fail_on_warning: false
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 EleutherAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenWebText2
2 |
3 | This project is part of EleutherAI's quest to create a massive repository of high quality text data for training language models.
4 |
5 | Very briefly, OpenWebText2 is a large filtered dataset of text documents scraped from URL found on Reddit submisisons.
6 |
7 | The plug and play version of OpenWebText2 contains:
8 | - 17,103,059 documents
9 | - 65.86GB uncompressed text
10 |
11 | ## Download Dataset / Documentation
12 |
13 | For further information please visit our [documentation](https://openwebtext2.readthedocs.io/en/latest/).
14 |
15 | ## Acknowledgements
16 | [researcher2](https://github.com/researcher2) Wrote much of this code, with inspiration and some straight copying of the scraping code found [here](https://github.com/yet-another-account/openwebtext/).
17 | [sdtblck](https://github.com/sdtblck/) kindly put together the Colab notebook, and performed a chunk of the scraping.
18 | [leogao2](https://github.com/leogao2/) provided overall design guidance, lm_dataformat, and performed another chunk of scraping.
19 | [Colaboratory](https://colab.research.google.com/) VMs helped us with about 10% of our overall scraping.
20 | [The Eye](http://the-eye.eu/) host our processed datasets.
21 | [Read The Docs](https://readthedocs.org/) host our documentation.
22 |
--------------------------------------------------------------------------------
/alembic.ini.template:
--------------------------------------------------------------------------------
1 | # A generic, single database configuration.
2 |
3 | [alembic]
4 | # path to migration scripts
5 | script_location = alembic
6 |
7 | # template used to generate migration files
8 | # file_template = %%(rev)s_%%(slug)s
9 |
10 | # timezone to use when rendering the date
11 | # within the migration file as well as the filename.
12 | # string value is passed to dateutil.tz.gettz()
13 | # leave blank for localtime
14 | # timezone =
15 |
16 | # max length of characters to apply to the
17 | # "slug" field
18 | # truncate_slug_length = 40
19 |
20 | # set to 'true' to run the environment during
21 | # the 'revision' command, regardless of autogenerate
22 | # revision_environment = false
23 |
24 | # set to 'true' to allow .pyc and .pyo files without
25 | # a source .py file to be detected as revisions in the
26 | # versions/ directory
27 | # sourceless = false
28 |
29 | # version location specification; this defaults
30 | # to alembic/versions. When using multiple version
31 | # directories, initial revisions must be specified with --version-path
32 | # version_locations = %(here)s/bar %(here)s/bat alembic/versions
33 |
34 | # the output encoding used when revision files
35 | # are written from script.py.mako
36 | # output_encoding = utf-8
37 |
38 | sqlalchemy.url = sqlite:///e:/Eleuther_AI/webtext2/dumps/submissions.sqlite
39 |
40 | [post_write_hooks]
41 | # post_write_hooks defines scripts or Python functions that are run
42 | # on newly generated revision scripts. See the documentation for further
43 | # detail and examples
44 |
45 | # format using "black" - use the console_scripts runner, against the "black" entrypoint
46 | # hooks=black
47 | # black.type=console_scripts
48 | # black.entrypoint=black
49 | # black.options=-l 79
50 |
51 | # Logging configuration
52 | [loggers]
53 | keys = root,sqlalchemy,alembic
54 |
55 | [handlers]
56 | keys = console
57 |
58 | [formatters]
59 | keys = generic
60 |
61 | [logger_root]
62 | level = WARN
63 | handlers = console
64 | qualname =
65 |
66 | [logger_sqlalchemy]
67 | level = WARN
68 | handlers =
69 | qualname = sqlalchemy.engine
70 |
71 | [logger_alembic]
72 | level = INFO
73 | handlers =
74 | qualname = alembic
75 |
76 | [handler_console]
77 | class = StreamHandler
78 | args = (sys.stderr,)
79 | level = NOTSET
80 | formatter = generic
81 |
82 | [formatter_generic]
83 | format = %(levelname)-5.5s [%(name)s] %(message)s
84 | datefmt = %H:%M:%S
85 |
--------------------------------------------------------------------------------
/alembic/README:
--------------------------------------------------------------------------------
1 | Generic single-database configuration.
--------------------------------------------------------------------------------
/alembic/env.py:
--------------------------------------------------------------------------------
1 | from logging.config import fileConfig
2 |
3 | from sqlalchemy import engine_from_config
4 | from sqlalchemy import pool
5 |
6 | from alembic import context
7 |
8 | # this is the Alembic Config object, which provides
9 | # access to the values within the .ini file in use.
10 | config = context.config
11 |
12 | # Interpret the config file for Python logging.
13 | # This line sets up loggers basically.
14 | fileConfig(config.config_file_name)
15 |
16 | # add your model's MetaData object here
17 | # for 'autogenerate' support
18 | # from myapp import mymodel
19 | # target_metadata = mymodel.Base.metadata
20 | import sys
21 | sys.path = ['', '..'] + sys.path[1:] # ew ew ew ew ew
22 | from pushshift.models import base
23 | target_metadata = base.metadata
24 |
25 | # other values from the config, defined by the needs of env.py,
26 | # can be acquired:
27 | # my_important_option = config.get_main_option("my_important_option")
28 | # ... etc.
29 |
30 |
31 | def run_migrations_offline():
32 | """Run migrations in 'offline' mode.
33 |
34 | This configures the context with just a URL
35 | and not an Engine, though an Engine is acceptable
36 | here as well. By skipping the Engine creation
37 | we don't even need a DBAPI to be available.
38 |
39 | Calls to context.execute() here emit the given string to the
40 | script output.
41 |
42 | """
43 | url = config.get_main_option("sqlalchemy.url")
44 | context.configure(
45 | url=url,
46 | target_metadata=target_metadata,
47 | literal_binds=True,
48 | dialect_opts={"paramstyle": "named"},
49 | )
50 |
51 | with context.begin_transaction():
52 | context.run_migrations()
53 |
54 |
55 | def run_migrations_online():
56 | """Run migrations in 'online' mode.
57 |
58 | In this scenario we need to create an Engine
59 | and associate a connection with the context.
60 |
61 | """
62 | connectable = engine_from_config(
63 | config.get_section(config.config_ini_section),
64 | prefix="sqlalchemy.",
65 | poolclass=pool.NullPool,
66 | )
67 |
68 | with connectable.connect() as connection:
69 | context.configure(
70 | connection=connection, target_metadata=target_metadata
71 | )
72 |
73 | with context.begin_transaction():
74 | context.run_migrations()
75 |
76 |
77 | if context.is_offline_mode():
78 | run_migrations_offline()
79 | else:
80 | run_migrations_online()
81 |
--------------------------------------------------------------------------------
/alembic/script.py.mako:
--------------------------------------------------------------------------------
1 | """${message}
2 |
3 | Revision ID: ${up_revision}
4 | Revises: ${down_revision | comma,n}
5 | Create Date: ${create_date}
6 |
7 | """
8 | from alembic import op
9 | import sqlalchemy as sa
10 | ${imports if imports else ""}
11 |
12 | # revision identifiers, used by Alembic.
13 | revision = ${repr(up_revision)}
14 | down_revision = ${repr(down_revision)}
15 | branch_labels = ${repr(branch_labels)}
16 | depends_on = ${repr(depends_on)}
17 |
18 |
19 | def upgrade():
20 | ${upgrades if upgrades else "pass"}
21 |
22 |
23 | def downgrade():
24 | ${downgrades if downgrades else "pass"}
25 |
--------------------------------------------------------------------------------
/cleaning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/cleaning/__init__.py
--------------------------------------------------------------------------------
/cleaning/dedupe_from_indexes.py:
--------------------------------------------------------------------------------
1 | """
2 | This script builds a list of all duplicates by file_id & document_id, and then iterates
3 | through all ".minscored" files from the filename lookup, creating a new archive for each
4 | file in the original containing all documents that were not marked as duplicates during
5 | the previous step.
6 |
7 | So for each original file, a "_final.jsonl.zst" files will be output in the original
8 | directory.
9 |
10 | Arguments
11 | ------
12 | --batch_directory (-dir)
13 | Directory containing the "*duplicates.txt" files along with the "file_name_lookup.pkl"
14 | created during batch slicing. The "_final.jsonl.zst" files will be output in their
15 | original directories.
16 | """
17 |
18 | import glob
19 | import os
20 | import pickle
21 | import argparse
22 |
23 | import tqdm
24 |
25 | from utils.archiver import Archive, Reader
26 |
27 | import logging
28 | from utils.logger import setup_logger_tqdm
29 | logger = logging.getLogger(__name__)
30 |
31 | def main(batch_directory):
32 | file_name_lookup_path = os.path.join(batch_directory, "file_name_lookup.pkl")
33 | file_name_lookup = pickle.load(open(file_name_lookup_path,"rb"))
34 |
35 | logger.info("Building duplicates dictionary...")
36 | duplicates_dict = {file_id : set() for file_id in range(len(file_name_lookup))}
37 | duplicate_files = glob.glob(os.path.join(batch_directory, "*_duplicates.txt"))
38 | for duplicate_file in duplicate_files:
39 | with open(duplicate_file, "r") as fh:
40 | duplicates = fh.read().splitlines()
41 | for duplicate in duplicates:
42 | file_id, document_id = tuple(map(int, duplicate.split(" ")))
43 | duplicates_dict[file_id].add(document_id)
44 |
45 | logger.info("De-duplicating files...")
46 | for file_id, original_file_name in enumerate(tqdm.tqdm(file_name_lookup)):
47 | final_file_name = original_file_name.replace("_default.jsonl.zst.deduped.merged.minscored",
48 | "_final.jsonl.zst")
49 |
50 | reader = Reader()
51 | count = 0
52 | archiver = Archive(final_file_name)
53 | for document, metadata in reader.read_jsonl(original_file_name, get_meta=True):
54 | if count not in duplicates_dict[file_id]:
55 | archiver.add_data(document, metadata)
56 | count += 1
57 | archiver.commit()
58 |
59 | parser = argparse.ArgumentParser(description='Dedupe from provided indexes.')
60 | parser.add_argument("-dir", "--batch_directory", default="")
61 |
62 | if __name__ == '__main__':
63 | logfile_path = "dedupe_from_index.log"
64 | setup_logger_tqdm(logfile_path)
65 |
66 | args = parser.parse_args()
67 | main(args.batch_directory)
--------------------------------------------------------------------------------
/cleaning/filter_from_reddit_scores.py:
--------------------------------------------------------------------------------
1 | """
2 | This script filters all scrape files "scrapes_*.jsonl.zst" by minimum total Reddit score.
3 | Unlike the original WebText we aggregate scores for all submissions containing a given
4 | URL so the bar is slightly lower in some cases, but in others where a URL went negative
5 | in one submission it will balance out.
6 |
7 | The filtered scrapes file will have the original name and path of the scrape file with a
8 | ".minscored" extension.
9 |
10 | Arguments
11 | ---------
12 | --scrape_directory (-dir)
13 | Directory containing the scrapes. You could use the overall work directory if you
14 | want as we use glob.glob to search recursively.
15 | """
16 |
17 | import argparse
18 | import glob
19 | import os
20 | import sys
21 | import math
22 | from functools import reduce
23 | from operator import add
24 |
25 | import tqdm
26 | from tqdm_multiprocess import TqdmMultiProcessPool
27 |
28 | from utils.archiver import Reader, Archive
29 |
30 | import logging
31 | from utils.logger import setup_logger_tqdm
32 | logger = logging.getLogger(__name__)
33 |
34 | million = math.pow(10, 6)
35 |
36 | # Multiprocessed
37 | def process_file(file_path, tqdm_func, global_tqdm):
38 | reader = Reader()
39 |
40 | filtered_archive_path = file_path + ".minscored"
41 | archiver = Archive(filtered_archive_path)
42 |
43 | for document, metadata in reader.read_jsonl(file_path, get_meta=True):
44 | total_score = reduce(add, metadata["reddit_scores"])
45 | if total_score >= 3:
46 | archiver.add_data(document, metadata)
47 |
48 | global_tqdm.update(os.path.getsize(file_path))
49 | archiver.commit()
50 |
51 | def filter_from_reddit_scores(scrape_directory):
52 | files = glob.glob(os.path.join(scrape_directory, "**/scrapes_*.jsonl.zst"), recursive=True)
53 | total_file_size = reduce(add, map(os.path.getsize, files))
54 | logger.info(f"Total File Size: {(total_file_size / million):.2f} MB")
55 |
56 | # [(file_name, [doc0_minhash, doc1_minhash, ...]), ....]
57 | with tqdm.tqdm(total=total_file_size, dynamic_ncols=True, unit_scale=1) as progress:
58 | pool = TqdmMultiProcessPool()
59 | process_count = 4
60 | tasks = []
61 | for file_path in files:
62 | task = (process_file, (file_path,))
63 | tasks.append(task)
64 |
65 | on_done = lambda _ : None
66 | on_error = on_done
67 | result = pool.map(process_count, progress, tasks, on_error, on_done)
68 |
69 | return result
70 |
71 | parser_description = 'Filter scrapes based on minimum reddit scores.'
72 | parser = argparse.ArgumentParser(description=parser_description)
73 | parser.add_argument("-dir", "--scrape_directory", default="")
74 |
75 | if __name__ == '__main__':
76 | args = parser.parse_args()
77 | if not os.path.isdir(args.scrape_directory):
78 | print("Scrape directory doesn't exist, exiting.")
79 | sys.exit(0)
80 |
81 | log_file = "filter_from_reddit_scores.log"
82 | setup_logger_tqdm(log_file)
83 |
84 | logger.info("Filtering scrapes based on minimum reddit scores.")
85 | filter_from_reddit_scores(args.scrape_directory)
86 |
87 |
--------------------------------------------------------------------------------
/cleaning/generate_minhashes.py:
--------------------------------------------------------------------------------
1 | """
2 | This script calculates minhashes for all filtered scrape files found using a recursive
3 | search on "*.minscored".
4 |
5 | More explicity, we create a set of 5-grams for each document, and generate document
6 | level minhashes using 10 hash functions with the excellent datasketch library.
7 |
8 | A single file "minhashes.pkl" is created in the scrape directory storing a data
9 | structure in the following format:
10 |
11 | [(file_name1, [doc0_minhash, doc1_minhash, ...]), (file_name2, [....]), ....]
12 |
13 | Arguments
14 | ---------
15 | --scrape_directory (-dir)
16 | Directory containing the minscored scrapes. You could use the overall work directory if you
17 | want as we use glob.glob to search recursively.
18 | --process_count (-procs)
19 | Number of worker processes in the pool. Defaults to 4.
20 | """
21 |
22 | import argparse
23 | import glob
24 | import os
25 | import sys
26 | import math
27 | from functools import reduce
28 | from operator import add
29 | from contextlib import redirect_stdout
30 |
31 | import tqdm
32 | import nltk
33 | from nltk.util import ngrams
34 | from datasketch import MinHash, LeanMinHash
35 | from tqdm_multiprocess import TqdmMultiProcessPool
36 |
37 | from utils.utils import timed_pickle_dump
38 | from utils.archiver import Reader
39 |
40 | import logging
41 | from utils.logger import setup_logger_tqdm
42 | logger = logging.getLogger(__name__)
43 |
44 | million = math.pow(10, 6)
45 |
46 | def extract_ngrams(data, num):
47 | n_grams = ngrams(nltk.word_tokenize(data), num)
48 | return [ ' '.join(grams) for grams in n_grams]
49 |
50 | # Multiprocessed
51 | def process_file(file_path, tqdm_func, global_tqdm):
52 | reader = Reader()
53 | minhashes = []
54 | previous_file_position = 0
55 | for document, metadata in reader.read_jsonl(file_path, get_meta=True):
56 |
57 | n_grams = extract_ngrams(document, 5)
58 | five_gram_set = set(n_grams)
59 | minhash = MinHash(num_perm=10)
60 | for five_gram in five_gram_set:
61 | minhash.update(five_gram.encode('utf8'))
62 | minhashes.append(LeanMinHash(minhash))
63 |
64 | # Update Progress Bar
65 | current_file_position = reader.fh.tell()
66 | global_tqdm.update(current_file_position - previous_file_position)
67 | previous_file_position = current_file_position
68 |
69 | return file_path, minhashes
70 |
71 | def generate_minhashes(scrape_directory, process_count):
72 | files = glob.glob(os.path.join(scrape_directory, "**/*.minscored"), recursive=True)
73 | total_file_size = reduce(add, map(os.path.getsize, files))
74 | logger.info(f"Total File Size: {(total_file_size / million):.2f} MB")
75 |
76 | # [(file_name1, [doc0_minhash, doc1_minhash, ...]), (file_name2, [....]), ....]
77 | with tqdm.tqdm(total=total_file_size, dynamic_ncols=True, unit_scale=1) as progress:
78 | pool = TqdmMultiProcessPool()
79 | tasks = []
80 | for file_path in files:
81 | task = (process_file, (file_path,))
82 | tasks.append(task)
83 |
84 | on_done = lambda _ : None
85 | on_error = on_done
86 | result = pool.map(process_count, progress, tasks, on_error, on_done)
87 |
88 | return result
89 |
90 | parser_description = 'Generate minhashes for all documents found.'
91 | parser = argparse.ArgumentParser(description=parser_description)
92 | parser.add_argument("-dir", "--scrape_directory", default="")
93 | parser.add_argument("-procs", "--process_count", type=int, default=4)
94 |
95 | if __name__ == '__main__':
96 | args = parser.parse_args()
97 | if not os.path.isdir(args.scrape_directory):
98 | print("Scrape directory doesn't exist, exiting.")
99 | sys.exit(0)
100 |
101 | with redirect_stdout(open(os.devnull, "w")):
102 | nltk.download('punkt')
103 |
104 | log_file = "generate_minhashes.log"
105 | setup_logger_tqdm(log_file)
106 |
107 | logger.info("Generating document level minhashes from 5 gram sets")
108 | minhashes_by_file = generate_minhashes(args.scrape_directory, args.process_count)
109 |
110 | output_pickle_path = os.path.join(args.scrape_directory, "minhashes.pkl")
111 | timed_pickle_dump(minhashes_by_file, output_pickle_path, "minhashes_by_file")
112 |
113 |
--------------------------------------------------------------------------------
/cleaning/minhash_lsh_batching.py:
--------------------------------------------------------------------------------
1 | """
2 | Splits minhashes.pkl into approximately the desired number of batches.
3 | As we always split on file lines this won't always be exact unless all
4 | files have the same number of documents.
5 |
6 | The "directory" must contain a 'minhashes.pkl' file created with 'generate_minhashes.py'.
7 |
8 | Produces batch files named 'batch0.pkl, batch1.pkl ...'. They contain the following
9 | pickled data structure:
10 | [(file_id, [doc0_minhash, doc1_minhash, ...]), ....]
11 |
12 | Produces a file name lookup named 'file_name_lookup.pkl'. Contains the following
13 | pickled data structure:
14 | [file_name1, file_name2, file_name3, ...]
15 |
16 | Arguments
17 | ------
18 | --directory (-dir)
19 | Directory containing the 'minhashes.pkl' file. Batch files and
20 | file name lookup will be saved here.
21 | --number_of_batches (-batches)
22 | Approximate number of batches to split minhashes into.
23 | """
24 |
25 | import os
26 | import argparse
27 | import pickle
28 |
29 | import tqdm
30 | from utils.utils import timed_pickle_dump, timed_pickle_load
31 |
32 | import logging
33 | from utils.logger import setup_logger_tqdm
34 | logger = logging.getLogger(__name__)
35 |
36 | def main(number_of_batches, batch_directory):
37 | minhashes_pickle_path = os.path.join(batch_directory, "minhashes.pkl")
38 |
39 | # [(file_name, [doc0_minhash, doc1_minhash, ...]), ....]
40 | minhashes = timed_pickle_load(minhashes_pickle_path, "minhashes")
41 |
42 | logger.info("Splitting minhashes for batching...")
43 | total_documents = 0
44 | for _ , documents in minhashes:
45 | total_documents += len(documents)
46 |
47 | document_count = 0
48 | documents_per_batch = total_documents / number_of_batches
49 | current_batch = []
50 | batch_count = 0
51 | for file_id, (file_name, documents) in tqdm.tqdm(enumerate(minhashes)):
52 | document_count += len(documents)
53 | current_batch.append((file_id, documents)) # Note we only store globally unique file_id here
54 |
55 | if document_count > (batch_count + 1) * documents_per_batch:
56 | batch_pickle_file_path = os.path.join(batch_directory, f"batch{batch_count}.pkl")
57 | timed_pickle_dump(current_batch, batch_pickle_file_path, f"batch {batch_count} minhashes")
58 | current_batch = []
59 | batch_count += 1
60 |
61 | if current_batch:
62 | batch_pickle_file_path = os.path.join(batch_directory, f"batch{batch_count}.pkl")
63 | timed_pickle_dump(current_batch, batch_pickle_file_path, f"batch {batch_count} minhashes")
64 | current_batch = None
65 |
66 | file_name_lookup = [file_name for file_name, documents in minhashes]
67 | file_name_lookup_path = os.path.join(batch_directory, "file_name_lookup.pkl")
68 | timed_pickle_dump(file_name_lookup, file_name_lookup_path, "Filename lookup")
69 |
70 | document_count_path = os.path.join(batch_directory, "document_count.pkl")
71 | pickle.dump(total_documents, open(document_count_path,"wb"))
72 |
73 |
74 | parser = argparse.ArgumentParser(description='Generate batches of minhashes for cassandra lsh dedupe.')
75 | parser.add_argument("-dir", "--directory", default="")
76 | parser.add_argument("-batches", "--number_of_batches", type=int, required=True)
77 |
78 | if __name__ == '__main__':
79 | logfile_path = "minhash_lsh_batching.log"
80 | setup_logger_tqdm(logfile_path)
81 |
82 | args = parser.parse_args()
83 |
84 | main(args.number_of_batches, args.directory)
--------------------------------------------------------------------------------
/cleaning/minhash_lsh_dedupe.py:
--------------------------------------------------------------------------------
1 | """
2 | This script uses MinHashLSH from the datasketch library to deduplicate
3 | across the whole set of document minhashes.
4 |
5 | See http://ekzhu.com/datasketch/lsh.html for more details.
6 |
7 | We use the Cassandra backend to keep memory usage low and a 0.5 threshold
8 | for duplicate detection.
9 |
10 | In it's current state, the script creates and pickles a MinHashLSH object
11 | containing the parameters for a local cassandra instance. If you want
12 | a remote instance you could either change the script or perform port
13 | redirection to a remote machine with SSH tunelling for example. This uses
14 | the default Cassandra port.
15 |
16 | lsh = MinHashLSH(
17 | threshold=0.5, num_perm=10, storage_config={
18 | 'type': 'cassandra',
19 | 'cassandra': {
20 | 'seeds': ['127.0.0.1'],
21 | 'keyspace': 'minhash_lsh_keyspace',
22 | 'replication': {
23 | 'class': 'SimpleStrategy',
24 | 'replication_factor': '1',
25 | },
26 | 'drop_keyspace': False,
27 | 'drop_tables': False,
28 | }
29 | }
30 | )
31 |
32 | Importantly, once you run the script you must keep using the same lsh object
33 | if you want the same basenames for the created Cassandra tables.
34 |
35 | We are running into issues with loading this MinHashLSH object in multiple processes
36 | so did our run with just a single process. The first run will freeze when trying
37 | to unpickle the MinHashLSH object in the worker process. On running a second time it
38 | will work.
39 |
40 | We save file and document level checkpoints after each query to allow for easy resuming.
41 | Each batch file will have a corresponding "*_duplicates.txt" when done.
42 |
43 | Arguments
44 | ------
45 | --batch_directory (-dir)
46 | Directory containing the "batch*.pkl" files. "lsh.pkl", duplicate lists and
47 | batch checkpoints will be saved here.
48 | --process_count (-procs)
49 | Number of processes in the pool. Defaults to 1.
50 | """
51 |
52 | import os
53 | import glob
54 | import argparse
55 | import json
56 | import pickle
57 |
58 | import tqdm
59 | from datasketch import MinHashLSH
60 | from tqdm_multiprocess import TqdmMultiProcessPool
61 |
62 | from utils.utils import Timer, timed_pickle_dump, timed_pickle_load
63 |
64 | import logging
65 | from utils.logger import setup_logger_tqdm
66 | logger = logging.getLogger(__name__)
67 |
68 | def get_minhash_lsh_cassandra():
69 | lsh = MinHashLSH(
70 | threshold=0.5, num_perm=10, storage_config={
71 | 'type': 'cassandra',
72 | 'cassandra': {
73 | 'seeds': ['127.0.0.1'],
74 | 'keyspace': 'minhash_lsh_keyspace',
75 | 'replication': {
76 | 'class': 'SimpleStrategy',
77 | 'replication_factor': '1',
78 | },
79 | 'drop_keyspace': False,
80 | 'drop_tables': False,
81 | }
82 | }
83 | )
84 | return lsh
85 |
86 | def minhash_lsh_dedupe_cassandra(batch_minhashes_pickle_path, lsh_pickle_path, tqdm_func, global_tqdm):
87 | # [(file_id, [doc0_minhash, doc1_minhash, ...]), ....]
88 | batch_minhashes = timed_pickle_load(batch_minhashes_pickle_path, "batch minhashes")
89 |
90 | # For some reason this will freeze when loading on the first run.
91 | lsh = timed_pickle_load(lsh_pickle_path, "lsh")
92 |
93 | checkpoint_file = batch_minhashes_pickle_path.replace(".pkl","_ckpt.pkl")
94 | if os.path.exists(checkpoint_file):
95 | ckpt_file_id, ckpt_document_id = pickle.load(open(checkpoint_file,"rb"))
96 | else:
97 | ckpt_file_id = -1
98 | ckpt_document_id = -1
99 |
100 | logger.info("Detecting duplicates")
101 | timer = Timer().start()
102 | duplicate_file_path = batch_minhashes_pickle_path.replace(".pkl", "_duplicates.txt")
103 | with open(duplicate_file_path, "a") as fh:
104 | for file_id, documents in batch_minhashes:
105 | if file_id <= ckpt_file_id:
106 | global_tqdm.update(len(documents))
107 | continue
108 | for document_id, minhash in enumerate(documents):
109 | if document_id <= ckpt_document_id:
110 | global_tqdm.update(ckpt_document_id + 1)
111 | ckpt_document_id = -1
112 | continue
113 | results = lsh.query(minhash)
114 | duplicate_found = True if results else False
115 | is_self = False
116 | for json_results in results:
117 | found_file_id, found_document_id = json.loads(json_results)
118 | # This check is needed in case you re-run things
119 | if file_id == found_file_id and document_id == found_document_id:
120 | duplicate_found = False
121 | is_self = True
122 | break
123 |
124 | if duplicate_found:
125 | fh.write(f"{file_id} {document_id}\n")
126 | else:
127 | if not is_self:
128 | lsh.insert(json.dumps((file_id, document_id)), minhash)
129 |
130 | global_tqdm.update()
131 | pickle.dump((file_id, document_id), open(checkpoint_file,"wb"))
132 |
133 | logger.info(timer.stop_string())
134 |
135 | return True
136 |
137 | def main(process_count, batch_directory):
138 |
139 | # Ensure LSH object containing cassandra connection info exists
140 | lsh_pickle_path = os.path.join(batch_directory, "lsh.pkl")
141 | if not os.path.exists(lsh_pickle_path):
142 | logger.info("Getting cassandra minhash lsh")
143 | lsh = get_minhash_lsh_cassandra()
144 | timed_pickle_dump(lsh, lsh_pickle_path, "lsh")
145 |
146 | files = glob.glob(os.path.join(batch_directory, "batch*.pkl"), recursive=True)
147 |
148 | pool = TqdmMultiProcessPool()
149 | tasks = []
150 |
151 | document_count_path = os.path.join(batch_directory, "document_count.pkl")
152 | total_documents = pickle.load(open(document_count_path,"rb"))
153 |
154 | for batch_file in files:
155 | arguments = (batch_file, lsh_pickle_path)
156 | task = (minhash_lsh_dedupe_cassandra, arguments)
157 | tasks.append(task)
158 |
159 | on_done = lambda _ : logger.info("done")
160 | on_error = lambda _ : logger.info("error")
161 | with tqdm.tqdm(total=total_documents, dynamic_ncols=True) as progress:
162 | result = pool.map(process_count, progress, tasks, on_error, on_done)
163 | logger.info(result)
164 |
165 | parser = argparse.ArgumentParser(description='Minhash LSH dedupe with cassandra backend.')
166 | parser.add_argument("-dir", "--batch_directory", default="")
167 | parser.add_argument("-procs", "--process_count", type=int, default=1)
168 |
169 | if __name__ == '__main__':
170 | logfile_path = "minhash_lsh_dedupe.log"
171 | setup_logger_tqdm(logfile_path)
172 |
173 | args = parser.parse_args()
174 |
175 | main(args.process_count, args.batch_directory)
--------------------------------------------------------------------------------
/data_analysis/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/data_analysis/__init__.py
--------------------------------------------------------------------------------
/data_analysis/final_stats.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import argparse
4 | import pickle
5 | import math
6 |
7 | import tqdm
8 |
9 | from utils.archiver import Reader
10 |
11 | import logging
12 | from utils.logger import setup_logger_tqdm
13 | logger = logging.getLogger(__name__)
14 |
15 | def get_stats_old():
16 | batch_directory = "/home/researcher2/webtext2/test"
17 | files = glob.glob(os.path.join(batch_directory, "*_duplicates.txt"))
18 | duplicate_count = 0
19 | for file_path in files:
20 | with open(file_path, "r") as fh:
21 | duplicate_count += len(fh.readlines())
22 |
23 | document_count_path = os.path.join(batch_directory, "document_count.pkl")
24 | document_count = pickle.load(open(document_count_path, "rb"))
25 |
26 | print("Total Duplicates: ", duplicate_count)
27 | print("Original Documents: ", document_count)
28 |
29 | useful_percentage = (1 - duplicate_count / document_count) * 100
30 | print(f"Useful Data: {useful_percentage:0.2f}%")
31 |
32 | def get_stats(final_directory):
33 |
34 | reader = Reader()
35 | files = glob.glob(os.path.join(final_directory, "*jsonl.zst"))
36 |
37 | document_count = 0
38 | total_text_size = 0
39 | logger.info("Getting final document count and total uncompressed text size.")
40 | for file_path in tqdm.tqdm(files, dynamic_ncols=True):
41 | for document, metadata in reader.read_jsonl(file_path, get_meta=True):
42 | document_count += 1
43 | total_text_size += len(document)
44 |
45 | return document_count, total_text_size
46 |
47 | parser = argparse.ArgumentParser(description='Final statistics')
48 | parser.add_argument("-dir", "--final_directory", default="")
49 |
50 | if __name__ == '__main__':
51 | logfile_path = "final_statistics.log"
52 | setup_logger_tqdm(logfile_path)
53 |
54 | args = parser.parse_args()
55 |
56 | final_count, total_text_size = get_stats(args.final_directory)
57 | billion = math.pow(10, 9)
58 | logger.info(f"Final Document Count: {final_count:,}")
59 | print(f"Total uncompressed text size: {(total_text_size / billion):.2f} GB")
60 | pickle.dump((final_count, total_text_size), open("final_stats.pkl", "wb"))
--------------------------------------------------------------------------------
/mkdocs/docs/background.md:
--------------------------------------------------------------------------------
1 | # WebText Background
2 |
3 | OpenAI required around 40gb of high quality text corpus for training GPT2. While Common Crawl provides the scale necessary for modern language models, the quality is unreliable. Manual curation of Common Crawl is always an option, albeit an expensive one. Thankfully Reddit provides decentralized curation by design, and this became the key innovation for the WebText dataset.
4 |
5 | The generation of WebText can be summarized as:
6 |
7 | 1. Scrape URLs from all Reddit submissions up to December 2017 with 3 or higher score.
8 | 2. Deduplicate scraped content based on URL
9 | 3. Exclude wikipedia - OpenAI already had a separate Wikipedia dataset
10 | 4. Deduplicate remaining content using undisclosed "heuristic based cleaning". This includes removal of non-english web pages.
11 |
12 | Neither the resulting corpus or generation source code was made public, inspiring Aaron Gokaslan and Vanya Cohen to create the OpenWebTextCorpus.
13 |
14 | OpenWebTextCorpus is an open source reproduction of WebText, reifying the "heuristic based cleaning" stage with fuzzy deduplication and enforcing a minimum token length. For content based de-duplication they used local-sensitivity-hashing (LSH) with minhash on sets of 5-grams at the document level. Documents were then tokenized and any with less then 128 tokens were removed. After all processing there remained 40GB of text across 8,013,769 documents.
15 |
16 | The original code for OpenWebTextCorpus unavailable at this time, but there are several popular repositories that cover the pipeline to various degrees.
17 |
18 | ## OpenWebText2 Motivation
19 |
20 | Our primary goals for the corpus are:
21 |
22 | 1. More data! Coverage of the original OpenWebTextCorpus ended at December 2017.
23 | 2. Include all languages, providing metadata for easy filtering
24 | 3. Provide several versions of the generated corpus for differing user requirements. Both versions will be broken up by month and frozen, with future months available once PushShift submission dumps become available.
25 | * Raw version containing all scraped pages with associated Reddit submission metadata
26 | * Plug and play version based on submissions of minimum 3 score with content based fuzzy de-duplication
27 | 4. Provide full source code for all stages of the pipeline including deduplication.
28 |
29 | We decided on a rewrite taking inspiration from both 1 and 2.
--------------------------------------------------------------------------------
/mkdocs/docs/css/extra.css:
--------------------------------------------------------------------------------
1 | table {
2 | margin-bottom: 16px;
3 | }
4 |
5 | table th {
6 | padding: 8px;
7 | }
8 |
9 | table td {
10 | padding: 8px;
11 | }
12 |
13 | .rst-content table code {
14 | white-space: nowrap;
15 | }
16 |
17 | p {
18 | margin-bottom: 16px;
19 | }
20 |
21 | .download-button
22 | {
23 |
24 | }
25 |
26 | .download-button svg {
27 | margin-left: 4px;
28 | border: none;
29 | }
30 |
31 | .darken {
32 | filter: brightness(75%);
33 | }
--------------------------------------------------------------------------------
/mkdocs/docs/index.md:
--------------------------------------------------------------------------------
1 | # Welcome!
2 |
3 | OpenWebText2 is an enhanced version of the original OpenWebTextCorpus covering all Reddit submissions from 2005 up until April 2020, with further months becoming available after the corresponding PushShift dump files are released.
4 |
5 | In case you haven't heard of WebText, the core principle is extracting URLs from reddit submissions, scraping the URLs, then performing filtering & deduplication. See [Background](background) for more information.
6 |
7 |
8 |
9 | ## Download Plug and Play Version
10 | This version has already been cleaned for you:
11 |
12 | - Deduplicated by URL
13 | - Filtered by minimum combined reddit score 3
14 | - Deduplicated at document level with MinHashLSH.
15 |
16 | **Stats**
17 | 17,103,059 documents
18 | 65.86 GB uncompressed text
19 | 28 GB compressed including text and metadata
20 |
21 |
22 |
23 |
30 |
31 |
32 |
33 |
34 | ## Download Raw Scrapes Version
35 | Only deduplicated by URL.
36 |
37 | **Stats**
38 | 69,547,149 documents
39 | 193.89gb uncompressed text.
40 | 79gb compressed including text and metadata
41 |
42 |
43 |
50 |
51 |
52 |
53 |
54 | ## Using The Data
55 |
56 | The data is stored using lm_dataformat. We use a slightly modified version to allow file peeking for tqdm progress bars: utils/archiver.py. Be sure to call *read_jsonl* with `get_meta=True` as both versions contain useful metadata for each document, including several original Reddit fields.
57 |
58 | ```python
59 | import glob
60 | import os
61 | import math
62 |
63 | import tqdm
64 |
65 | from utils.archiver import Reader
66 |
67 | document_count = 0
68 | total_text_size = 0
69 | dataset_directory = "PATH_TO_FILES"
70 | files = glob.glob(os.path.join(dataset_directory, "*jsonl.zst"))
71 | for file_path in tqdm.tqdm(files, dynamic_ncols=True):
72 | reader = Reader()
73 | for document, metadata in reader.read_jsonl(file_path, get_meta=True):
74 | document_count += 1
75 | total_text_size += len(document)
76 |
77 | billion = math.pow(10, 9)
78 | print(f"Total Document Count: {document_count:,}")
79 | print(f"Total Uncompressed Text Size: {(total_text_size / billion):.2f} GB")
80 | ```
81 |
82 | Alternatively checkout The-Pile, which acts as an aggregator/dataloader for multiple text datasets. It allows you to configure your total data size requirement, along with the desired weighting for each subset. Once configured, you get a randomized stream of documents, allowing easy feeding to your language model.
83 |
84 | ## Cite as
85 |
86 |
87 | @article{pile,
88 | title={The {P}ile: An 800GB Dataset of Diverse Text for Language Modeling},
89 | author={Gao, Leo and Biderman, Stella and Black, Sid and Golding, Laurence and Hoppe, Travis and Foster, Charles and Phang, Jason and He, Horace and Thite, Anish and Nabeshima, Noa and Presser, Shawn and Leahy, Connor},
90 | journal={arXiv preprint arXiv:2101.00027},
91 | year={2020}
92 | }
93 |
94 |
--------------------------------------------------------------------------------
/mkdocs/docs/licence.md:
--------------------------------------------------------------------------------
1 | # MIT License
2 |
3 | Copyright (c) 2020 EleutherAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/mkdocs/docs/replication.md:
--------------------------------------------------------------------------------
1 | # Dataset Replication
2 |
3 | This area of the documentation provides instructions for building the full dataset from scratch. If you just want the dataset, please see [Welcome](/).
4 |
5 | PushShift provides dumps of all reddit posts and submissions, however they are normally a few months behind. While this would be problematic for certain use cases, we didn't require up to the minute data for training GPTNeo. In the future we may look into getting recent data either by scraping Reddit directly or using one of the existing APIs.
6 |
7 | ## Pipeline Overview
8 |
9 | At a high level the pipeline works as follows:
10 |
11 | 1. Download and process the PushShift submission dumps to extract unique URLs & Metadata.
12 | 2. Scrape the URLs using Newspaper3k, saving both text and metadata with lm_dataformat.
13 | 5. Filter the scraped documents by minimum Reddit score 3.
14 | 4. Perform fuzzy deduplication using MinHashLSH.
15 | 5. Package up the various dataset releases.
16 | 6. Produce some useful size stats for the releases.
17 |
18 | ## Environment Setup
19 | We tested everything on Ubuntu 18/20 & Windows 10 with miniconda. You could use virtualenv, venv or even the global python environment if you wish.
20 |
21 | ### Miniconda Install For Linux
22 |
23 | Follow the below steps, or read the conda instructions:
24 | https://docs.conda.io/projects/conda/en/latest/user-guide/install/linux.html
25 |
26 | ```bash
27 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
28 | sha256sum Miniconda3-latest-Linux-x86_64.sh
29 | bash Miniconda3-latest-Linux-x86_64.sh
30 | ```
31 | Select yes on the init step.
32 |
33 | Restart your shell to refresh the path.
34 |
35 | ### Create and activate conda environment
36 |
37 | Environments are saved in a central store on the local disk, no need to create folders like with venv.
38 | ```
39 | conda create --name pushshift python=3.8
40 | conda activate pushshift
41 | ```
42 |
43 | ### Install Repo and Requirements
44 | ```bash
45 | git clone https://github.com/EleutherAI/openwebtext2.git
46 | cd openwebtext2
47 | pip install -r requirements.txt
48 | ```
49 |
50 | ### General Recommendations
51 |
52 | Use the screen command if running a remote terminal, many scripts below take a long time to run and while they often support resuming it's better not to rely on it.
53 |
54 | Create a working directory on a drive with at least 500gb of space.
55 |
56 | ## Stage 1 - Processing PushShift Submission Dumps
57 |
58 | This stage consists of the following steps:
59 |
60 | 1. Sqlite database setup.
61 | 2. Download and verify the PushShift submission dumps, extracting and storing urls and metadata
62 | from relevant submissions into the sqlite database. Performed by "pushshift/pushshift_to_sqlite.py".
63 | 3. Query the populated sqlite database and build a list of URLs with metadata for all related submissions.
64 | Performed by "pushshift/generate_urls_from_sqlite.py"
65 |
66 | All scripts for this stage can be found within the "pushshift" package.
67 |
68 | ### Sqlite Database Setup
69 | We use alembic to manage the sqlite database in case you want to add extra indexes or fields easily later.
70 |
71 | Inside the cloned repo:
72 | ```bash
73 | cp alembic.ini.template alembic.ini
74 | ```
75 | Modify the following line within the alembic.ini file. For example on windows:
76 | ```python
77 | sqlalchemy.url = sqlite:///e:/Eleuther_AI/openwebtext2/submissions.sqlite
78 | ```
79 | Or on Linux (notice the extra forward slash before "mnt" to indicate system root):
80 | ```python
81 | sqlalchemy.url = sqlite:////mnt/data/openwebtext2/submissions.sqlite
82 | ```
83 |
84 | ### Insert PushShift Submission Data Into Sqlite
85 |
86 | This step is performed by *pushshift/pushshift_to_sqlite.py*.
87 |
88 | | Script Argument | Description |
89 | | -----------: | ----------- |
90 | | `--start_period (-s)` | Month and Year of first pushshift dump. Default: 6,2005. |
91 | | `--finish_period (-f)` | Month and Year of final pushshift dump. Defaults to current month. |
92 | | `--output_directory (-dir)` | Will contain the dumps subdirectory created as part of the process. |
93 | | `--keep_dumps (-kd)` | If specified the dumps won't be deleted after successful processing. |
94 |
95 | Notice the database location is not specified here, this is always sourced from the alembic.ini file.
96 |
97 | For example on Linux, to download and process all dumps, leaving the downloaded dumps afterwards:
98 | ```bash
99 | python -m pushshift.pushshift_to_sqlite -dir /mnt/data/openwebtext2 -kd
100 | ```
101 |
102 | Test run on 2006 only, deleting dumps when done:
103 | ```bash
104 | python -m pushshift.pushshift_to_sqlite -s 1,2006 -f 12,2006 -dir /mnt/data/openwebtext2
105 | ```
106 |
107 | This step uses checkpointing, saving a .dbdone file for each dump once processing is complete. So if you need to stop and come back later you can.
108 |
109 | ### Processing the PushShift Dumps With jq and sort instead
110 |
111 | TO DO.
112 |
113 |
114 | ### Extract Unique URLs with Reddit Metadata
115 |
116 | This step is performed by *pushshift/generate_urls.py*. Currently only supports sqlite mode, working on the json->sorted tsv version now.
117 |
118 | | Script Argument | Description |
119 | | -----------: | ----------- |
120 | | `--start_period (-s)` | Month and Year of first URLs. Defaults to None (query all URLs). |
121 | | `--finish_period (-f)` | Month and Year of final URLs. Defaults to None (query all URLs). |
122 | | `--output_directory (-dir)` | Will contain the urls subdirectory created as part of the process. |
123 | | `--urls_per_file` | Maximum number of urls per file. Defaults to 100,000. |
124 | | `--min_score (-score)` | Minimum aggregate submissions score to include url. Defaults to 3. |
125 | | `--data_source (-source)` | Where to find sorted URLs: "db" or "tsv". tsv doesn't support date ranges. Defaults "to db". |
126 |
127 |
128 |
129 | Notice the database location is not specified here, this is always sourced from the alembic.ini file.
130 |
131 | For example on Linux, to extract all urls
132 | ```bash
133 | python -m pushshift.generate_urls -dir /mnt/data/openwebtext2
134 | ```
135 |
136 | Test run on 2006 only:
137 | ```bash
138 | python -m pushshift.generate_urls -s 1,2006 -f 12,2006 -dir /mnt/data/openwebtext2
139 | ```
140 |
141 | ## Stage 2 - Scraping From Sourced URLs
142 |
143 | This stage is performed by *scraping/scrape_urls.py* and took several weeks compute time. To decrease this you can run on multiple servers by passing out the URL files.
144 |
145 | | Script Argument | Description |
146 | | -----------: | ----------- |
147 | | `--job_directory (-dir)` | Base directory containing the urls subdirectory and location where the scrapes subdirectory will be created. |
148 | | `--process_count (-procs)` | Number of worker processes in the pool. Defaults to 60. Don't go above this on Windows. |
149 | | `--request_timeout (-timeout)` | Scraping timeout for each URL. Defaults to 30 seconds. |
150 |
151 | The script iterates through URL files generated in step 2 above. For each file its hands out the URLs
152 | to a multiprocessing pool for scraping. Once all URLs in the batch are scraped, the successful results are
153 | archived using a slightly modified version of lm_dataformat. For each document (URL), the following metadata fields are saved in the metadata dict offered by lm_dataformat:
154 |
155 | | Meta Field | Description |
156 | | -----------: | ----------- |
157 | | title | Web Page Title. |
158 | | lang | Language detected by Newspaper scraper. |
159 | | url | Original URL. |
160 | | word_count | Total words outputted by Newspaper. |
161 | | elapsed | Scraping time. |
162 | | scraper | Always "newspaper". |
163 | | domain | Top level domain for the original URL. |
164 | | reddit_id | List of submission IDs containing URL - converted from base36. |
165 | | subreddit | List of subreddits for the corresponding submissions. |
166 | | reddit_score | List of reddit scores for the corresponding submissions. |
167 | | reddit_title | List of submissions titles for the corresponding submissions. |
168 | | reddit_created_utc | List of submissions created times for the corresponding submissions. |
169 |
170 | The program will look for URL files within "job_directory/urls". All scrapes will be stored in "job_directory/scrapes"
171 |
172 | For example on Linux, this will scrape using 90 processes and 30 second timeout:
173 | ```bash
174 | python -m scraping.scrape_urls -dir /mnt/data/openwebtext2 -procs 90
175 | ```
176 |
177 | On a dedicated 2012 i7 Linux machine we used between 90 and 120 processes successfully.
178 |
179 | We do some limited URL filtering in *scraping/filter.py*. This is mainly to speed up the process by avoiding timeouts or files that obviously won't contain text.
180 |
181 | Once each URL file is scraped, the program saves a ".done" file so you can resume later without rescraping. That file contains a count of successfully scraped URLs if you are interested.
182 |
183 | ## Stage 3 - Filtering scraped documents by minimum total Reddit score
184 |
185 | This stage is performed by *cleaning/filter_from_reddit_score.py*.
186 |
187 | | Script Argument | Description |
188 | | -----------: | ----------- |
189 | | `--scrape_directory (-dir)` | Directory containing the scrapes. You could use the overall work directory if you want as we use glob.glob to search recursively. |
190 |
191 | The script filters all scrape files "scrapes_*.jsonl.zst" by minimum total Reddit score.
192 | Unlike the original WebText we aggregate scores for all submissions containing a given
193 | URL so the bar is slightly lower in some cases, but in others where a URL went negative
194 | in some submission it will balance out.
195 |
196 | The filtered scrapes file will have the original name and path of the scrape file with a
197 | ".minscored" extension.
198 |
199 | For example on Linux:
200 | ```bash
201 | python -m cleaning.filter_from_reddit_score -dir /mnt/data/openwebtext2/scrapes
202 | ```
203 |
204 | ## Stage 4 - Deduplicate Filtered Documents using MinHashLSH with Cassandra
205 |
206 | There are several sub-stages here:
207 |
208 | 1. Setup Cassandra
209 | 2. Generate minhashes for every document
210 | 3. Batch up the minhashes for running parallel dedupe
211 | 4. Using MinHashLSH With Cassandra - Generate lists of duplicates
212 | 5. Deduplicating our documents using the lists from step 3.
213 |
214 | All scripts for this stage can be found within the "cleaning" package.
215 |
216 | ### Setup Cassandra
217 |
218 | We used a local Cassandra install, simplifying the setup process. Some good cassandra guides:
219 |
220 | Installation Guide For Ubuntu 20
221 | Introduction To Cassandra + Connecting With Python API
222 |
223 | Summarized Quick Install For Ubuntu 20:
224 | ```bash
225 | sudo apt install openjdk-8-jdk
226 | sudo apt install apt-transport-https
227 | wget -q -O - https://www.apache.org/dist/cassandra/KEYS | sudo apt-key add -
228 | sudo sh -c 'echo "deb http://www.apache.org/dist/cassandra/debian 311x main" > /etc/apt/sources.list.d/cassandra.list'
229 | sudo apt update
230 | sudo apt install cassandra
231 | sudo systemctl status cassandra
232 | ```
233 |
234 | To test your installation was successful, run the cqlsh CLI:
235 | ```bash
236 | cqlsh
237 | ```
238 |
239 | Once inside:
240 | ```describe keyspaces```
241 |
242 | If you want multiple nodes or remote connection you need to set the following in your /etc/cassandra/cassandra.yaml:
243 |
244 | seeds: "your_server_external_ip, other nodes in cluster"
245 | listen_address: your_server_external_ip
246 | start_rpc: true
247 | rpc_address: 0.0.0.0 (this will bind to same address as listen_address)
248 |
249 | For some reason they recommend not to make this available on the internet despite supporting various forms of authentication. So either use a tunnel or fancy networking to get around this.
250 |
251 | ### Generate Minhashes For Every Document
252 |
253 | This step is performed by *cleaning/generate_minhashes.py* and took about 1.5 days
254 | on a 2012 i7 Linux machine.
255 |
256 | | Script Argument | Description |
257 | | -----------: | ----------- |
258 | | `scrape_directory (-dir)` | Directory containing the minscored scrapes. You could use the overall work directory if you want as we use glob.glob to search recursively. |
259 | | `process_count (-procs)` | Number of worker processes in the pool. Defaults to 4. |
260 |
261 | This script calculates minhashes for all filtered scrape files found using a recursive
262 | search on "\*.minscored".
263 |
264 | More explicity, we create a set of 5-grams for each document, and generate document
265 | level minhashes using 10 hash functions with the excellent datasketch library.
266 |
267 | A single file "minhashes.pkl" is created in the scrape directory storing a data
268 | structure in the following format:
269 |
270 | ```python
271 | [(file_name1, [doc0_minhash, doc1_minhash, ...]), (file_name2, [....]), ....]
272 | ```
273 |
274 | For example on Linux:
275 | ```bash
276 | python -m cleaning.generate_minhashes -dir /mnt/data/openwebtext2/scrapes
277 | ```
278 |
279 | ### Slice The Minhashes For Batching
280 |
281 | This step is performed by *cleaning/minhash_lsh_batching.py*.
282 |
283 | | Script Argument | Description |
284 | | -----------: | ----------- |
285 | | `directory (-dir) ` | Directory containing the 'minhashes.pkl' file. Batch files and file name lookup will be saved here. |
286 | | `number_of_batches (-batches) ` | Approximate number of batches to split minhashes into. |
287 |
288 | The "directory" must contain a 'minhashes.pkl' file created with *cleaning/generate_minhashes.py*.
289 |
290 | This splits "minhashes.pkl" into approximately the desired number of batches, producing batch files named 'batch0.pkl, batch1.pkl, etc'. They contain the following pickled data structure:
291 | ```python
292 | [(file_id, [doc0_minhash, doc1_minhash, ...]), ....]
293 | ```
294 |
295 | It also creates a file name lookup named 'file_name_lookup.pkl' containing the following pickled datastructure:
296 | ```python
297 | [file_name1, file_name2, file_name3, ...]
298 | ```
299 |
300 | For example on Linux with 16 batches:
301 | ```bash
302 | python -m cleaning.minhash_lsh_batching -dir /mnt/data/openwebtext2/scrapes -batches 16
303 | ```
304 |
305 | ### Find Duplicates Using MinHashLSH with Cassandra
306 |
307 | This step is performed by *cleaning/minhash_lsh_dedupe.py*.
308 |
309 | | Script Argument | Description |
310 | | -----------: | ----------- |
311 | | `batch_directory (-dir)` | Directory containing the "batch\*.pkl" files. Duplicate lists and batch checkpoints will be saved here. |
312 | | `process_count (-procs)` | Number of processes in the pool. Defaults to 4. |
313 |
314 | The script generates a list of detected duplicates for files/documents located in the various "batch\*.pkl" files.
315 |
316 | We save file and document level checkpoints after each query to allow for easy resuming.
317 | Each batch file will have a corresponding "\*duplicates.txt" when done.
318 |
319 | For example on Linux with the default 4 processes:
320 |
321 | ```bash
322 | python -m cleaning.minhash_lsh_dedupe -dir /mnt/data/openwebtext2/scrapes
323 | ```
324 |
325 | ### De-Duplicating Using Generated Duplicate Lists
326 |
327 | This step is performed by *cleaning/dedupe_from_indexes.py*.
328 |
329 | | Script Argument | Description |
330 | | -----------: | ----------- |
331 | | `batch_directory (-dir)` | Directory containing the "\*duplicates.txt" files along with the "file_name_lookup.pkl" created during batch slicing. The "\*final.jsonl.zst" files will be output in their original directories. |
332 |
333 | This script builds a list of all duplicates by file_id & document_id, and then iterates
334 | through all ".minscored" files from the filename lookup, creating a new archive for each
335 | file in the original containing all documents that were not marked as duplicates during
336 | the previous step.
337 |
338 | For each original file, a "\*final.jsonl.zst" files will be output in the same directory.
339 |
340 | For example on Linux:
341 | ```bash
342 | python -m cleaning.dedupe_from_indexes -dir /mnt/data/openwebtext2/scrapes
343 | ```
344 |
345 | ## Stage 5 - Packaging The Dataset Releases
346 |
347 | ### Plug And Play Release
348 |
349 | We originally did processing by month, but now just use files containing scrapes for the original
350 | URL files.
351 |
352 | Simply tar all the "\*final.jsonl.zst" files.
353 |
354 | ### Raw Scrapes Release
355 |
356 | Similarly just tar all the "scrapes_\*.jsonl.zst" files.
357 |
358 | ## Stage 6 - Produce Release Stats
359 |
360 | If you move the files from each release into their own subdirectory, you can run the "data_analysis/final_stats.py"
361 | to get total document count and text size for all "jsonl.zst" files in each directory:
362 |
363 | For example on Linux:
364 | ```bash
365 | python -m data_analysis.final_stats -dir /mnt/data/openwebtext2/final
366 | python -m data_analysis.final_stats -dir /mnt/data/openwebtext2/raw_release
367 | ```
368 |
--------------------------------------------------------------------------------
/mkdocs/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: OpenWebText2
2 | theme: readthedocs
3 | docs_dir: 'docs'
4 | nav:
5 | - Welcome: index.md
6 | - WebText Background: background.md
7 | - Dataset Replication: replication.md
8 | - Licence: licence.md
9 | extra_css:
10 | - css/extra.css
11 |
--------------------------------------------------------------------------------
/pushshift/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/pushshift/__init__.py
--------------------------------------------------------------------------------
/pushshift/download_pushshift_dumps.py:
--------------------------------------------------------------------------------
1 | """
2 | This module is responsible for downloading the PushShift submission dump
3 | files, and contains the following functions:
4 |
5 | Functions
6 | ---------
7 | get_url_content_length(url)
8 | Attempts to retrieve the Content-Length header when performing
9 | a get request on the provided 'url'. Returns the number of bytes
10 | if available, or None.
11 |
12 | build_file_list(start_date, end_date):
13 | Builds a list of PushShift submission dump files located at
14 | "https://files.pushshift.io/reddit/submissions" within the desired date
15 | range.
16 |
17 | get_sha256sums
18 | Downloads the sha256sum file for the PushShift submission dumps from
19 | "https://files.pushshift.io/reddit/submissions/sha256sums.txt". Builds
20 | and returns a dictionary with file name as key and sha256 sum as value.
21 | """
22 |
23 | from dateutil.relativedelta import *
24 | import math
25 |
26 | import requests
27 |
28 | import logging
29 | logger = logging.getLogger(__name__)
30 |
31 | million = math.pow(10, 6)
32 | possible_archive_formats = ["zst", "xz", "bz2"]
33 |
34 | def get_url_content_length(url):
35 | response = requests.head(url)
36 | response.raise_for_status()
37 |
38 | if "Content-Length" in response.headers:
39 | return int(response.headers['Content-length'])
40 | else:
41 | return None
42 |
43 | def build_file_list(start_date, end_date):
44 | base_url = "https://files.pushshift.io/reddit/submissions"
45 | url_list = []
46 | date = start_date
47 | current_year = None
48 | while date <= end_date:
49 | year = date.strftime("%Y")
50 | month = date.strftime("%m")
51 |
52 | if year != current_year:
53 | current_year = year
54 | logger.info(f"Scanning Year {current_year}")
55 |
56 | if year < "2011":
57 | url = f"{base_url}/RS_v2_{year}-{month}.xz"
58 | url_list.append(url)
59 | else:
60 | for extension in possible_archive_formats:
61 | url = f"{base_url}/RS_{year}-{month}.{extension}"
62 | try:
63 | get_url_content_length(url) # If this fails there's no file
64 | url_list.append(url)
65 | break
66 | except:
67 | pass
68 |
69 | date = date + relativedelta(months=+1)
70 |
71 | return url_list
72 |
73 | def get_sha256sums():
74 | sha256sum_url = "https://files.pushshift.io/reddit/submissions/sha256sums.txt"
75 |
76 | sha256sum_lookup = {}
77 | with requests.get(sha256sum_url) as response:
78 | response.raise_for_status()
79 | for line in response.text.splitlines():
80 | if line.strip():
81 | sha256sum, file_name = tuple(line.strip().split(" "))
82 | sha256sum_lookup[file_name] = sha256sum
83 |
84 | return sha256sum_lookup
--------------------------------------------------------------------------------
/pushshift/generate_urls.py:
--------------------------------------------------------------------------------
1 | """
2 | This script produces files containing the urls and associated Reddit metadata
3 | for a given period, defaulting to 1,000,000 urls per file. Using lm_dataformat,
4 | one record is stored for each URL, along with the metadata of all submissions for
5 | that particular URL.
6 |
7 | No separate URL based deduplication is required unless you run multiple iterations of
8 | the script across different time periods (unimplemented but very simple).
9 |
10 | Note that we don't filter by score at this stage as the full pipeline scrapes all urls
11 | and leaves the filtering to be done by the user if they don't want the plug and play version.
12 |
13 | Arguments
14 | ---------
15 | --start_period (-s)
16 | Month and Year of first URLs. Defaults to None (query all URLs).
17 | --finish_period (-f)
18 | Month and Year of final URLs. Defaults to None (query all URLs).
19 | --output_directory (-dir)
20 | Base directory that will contain the urls subdirectory created as part of the process.
21 | --urls_per_file
22 | Maximum number of urls per file. Defaults to 100,000.
23 | --min_score
24 | Minimum aggregate submissions score to include url.
25 | --data_source
26 | Where to find sorted URLs: "db" or "tsv". tsv doesn't support date ranges.
27 |
28 | If both start_period and finish_period are blank then we can use a faster query on the reddit_submission
29 | table.
30 | """
31 |
32 | import datetime
33 | from dateutil.relativedelta import *
34 | import argparse
35 | import os
36 | import json
37 | from functools import partial
38 | import sys
39 |
40 | from utils.archiver import Archive
41 | from .models import RedditSubmission, get_db_session
42 |
43 | import logging
44 | from utils.logger import setup_logger_tqdm
45 | logger = logging.getLogger(__name__)
46 |
47 | def get_from_db(start_date, end_date):
48 | db_session = get_db_session()
49 |
50 | # SELECT id, url, score, title, subreddit, created_utc
51 | # FROM reddit_submission
52 | # WHERE created_utc >= start_date and created_utc <= end_date
53 | # ORDER BY url
54 | select_fields = (RedditSubmission.id, RedditSubmission.url, RedditSubmission.score,
55 | RedditSubmission.title, RedditSubmission.subreddit, RedditSubmission.created_utc)
56 |
57 | if start_date or end_date:
58 | month_end = end_date + relativedelta(months=+1)
59 | query = db_session.query(*select_fields) \
60 | .filter(RedditSubmission.created_utc >= start_date) \
61 | .filter(RedditSubmission.created_utc < month_end) \
62 | .order_by(RedditSubmission.url) \
63 | .yield_per(1000)
64 | else:
65 | query = db_session.query(*select_fields) \
66 | .order_by(RedditSubmission.url) \
67 | .yield_per(1000)
68 |
69 | logger.info("Querying sqlite database for submissions")
70 | logger.info(query)
71 | return query
72 |
73 | def get_from_tsv():
74 | pass
75 |
76 |
77 | def generate_urls(url_directory, urls_per_file, min_score, source):
78 |
79 | url_batch = 0
80 | url_file_path = os.path.join(url_directory, f"urls_{url_batch}.jsonl.zst")
81 | archiver = Archive(url_file_path)
82 |
83 | current_url = ""
84 | current_meta = {}
85 | current_meta["id"] = []
86 | current_meta["score"] = []
87 | current_meta["title"] = []
88 | current_meta["subreddit"] = []
89 | current_meta["created_utc"] = []
90 |
91 | total_url_count = 0
92 | url_count = 0
93 | logger.info("Generating now...")
94 | for submission_id, url, score, title, subreddit, created_utc in source():
95 | if not current_url:
96 | current_url = url
97 | elif url != current_url:
98 | # New URL - Add Old URL and meta to archive if score is high enough
99 | total_score = sum(current_meta["score"])
100 | if (total_score >= min_score):
101 | archiver.add_data(current_url, current_meta)
102 | url_count += 1
103 | total_url_count += 1
104 |
105 | # Commit and Init New Archive if full
106 | if url_count == urls_per_file:
107 | archiver.commit()
108 | url_batch += 1
109 | url_file_path = os.path.join(url_directory, f"urls_{url_batch}.jsonl.zst")
110 | archiver = Archive(url_file_path)
111 | url_count = 0
112 |
113 | current_url = url
114 | current_meta = {}
115 | current_meta["id"] = []
116 | current_meta["score"] = []
117 | current_meta["title"] = []
118 | current_meta["subreddit"] = []
119 | current_meta["created_utc"] = []
120 |
121 | current_meta["id"].append(submission_id)
122 | current_meta["score"].append(score)
123 | current_meta["title"].append(title)
124 | current_meta["subreddit"].append(subreddit)
125 | current_meta["created_utc"].append(created_utc)
126 |
127 | if url_count > 0:
128 | archiver.add_data(current_url, current_meta)
129 | total_url_count += 1
130 | archiver.commit()
131 |
132 | url_count_path = os.path.join(url_directory, "url_count.json")
133 | json.dump(total_url_count, open(url_count_path, "w"))
134 |
135 | parser_description = 'Generate URL files from sqlite database containing URLs and reddit metadata.'
136 | parser = argparse.ArgumentParser(description=parser_description)
137 | parser.add_argument("-s", "--start_period", default=None)
138 | parser.add_argument("-f", "--finish_period", default=None)
139 | parser.add_argument("-dir", "--output_directory", default="")
140 | parser.add_argument("--urls_per_file", type=int, default=100000)
141 | parser.add_argument("-score", "--min_score", type=int, default=3)
142 | parser.add_argument("-source", "--data_source", default="db")
143 |
144 | if __name__ == '__main__':
145 | args = parser.parse_args()
146 |
147 | logfile_path = "generate_urls.log"
148 | setup_logger_tqdm(logfile_path)
149 |
150 | # If both none we can just query all and use the index on url field
151 | # Otherwise we use the index on created_utc and have to do a costly url sort
152 | if args.start_period or args.finish_period:
153 | if args.start_period:
154 | start_month, start_year = tuple(map(int,args.start_period.split(",")))
155 | start_date = datetime.datetime(start_year, start_month, 1)
156 | else:
157 | start_date = datetime.datetime(2005, 6, 1)
158 |
159 | if args.finish_period:
160 | finish_month, finish_year = tuple(map(int,args.finish_period.split(",")))
161 | end_date = datetime.datetime(finish_year, finish_month, 1)
162 | else:
163 | end_date = datetime.datetime.now()
164 |
165 | logger.info(f"Finding URLs between {start_date.strftime('%Y-%m')} and {end_date.strftime('%Y-%m')}")
166 | else:
167 | logger.info(f"Finding all URLs.")
168 | start_date = None
169 | end_date = None
170 |
171 | urls_directory = os.path.join(args.output_directory, "urls")
172 |
173 | logger.info(f"Urls output directory: {urls_directory}")
174 | logger.info(f"Minimum score: {args.min_score}")
175 | logger.info(f"URLs per file: {args.urls_per_file}")
176 |
177 | if args.data_source == "db":
178 | source = partial(get_from_db, start_date, end_date)
179 | elif args.data_source == "tsv":
180 | source = get_from_tsv
181 | else:
182 | logger.info(f"Invalid source {args.data_source}")
183 | sys.exit(-1)
184 |
185 | logger.info(f"Data source: {args.data_source}")
186 |
187 | generate_urls(urls_directory, args.urls_per_file, args.min_score, source)
188 |
189 |
--------------------------------------------------------------------------------
/pushshift/models.py:
--------------------------------------------------------------------------------
1 | """
2 | SqlAlchemy model for RedditSubmission along with some database helper functions.
3 | To populate the database see pushshift_to_sqlite.py.
4 | """
5 |
6 | import os
7 | import configparser
8 |
9 | from sqlalchemy.ext.declarative import declarative_base
10 | from sqlalchemy import Column, Integer, Text, DateTime
11 | from sqlalchemy import create_engine
12 | from sqlalchemy.schema import Index
13 | from sqlalchemy.orm import sessionmaker
14 |
15 | base = declarative_base()
16 |
17 | class RedditSubmission(base):
18 | __tablename__ = "reddit_submission"
19 | id = Column(Integer, primary_key=True) # Converted from Base36
20 | url = Column(Text, index=True)
21 | score = Column(Integer)
22 | title = Column(Text)
23 | subreddit = Column(Text)
24 | created_utc = Column(DateTime, index=True)
25 |
26 | # Index('idx_url_created', RedditSubmission.url, RedditSubmission.created_utc)
27 |
28 | def get_db_session():
29 | config = configparser.ConfigParser()
30 | config.read('alembic.ini')
31 | db_url = config["alembic"]["sqlalchemy.url"]
32 |
33 | fresh_db = False
34 | db_file_path = db_url.replace("sqlite:///","")
35 | if not os.path.exists(db_file_path):
36 | fresh_db = True
37 |
38 | engine = create_engine(db_url)
39 | if fresh_db:
40 | base.metadata.create_all(engine)
41 |
42 | Session = sessionmaker(bind=engine)
43 | db_session = Session()
44 | return db_session
45 |
46 | def recreate_db():
47 | config = configparser.ConfigParser()
48 | config.read('alembic.ini')
49 | db_url = config["alembic"]["sqlalchemy.url"]
50 |
51 | db_file_path = db_url.replace("sqlite:///","")
52 | if os.path.exists(db_file_path):
53 | os.remove(db_file_path) # sqlite doesn't truncate on drop_all
54 |
55 | engine = create_engine(db_url)
56 | base.metadata.create_all(engine)
57 |
58 | if __name__ == '__main__':
59 | recreate_db()
--------------------------------------------------------------------------------
/pushshift/process_dump_files_sqlite.py:
--------------------------------------------------------------------------------
1 | """
2 | Called from pushshift_to_sqlite.py.
3 |
4 | Processes a PushShift submission dump file, storing the url and relevant Reddit metadata
5 | into the sqlite database specified in alembic.ini. Note the reddit submission
6 | id is converted from base36 first and the created_utc is stored as a datetime
7 | object.
8 |
9 | process_dump_file is the entry point, requiring you to specify 'dump_file_path'
10 | and 'output_directory'. Supports tqdm-multiprocess.
11 |
12 | metadata = {}
13 | metadata["id"] = base36.loads(post["id"])
14 | metadata["subreddit"] = post.get("subreddit")
15 | metadata["title"] = post.get("title")
16 | metadata["score"] = post.get("score")
17 | metadata["created_utc"] = datetime.datetime.fromtimestamp(int(post["created_utc"]))
18 | """
19 |
20 | import json
21 | import os
22 | import math
23 | import datetime
24 |
25 | import base36
26 | from sqlalchemy import exc
27 |
28 | from .models import RedditSubmission
29 | from utils.archive_stream_readers import get_archive_stream_reader
30 |
31 | import logging
32 | logger = logging.getLogger()
33 |
34 | million = math.pow(10, 6)
35 |
36 | chunk_size = int(16 * million) # Must be larger then the size of any post
37 |
38 | def process_reddit_post(post):
39 | is_self = post.get("is_self")
40 | if is_self is None or is_self:
41 | return None
42 |
43 | url = post.get("url")
44 | if url is None or url == "":
45 | return None
46 |
47 | reddit_submission = RedditSubmission()
48 | reddit_submission.id = base36.loads(post["id"])
49 | reddit_submission.subreddit = post.get("subreddit")
50 | reddit_submission.title = post.get("title")
51 | reddit_submission.score = post.get("score", 0)
52 | reddit_submission.created_utc = datetime.datetime.fromtimestamp(int(post["created_utc"]))
53 | reddit_submission.url = url
54 |
55 | return reddit_submission
56 |
57 | def process_dump_file(dump_file_path, db_session, tqdm_func):
58 | logging.info(f"Processing dump file '{dump_file_path}'")
59 | dump_file_size = os.path.getsize(dump_file_path)
60 |
61 | previous_file_position = 0
62 | count = 0
63 | insert_batch_size = 100000
64 | with get_archive_stream_reader(dump_file_path) as reader, \
65 | tqdm_func(total=dump_file_size, unit="byte", unit_scale=1) as progress:
66 |
67 | progress.set_description(f"Processing {os.path.basename(dump_file_path)}")
68 |
69 | previous_line = ""
70 | while True:
71 | chunk = reader.read(chunk_size)
72 | if not chunk:
73 | break
74 |
75 | # Update Progress Bar
76 | current_file_position = reader.tell()
77 | progress.update(current_file_position - previous_file_position)
78 | previous_file_position = current_file_position
79 |
80 | # Process chunk + leftover, ignore possibly incomplete last line
81 | try:
82 | string_data = chunk.decode("utf-8")
83 | except UnicodeDecodeError as ex:
84 | logger.info(f"Error in position {current_file_position} in file {dump_file_path}")
85 | logger.info(ex)
86 | continue
87 | lines = string_data.split("\n")
88 | for i, line in enumerate(lines[:-1]):
89 | if i == 0:
90 | line = previous_line + line
91 |
92 | reddit_post = None
93 | try:
94 | reddit_post = json.loads(line)
95 | except Exception as ex:
96 | logger.info(f"JSON decoding failed: {ex}")
97 | continue
98 |
99 | reddit_submission = process_reddit_post(reddit_post)
100 | if reddit_submission:
101 | db_session.add(reddit_submission)
102 | count += 1
103 |
104 | if count == insert_batch_size:
105 | logging.info(f"Committing {count} records to db.")
106 | try:
107 | db_session.commit()
108 | except exc.IntegrityError:
109 | logger.info(f"Duplicate INSERT, ignoring.")
110 | db_session.rollback()
111 | count = 0
112 |
113 | previous_line = lines[-1]
114 |
115 | if count > 0:
116 | logging.info(f"Committing {count} records to db.")
117 | try:
118 | db_session.commit()
119 | except exc.IntegrityError:
120 | logger.info(f"Duplicate INSERT, ignoring.")
121 | db_session.rollback()
122 | count = 0
123 |
124 | logging.info("Done with file.")
125 |
--------------------------------------------------------------------------------
/pushshift/pushshift_to_sqlite.py:
--------------------------------------------------------------------------------
1 | """
2 | Builds a list of PushShift submission dump files located in "https://files.pushshift.io/reddit/submissions"
3 | within the desired date range, and then performs the following steps for each file. Note this can't be done
4 | with a multiprocessing pool due to locking issues with sqlite.
5 |
6 | 1. Download and verify the file using the available sha256 sums
7 | 2. Process the file, storing the url and relevant Reddit metadata into the sqlite database
8 | specified in alembic.ini (copy alembic.ini.template and set sqlalchemy.url).
9 | 3. Create a .dbdone file to mark the particular file as being processed, allowing script resume.
10 | 4. Delete the PushShift dump file to save storage space if --keep_dumps not specified.
11 |
12 | Arguments
13 | ---------
14 | --start_period (-s)
15 | Month and Year of first pushshift dump. Default: 6,2005
16 | --finish_period (-f)
17 | Month and Year of final pushshift dump. Defaults to current month, ignoring any missing months.
18 | --output_directory (-dir)
19 | Base directory that will contain the dumps subdirectory created as part of the process.
20 | --keep_dumps (-kd)
21 | If specified the dump won't be deleted after successful processing.
22 | """
23 |
24 | import datetime
25 | import os
26 | import argparse
27 | import sys
28 |
29 | from best_download import download_file
30 | import cutie
31 | import tqdm
32 |
33 | from .download_pushshift_dumps import build_file_list, get_sha256sums
34 | from .process_dump_files_sqlite import process_dump_file
35 | from .models import get_db_session
36 |
37 | import logging
38 | from utils.logger import setup_logger_tqdm
39 | logger = logging.getLogger(__name__)
40 |
41 | def reddit_processing(url, sha256sums, dumps_directory, keep_dumps):
42 |
43 | base_name = url.split('/')[-1]
44 | dump_file_path = os.path.join(dumps_directory, base_name)
45 | db_done_file = dump_file_path + ".dbdone"
46 |
47 | if os.path.exists(db_done_file):
48 | return True
49 |
50 | try:
51 | download_file(url, dump_file_path, sha256sums.get(base_name))
52 | except Exception as ex:
53 | logger.info(f"Download failed {ex}, skipping processing.")
54 | return False
55 |
56 | db_session = get_db_session()
57 | process_dump_file(dump_file_path, db_session, tqdm.tqdm)
58 |
59 | with open(db_done_file, "w") as fh:
60 | fh.write("Done!")
61 |
62 | if not keep_dumps:
63 | os.remove(dump_file_path)
64 |
65 | return True
66 |
67 | parser = argparse.ArgumentParser(description='Download PushShift submission dumps, extra urls')
68 | parser.add_argument("-s", "--start_period", default="6,2005")
69 | parser.add_argument("-f", "--finish_period", default=None)
70 | parser.add_argument("-dir", "--output_directory", default="")
71 | parser.add_argument("-kd", "--keep_dumps", action='store_true')
72 |
73 | # First available file: https://files.pushshift.io/reddit/submissions/RS_v2_2005-06.xz
74 | def main():
75 | logfile_path = "download_pushshift_dumps.log"
76 | setup_logger_tqdm(logfile_path) # Logger will write messages using tqdm.write
77 |
78 | args = parser.parse_args()
79 |
80 | start_month, start_year = tuple(map(int,args.start_period.split(",")))
81 | start_date = datetime.datetime(start_year, start_month, 1)
82 |
83 | if args.finish_period:
84 | finish_month, finish_year = tuple(map(int,args.finish_period.split(",")))
85 | end_date = datetime.datetime(finish_year, finish_month, 1)
86 | else:
87 | end_date = datetime.datetime.now()
88 |
89 | logger.info("Running Script - PushShift submission dumps to sqlite")
90 | logger.info("Downloading and processing dumps in the following range:")
91 | logger.info(start_date.strftime("Start Period: %m-%Y"))
92 | logger.info(end_date.strftime("End Period: %m-%Y"))
93 |
94 | dumps_directory = os.path.join(args.output_directory, "dumps")
95 |
96 | if os.path.isdir(dumps_directory):
97 | message = f"Directory '{dumps_directory}' already exists, if there are done files" \
98 | " in the directory then these particular months will be skipped. Delete" \
99 | " these files or the directory to avoid this."
100 | logger.info(message)
101 | if not cutie.prompt_yes_or_no('Do you want to continue?'):
102 | sys.exit(0)
103 |
104 | os.makedirs(dumps_directory, exist_ok=True)
105 |
106 | logger.info("Building PushShift submission dump file list...")
107 | url_list = build_file_list(start_date, end_date)
108 |
109 | logger.info("Getting sha256sums")
110 | sha256sums = get_sha256sums()
111 |
112 | # Download and Process
113 | logger.info("Commencing download and processing into sqlite.")
114 | results = []
115 | for url in url_list:
116 | result = reddit_processing(url, sha256sums, dumps_directory, args.keep_dumps)
117 | results.append(result)
118 |
119 | if __name__ == '__main__':
120 | main()
121 |
122 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | zstandard
2 | requests
3 | python-dateutil
4 | tqdm
5 | tldextract
6 | beautifulsoup4
7 | newspaper3k
8 | htmlmin
9 | lm_dataformat
10 | jsonlines
11 | datasketch
12 | colorama
13 | cutie
14 | sqlalchemy
15 | tqdm-multiprocess
16 | base36
17 | alembic
18 | cassandra-driver
19 | best-download
20 |
21 | mkdocs
--------------------------------------------------------------------------------
/scraping/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/scraping/__init__.py
--------------------------------------------------------------------------------
/scraping/filter.py:
--------------------------------------------------------------------------------
1 | import tldextract
2 | import re
3 |
4 | import logging
5 | logger = logging.getLogger("filelock")
6 | logger.setLevel(logging.WARNING)
7 |
8 | # https://stackoverflow.com/questions/7160737/python-how-to-validate-a-url-in-python-malformed-or-not
9 | url_regex = re.compile(
10 | r'^(?:http)s?://' # http:// or https://
11 | r'(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|' #domain...
12 | r'\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})' # ...or ip
13 | r'(?::\d+)?' # optional port
14 | r'(?:/?|[/?]\S+)$', re.IGNORECASE)
15 |
16 | # domains that aren't scraper friendly. do not include subdomains!
17 | exclude_domains = set([
18 | # image & video hosting sites
19 | 'imgur.com',
20 | 'redd.it',
21 | 'instagram.com',
22 | 'discord.gg',
23 | 'gfycat.com',
24 | 'giphy.com',
25 | 'reddituploads.com',
26 | 'redditmedia.com',
27 | 'twimg.com',
28 | 'sli.mg',
29 | 'magaimg.net',
30 | 'flickr.com',
31 | 'imgflip.com',
32 | 'youtube.com',
33 | 'youtu.be',
34 | 'youtubedoubler.com',
35 | 'vimeo.com',
36 | 'twitch.tv',
37 | 'streamable.com',
38 | 'bandcamp.com',
39 | 'soundcloud.com',
40 |
41 | # not scraper friendly
42 | 'reddit.com',
43 | 'gyazo.com',
44 | 'github.com',
45 | 'xkcd.com',
46 | 'twitter.com',
47 | 'spotify.com',
48 | 'itunes.apple.com',
49 | 'facebook.com',
50 | 'gunprime.com',
51 | 'strawpoll.me',
52 | 'voyagefusion.com',
53 | 'rollingstone.com',
54 | 'google.com',
55 | 'timeanddate.com',
56 | 'walmart.com',
57 | 'roanoke.com',
58 | 'spotrac.com',
59 |
60 | # original paper excluded wikipedia
61 | 'wikipedia.org',
62 |
63 | # lots of top posts for this one
64 | 'battleforthenet.com',
65 | ])
66 |
67 | exclude_extensions = (
68 | '.png',
69 | '.jpg',
70 | '.jpeg',
71 | '.gif',
72 | '.gifv',
73 | '.pdf',
74 | '.mp4',
75 | '.mp3',
76 | '.ogv',
77 | '.webm',
78 | '.doc',
79 | '.docx',
80 | '.log',
81 | '.csv',
82 | '.dat',
83 | '.iso',
84 | '.bin',
85 | '.exe',
86 | '.apk',
87 | '.jar',
88 | '.app',
89 | '.ppt',
90 | '.pps',
91 | '.pptx',
92 | '.xml',
93 | '.gz',
94 | '.xz',
95 | '.bz2',
96 | '.tgz',
97 | '.tar',
98 | '.zip',
99 | '.wma',
100 | '.mov',
101 | '.wmv',
102 | '.3gp',
103 | '.svg',
104 | '.rar',
105 | '.wav',
106 | '.avi',
107 | '.7z'
108 | )
109 |
110 | def should_exclude(url):
111 |
112 | ext = tldextract.extract(url)
113 | domain = '.'.join([x for x in ext if x])
114 | basedomain = '.'.join(ext[-2:])
115 |
116 | # Ignore non-URLs
117 | if len(url) <= 8 or ' ' in url or re.match(url_regex, url) is None:
118 | return True
119 |
120 | # Ignore excluded domains
121 | if basedomain in exclude_domains or domain in exclude_domains:
122 | return True
123 |
124 | # Ignore case-insensitive matches for excluded extensions
125 | if url.lower().split('?')[0].endswith(exclude_extensions):
126 | return True
127 |
128 | return False
--------------------------------------------------------------------------------
/scraping/scrape_urls.py:
--------------------------------------------------------------------------------
1 | """
2 | This program iterates through URL files generated in step 2 above. For each file its hands out the URLs
3 | to a multiprocessing pool for scraping. Once all URLs in the batch are scraped, the successful results are
4 | archived using a slightly modified version of [lm_dataformat](https://github.com/leogao2/lm_dataformat)
5 | (thanks @bmk). The following metadata fields are saved in the metadata dict offered by lm_dataformat:
6 |
7 | title: Web Page Title
8 | lang: Language detected by Newspaper scraper.
9 | url: Original URL.
10 | word_count: Total words outputted by Newspaper.
11 | elapsed: Scraping time.
12 | scraper: Always "newspaper".
13 | domain: Top level domain for the original URL.
14 | reddit_id: List of submission IDs containing URL - converted from base36.
15 | subreddit: List of subreddits for the corresponding submissions.
16 | reddit_score: List of reddit scores for the corresponding submissions.
17 | reddit_title: List of submissions titles for the corresponding submissions.
18 | reddit_created_utc: List of submissions created times for the corresponding submissions.
19 |
20 | Arguments
21 | ---------
22 | --job_directory (-dir)
23 | Base directory containing the urls subdirectory and location where the scrapes subdirectory
24 | will be created.
25 | --process_count (-procs)
26 | Number of worker processes in the pool. Defaults to 60. Don't go above this on Windows.
27 | --request_timeout (-timeout)
28 | Scraping timeout for each URL. Defaults to 30 seconds.
29 | """
30 |
31 | import os
32 | import sys
33 | import glob
34 | import json
35 | import argparse
36 |
37 | import tqdm
38 | from tqdm_multiprocess import TqdmMultiProcessPool
39 |
40 | from scraping.scrapers import newspaper_scraper
41 | from utils.archiver import Reader, Archive
42 | from utils.utils import Timer
43 |
44 | import logging
45 | from utils.logger import setup_logger_tqdm
46 | logger = logging.getLogger(__name__)
47 |
48 | # Multiprocessed
49 | def download(url_entry, request_timeout, scraper,
50 | memoize, tqdm_func, global_tqdm):
51 |
52 | url, reddit_meta = url_entry
53 | text, meta, success = scraper(url, memoize, request_timeout=request_timeout)
54 |
55 | if not success or text is None or text.strip() == "":
56 | if global_tqdm:
57 | global_tqdm.update()
58 | return (text, meta, False)
59 |
60 | # Add extra meta
61 | meta["reddit_id"] = reddit_meta["id"]
62 | meta["subreddit"] = reddit_meta["subreddit"]
63 | meta["reddit_score"] = reddit_meta["score"]
64 | meta["reddit_title"] = reddit_meta["title"]
65 | meta["reddit_created_utc"] = reddit_meta["created_utc"]
66 |
67 | if global_tqdm:
68 | global_tqdm.update()
69 |
70 | return (text, meta, success)
71 |
72 | def scrape_urls(urls_directory, scrapes_directory, process_count, request_timeout):
73 |
74 | url_files = glob.glob(os.path.join(urls_directory, "urls_*.jsonl.zst"))
75 |
76 | # Get Total URL count
77 | url_count_path = os.path.join(urls_directory, "url_count.json")
78 | if os.path.exists(url_count_path):
79 | total_url_count = json.load(open(url_count_path, "r"))
80 | else:
81 | logger.info("Getting total URL count...")
82 | total_url_count = 0
83 | for url_file_path in tqdm.tqdm(url_files, dynamic_ncols=True):
84 | reader = Reader()
85 | url_data = []
86 | for url in reader.read_jsonl(url_file_path, get_meta=False):
87 | total_url_count += 1
88 | json.dump(total_url_count, open(url_count_path, "w"))
89 |
90 | # overall progress bar
91 | progress = tqdm.tqdm(total=total_url_count, dynamic_ncols=True)
92 | progress.set_description("Total URLs")
93 |
94 | for url_file_path in url_files:
95 | # Skip if previously done
96 | done_file_path = url_file_path + ".done"
97 | if os.path.exists(done_file_path):
98 | batch_url_count = json.load(open(done_file_path, "r"))
99 | progress.update(batch_url_count)
100 | logger.info(f"'{os.path.basename(url_file_path)}' already scraped, skipping.")
101 | continue
102 |
103 | logger.info(f"Scraping URLs from '{os.path.basename(url_file_path)}'.")
104 |
105 | reader = Reader()
106 | url_data = []
107 | for url, reddit_meta in reader.read_jsonl(url_file_path, get_meta=True):
108 | url_data.append((url, reddit_meta))
109 |
110 | timer = Timer().start()
111 |
112 | # Download and Process With Pool
113 | pool = TqdmMultiProcessPool(process_count)
114 | tasks = []
115 | for url_entry in url_data:
116 | arguments = (url_entry, request_timeout, newspaper_scraper, False)
117 | task = (download, arguments)
118 | tasks.append(task)
119 |
120 | # tqdm-multiprocess doesn't support multiple global tqdms, use on_done as well
121 | on_done = lambda _ : progress.update()
122 | on_error = lambda _ : None
123 |
124 | with tqdm.tqdm(total=len(url_data), dynamic_ncols=True) as batch_progress:
125 | batch_progress.set_description(f"{os.path.basename(url_file_path)}")
126 | results = pool.map(batch_progress, tasks, on_error, on_done)
127 |
128 | logger.info("Archiving chunk with lm_dataformat...")
129 | # urls_*.jsonl.zst -> scrapes_*.jsonl.zst
130 | output_archive_name = os.path.basename(url_file_path).replace("urls", "scrapes")
131 | output_archive_path = os.path.join(scrapes_directory, output_archive_name)
132 | archiver = Archive(output_archive_path)
133 | batch_error_count = 0
134 | for text, meta, status in results:
135 | if not status:
136 | batch_error_count += 1
137 | else:
138 | archiver.add_data(text, meta)
139 | archiver.commit()
140 |
141 | error_percentage = batch_error_count / len(url_data) * 100
142 | logger.info(f"Errors: {batch_error_count} / {len(url_data)} ({error_percentage:0.2f}%)")
143 | logger.info(f"Batch time: {timer.stop():0.2f} seconds")
144 |
145 | json.dump(len(url_data), open(done_file_path, "w"))
146 |
147 | progress.close()
148 | logger.info("Done!")
149 |
150 | parser_description = 'Scrape urls extracted from Reddit.'
151 | parser = argparse.ArgumentParser(description=parser_description)
152 | parser.add_argument("-dir", "--job_directory", default="")
153 | parser.add_argument("-procs", "--process_count", type=int, default=60)
154 | parser.add_argument("-timeout", "--request_timeout", type=int, default=30)
155 |
156 | if __name__ == "__main__":
157 | logfile_name = "scrape_urls.log"
158 | setup_logger_tqdm(logfile_name)
159 |
160 | args = parser.parse_args()
161 |
162 | urls_directory = os.path.join(args.job_directory, "urls")
163 | if not os.path.exists(urls_directory):
164 | logger.info(f"No 'urls' directory found in '{args.job_directory}', aborting")
165 | sys.exit(0)
166 |
167 | scrapes_directory = os.path.join(args.job_directory, "scrapes")
168 | os.makedirs(scrapes_directory, exist_ok=True)
169 |
170 | logger.info(f"Scrapes outputting to: '{scrapes_directory}'")
171 |
172 | scrape_urls(urls_directory, scrapes_directory, args.process_count, args.request_timeout)
173 |
--------------------------------------------------------------------------------
/scraping/scrapers.py:
--------------------------------------------------------------------------------
1 | # Code taken in large part from https://github.com/jcpeterson/openwebtext
2 |
3 | import time
4 | import unicodedata
5 |
6 | import bs4
7 | import newspaper
8 |
9 | from lxml.html.clean import Cleaner
10 | from htmlmin import minify
11 | from scraping.filter import should_exclude
12 |
13 |
14 | def find_and_filter_tag(tag, soup):
15 | """tag specific filter logic"""
16 |
17 | candidates = soup.find_all(tag)
18 | candidates = [
19 | unicodedata.normalize("NFKD", x.string)
20 | for x in candidates
21 | if x.string is not None
22 | ]
23 |
24 | if tag == "p":
25 | candidates = [y.strip() for y in candidates if len(y.split(" ")) >= 4]
26 | count = sum(len(y.split(" ")) for y in candidates)
27 | else:
28 | raise NotImplementedError
29 |
30 | return (candidates, count)
31 |
32 |
33 | def raw_scraper(url, memoize):
34 | t1 = time.time()
35 | if should_exclude(url):
36 | # heuristic to make downloading faster
37 | return None, {
38 | "url": url,
39 | "scraper": "raw",
40 | }
41 |
42 | try:
43 | cleaner = Cleaner()
44 | cleaner.javascript = True
45 | cleaner.style = True
46 | article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
47 | article.download()
48 | html = minify(article.html)
49 | html = cleaner.clean_html(html)
50 | article.parse()
51 | except:
52 | return None, {
53 | "url": url,
54 | "scraper": "raw",
55 | }
56 | if article.text == "":
57 | return None, {
58 | "url": url,
59 | "scraper": "raw",
60 | }
61 |
62 | metadata = {"url": url, "elapsed": time.time() - t1, "scraper": "raw"}
63 | return html, metadata
64 |
65 |
66 | def newspaper_scraper(url, memoize, request_timeout):
67 | t1 = time.time()
68 |
69 | if should_exclude(url):
70 | # heuristic to make downloading faster
71 | return None, {
72 | "url": url,
73 | "scraper": "newspaper",
74 | }, False
75 |
76 | try:
77 | article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize,
78 | request_timeout=request_timeout)
79 | article.download()
80 | article.parse()
81 | except Exception as ex:
82 | return None, ex, False
83 |
84 | #print(article.__dict__.keys())
85 | #dict_keys(['config', 'extractor', 'source_url', 'url', 'title', 'top_img', 'top_image', 'meta_img', 'imgs', 'images',
86 | # 'movies', 'text', 'keywords', 'meta_keywords', 'tags', 'authors', 'publish_date', 'summary', 'html', 'article_html',
87 | # 'is_parsed', 'download_state', 'download_exception_msg', 'meta_description', 'meta_lang', 'meta_favicon', 'meta_data',
88 | # 'canonical_link', 'top_node', 'clean_top_node', 'doc', 'clean_doc', 'additional_data', 'link_hash'])
89 |
90 | #print(article.title)
91 | # print(article.meta_lang)
92 | # if article.meta_lang != "en" and article.meta_lang:
93 | # print(article.text)
94 |
95 | text = article.text
96 | count = len(text.split())
97 |
98 | metadata = {
99 | "title": article.title,
100 | "lang": article.meta_lang,
101 | "url": url,
102 | "word_count": count,
103 | "elapsed": time.time() - t1,
104 | "scraper": "newspaper",
105 | }
106 | return text, metadata, True
107 |
108 | def bs4_scraper(url, memoize):
109 | t1 = time.time()
110 | if should_exclude(url):
111 | # heuristic to make downloading faster
112 | return None, {
113 | "url": url,
114 | "scraper": "bs4",
115 | }
116 |
117 | try:
118 | article = newspaper.Article(url, fetch_images=False, memoize_articles=memoize)
119 | article.download()
120 | html = article.html
121 | soup = bs4.BeautifulSoup(html, "lxml")
122 | text, count = find_and_filter_tag("p", soup)
123 | # DDB: keep text as a single string for consistency with
124 | # newspaper_scraper
125 | text = " ".join(text)
126 | except:
127 | return None, {
128 | "url": url,
129 | "scraper": "bs4",
130 | }
131 |
132 | metadata = {
133 | "url": url,
134 | "word_count": count,
135 | "elapsed": time.time() - t1,
136 | "scraper": "bs4",
137 | }
138 | return text, metadata
139 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/EleutherAI/openwebtext2/c967e201bcf643f425a5048a3a1681f592148585/utils/__init__.py
--------------------------------------------------------------------------------
/utils/archive_stream_readers.py:
--------------------------------------------------------------------------------
1 | import lzma
2 | import zstandard as zstd
3 | import bz2
4 |
5 | # Lets you access the underyling file's tell function for updating pqdm
6 | class ArchiveStreamReader(object):
7 | def __init__(self, file_path, decompressor):
8 | self.file_path = file_path
9 | self.file_handle = None
10 | self.decompressor = decompressor
11 | self.stream_reader = None
12 |
13 | def __enter__(self):
14 | self.file_handle = open(self.file_path, 'rb')
15 | self.stream_reader = self.decompressor(self.file_handle)
16 | return self
17 |
18 | def __exit__(self, exc_type, exc_value, traceback):
19 | self.stream_reader.close()
20 | self.file_handle.close()
21 |
22 | def tell(self):
23 | return self.file_handle.tell()
24 |
25 | def read(self, size):
26 | return self.stream_reader.read(size)
27 |
28 | def get_archive_stream_reader(file_path):
29 | extension = file_path.split(".")[-1]
30 |
31 | if extension == "zst":
32 | return ArchiveStreamReader(file_path, zstd.ZstdDecompressor().stream_reader)
33 | elif extension == "bz2":
34 | return ArchiveStreamReader(file_path, bz2.BZ2File)
35 | elif extension == "xz":
36 | return ArchiveStreamReader(file_path, lzma.open)
--------------------------------------------------------------------------------
/utils/archiver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zstandard
3 | import json
4 | import jsonlines
5 | import io
6 | import datetime
7 |
8 | def json_serial(obj):
9 | """JSON serializer for objects not serializable by default json code"""
10 |
11 | if isinstance(obj, (datetime.datetime,)):
12 | return obj.isoformat()
13 | raise TypeError ("Type %s not serializable" % type(obj))
14 |
15 | # Modified version of lm_dataformat Archive for single file.
16 | class Archive:
17 | def __init__(self, file_path, compression_level=3):
18 | self.file_path = file_path
19 | dir_name = os.path.dirname(file_path)
20 | if dir_name:
21 | os.makedirs(dir_name, exist_ok=True)
22 | self.fh = open(self.file_path, 'wb')
23 | self.cctx = zstandard.ZstdCompressor(level=compression_level)
24 | self.compressor = self.cctx.stream_writer(self.fh)
25 |
26 | def add_data(self, data, meta={}):
27 | self.compressor.write(json.dumps({'text': data, 'meta': meta}, default=json_serial).encode('UTF-8') + b'\n')
28 |
29 | def commit(self):
30 | self.compressor.flush(zstandard.FLUSH_FRAME)
31 | self.fh.flush()
32 | self.fh.close()
33 |
34 | # Modified version of lm_dataformat Reader with self.fh set, allowing peeking for tqdm.
35 | class Reader:
36 | def __init__(self):
37 | pass
38 |
39 | def read_jsonl(self, file, get_meta=False, autojoin_paragraphs=True, para_joiner='\n\n'):
40 | with open(file, 'rb') as fh:
41 | self.fh = fh
42 | cctx = zstandard.ZstdDecompressor()
43 | reader = io.BufferedReader(cctx.stream_reader(fh))
44 | rdr = jsonlines.Reader(reader)
45 | for ob in rdr:
46 | # naive jsonl where each object is just the string itself, with no meta. For legacy compatibility.
47 | if isinstance(ob, str):
48 | assert not get_meta
49 | yield ob
50 | continue
51 |
52 | text = ob['text']
53 |
54 | if autojoin_paragraphs and isinstance(text, list):
55 | text = para_joiner.join(text)
56 |
57 | if get_meta:
58 | yield text, (ob['meta'] if 'meta' in ob else {})
59 | else:
60 | yield text
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from datetime import timedelta
4 | from tqdm import tqdm
5 |
6 | class LogFormatter():
7 |
8 | def __init__(self):
9 | self.start_time = time.time()
10 |
11 | def format(self, record):
12 | elapsed_seconds = round(record.created - self.start_time)
13 |
14 | prefix = "%s - %s - %s" % (
15 | record.levelname,
16 | time.strftime('%x %X'),
17 | timedelta(seconds=elapsed_seconds)
18 | )
19 | message = record.getMessage()
20 | message = message.replace('\n', '\n' + ' ' * (len(prefix) + 3))
21 | return "%s - %s" % (prefix, message) if message else ''
22 |
23 | def reset_time(self):
24 | self.start_time = time.time()
25 |
26 | def setup_logger(filepath=None, to_console=True, formatter=LogFormatter()):
27 |
28 | # create logger
29 | logger = logging.getLogger()
30 | logger.setLevel(logging.DEBUG)
31 | logger.propagate = False
32 |
33 | logger.handlers = []
34 |
35 | # create file handler
36 | if filepath is not None:
37 | file_handler = logging.FileHandler(filepath, "a")
38 | file_handler.setLevel(logging.DEBUG)
39 | file_handler.setFormatter(formatter)
40 | logger.addHandler(file_handler)
41 |
42 | # create console handler
43 | if to_console:
44 | console_handler = logging.StreamHandler()
45 | console_handler.setLevel(logging.INFO)
46 | console_handler.setFormatter(formatter)
47 | logger.addHandler(console_handler)
48 |
49 | class ChildProcessHandler(logging.StreamHandler):
50 | def __init__(self, message_queue):
51 | self.message_queue = message_queue
52 | logging.StreamHandler.__init__(self)
53 |
54 | def emit(self, record):
55 | self.message_queue.put(record)
56 |
57 | def setup_logger_child_process(message_queue):
58 | # create logger
59 | logger = logging.getLogger()
60 | logger.setLevel(logging.DEBUG)
61 | logger.propagate = False
62 |
63 | logger.handlers = []
64 |
65 | # create queue handler
66 | child_process_handler = ChildProcessHandler(message_queue)
67 | child_process_handler.setLevel(logging.INFO)
68 | logger.addHandler(child_process_handler)
69 |
70 | class TqdmHandler(logging.StreamHandler):
71 | def __init__(self):
72 | logging.StreamHandler.__init__(self)
73 |
74 | def emit(self, record):
75 | msg = self.format(record)
76 | tqdm.write(msg)
77 |
78 | def setup_logger_tqdm(filepath=None, formatter=LogFormatter()):
79 |
80 | # create logger
81 | logger = logging.getLogger()
82 | logger.setLevel(logging.DEBUG)
83 | logger.propagate = False
84 |
85 | logger.handlers = []
86 |
87 | # create file handler
88 | if filepath is not None:
89 | file_handler = logging.FileHandler(filepath, "a")
90 | file_handler.setLevel(logging.DEBUG)
91 | file_handler.setFormatter(formatter)
92 | logger.addHandler(file_handler)
93 |
94 | # create tqdm handler
95 | tqdm_handler = TqdmHandler()
96 | tqdm_handler.setLevel(logging.INFO)
97 | tqdm_handler.setFormatter(formatter)
98 | logger.addHandler(tqdm_handler)
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import os
3 | import time
4 | import pickle
5 |
6 | import logging
7 | logger = logging.getLogger(__name__)
8 |
9 | def timed_pickle_load(file_name, pickle_description):
10 | logger.info(f"Unpickling {pickle_description}...")
11 | timer = Timer().start()
12 | unpickled = pickle.load(open(file_name, "rb"))
13 | logger.info(timer.stop_string())
14 | return unpickled
15 |
16 | def timed_pickle_dump(the_object, file_name, pickle_description):
17 | logger.info(f"Pickling {pickle_description}...")
18 | timer = Timer().start()
19 | pickle.dump(the_object, open(file_name, "wb"))
20 | logger.info(timer.stop_string())
21 |
22 | class TimerError(Exception):
23 | """A custom exception used to report errors in use of Timer class"""
24 |
25 | class Timer:
26 | def __init__(self):
27 | self._start_time = None
28 |
29 | def start(self):
30 | """Start a new timer"""
31 | if self._start_time is not None:
32 | raise TimerError(f"Timer is running. Use .stop() to stop it")
33 |
34 | self._start_time = time.perf_counter()
35 |
36 | return self
37 |
38 | def stop(self):
39 | """Stop the timer, and report the elapsed time"""
40 | if self._start_time is None:
41 | raise TimerError(f"Timer is not running. Use .start() to start it")
42 |
43 | elapsed_time = time.perf_counter() - self._start_time
44 | self._start_time = None
45 | return elapsed_time
46 |
47 | def stop_string(self):
48 | elapsed = self.stop()
49 | return f"Took {elapsed:0.2f}s"
50 |
51 | def linecount(filename):
52 | f = open(filename, 'rb')
53 | lines = 0
54 | buf_size = 1024 * 1024
55 | read_f = f.raw.read
56 |
57 | buf = read_f(buf_size)
58 | while buf:
59 | lines += buf.count(b'\n')
60 | buf = read_f(buf_size)
61 |
62 | return lines
63 |
64 | def chunker(l, n, s=0):
65 | """Yield successive n-sized chunks from l, skipping the first s chunks."""
66 | if isinstance(l, collections.Iterable):
67 | chnk = []
68 | for i, elem in enumerate(l):
69 | if i < s:
70 | continue
71 |
72 | chnk.append(elem)
73 | if len(chnk) == n:
74 | yield chnk
75 | chnk = []
76 | if len(chnk) != 0:
77 | yield chnk
78 |
79 | else:
80 | for i in range(s, len(l), n):
81 | yield l[i : i + n]
82 |
83 | def mkdir(fp):
84 | try:
85 | os.makedirs(fp)
86 | except FileExistsError:
87 | pass
88 | return fp
--------------------------------------------------------------------------------