├── ml_datasets ├── test │ ├── __init__.py │ ├── test_util.py │ └── test_datasets.py ├── _registry.py ├── loaders │ ├── __init__.py │ ├── stack_exchange.py │ ├── wikiner.py │ ├── quora.py │ ├── snli.py │ ├── imdb.py │ ├── dbpedia.py │ ├── cmu.py │ ├── cifar.py │ ├── universal_dependencies.py │ ├── reuters.py │ └── mnist.py ├── __init__.py ├── spacy_readers.py └── util.py ├── requirements.txt ├── setup.py ├── pyproject.toml ├── setup.cfg ├── LICENSE ├── .github └── workflows │ └── pytest.yml ├── .gitignore └── README.md /ml_datasets/test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ml_datasets/_registry.py: -------------------------------------------------------------------------------- 1 | import catalogue 2 | 3 | register_loader = catalogue.create("ml-datasets", entry_points=True) 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cloudpickle>=2.2 2 | numpy>=1.18 3 | scipy>=1.7.0 4 | tqdm>=4.10.0,<5.0.0 5 | # Our libraries 6 | srsly>=1.0.1,<4.0.0 7 | catalogue>=0.2.0,<3.0.0 8 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | 4 | if __name__ == "__main__": 5 | from setuptools import setup, find_packages 6 | 7 | setup(name="ml_datasets", packages=find_packages()) 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.pytest.ini_options] 2 | addopts = "--strict-markers --strict-config -v -r sxfE --color=yes --durations=10" 3 | xfail_strict = true 4 | filterwarnings = [ 5 | "error", 6 | # FIXME spurious random download warnings; will cause trouble in downstream CI 7 | "ignore:Implicitly cleaning up =3.8 15 | install_requires = 16 | cloudpickle>=2.2 17 | numpy>=1.18 18 | tqdm>=4.10.0,<5.0.0 19 | # Our libraries 20 | srsly>=1.0.1,<4.0.0 21 | catalogue>=0.2.0,<3.0.0 22 | 23 | [options.entry_points] 24 | spacy_readers = 25 | ml_datasets.imdb_sentiment.v1 = ml_datasets.spacy_readers:imdb_reader 26 | ml_datasets.cmu_movies.v1 = ml_datasets.spacy_readers:cmu_reader 27 | ml_datasets.dbpedia.v1 = ml_datasets.spacy_readers:dbpedia_reader 28 | 29 | [flake8] 30 | ignore = E203, E266, E501, E731, W503, E741 31 | max-line-length = 80 32 | select = B,C,E,F,W,T4,B9 33 | exclude = 34 | ml_datasets/__init__.py 35 | -------------------------------------------------------------------------------- /ml_datasets/test/test_util.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from urllib.error import HTTPError, URLError 3 | from ml_datasets.util import get_file 4 | 5 | 6 | def test_get_file_domain_resolution_fails(): 7 | with pytest.raises( 8 | URLError, match=r"test_non_existent_file.*(not known|getaddrinfo failed)" 9 | ): 10 | get_file( 11 | "non_existent_file.txt", 12 | "http://test_notexist.wth/test_non_existent_file.txt" 13 | ) 14 | 15 | 16 | def test_get_file_404_file_not_found(): 17 | with pytest.raises(HTTPError, match="test_non_existent_file.*404.*Not Found") as e: 18 | get_file( 19 | "non_existent_file.txt", 20 | "http://google.com/test_non_existent_file.txt" 21 | ) 22 | assert e.value.code == 404 23 | # Suppress pytest.PytestUnraisableExceptionWarning: 24 | # Exception ignored while calling deallocator 25 | # This questionable design quirk comes from urllib.request.urlretrieve, 26 | # so we shouldn't shim around it. 27 | e.value.close() 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ExplosionAI GmbH 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ml_datasets/loaders/quora.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import csv 3 | 4 | from ..util import partition, get_file 5 | from .._registry import register_loader 6 | 7 | 8 | QUORA_QUESTIONS_URL = "http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv" 9 | 10 | 11 | @register_loader("quora_questions") 12 | def quora_questions(loc=None): 13 | if loc is None: 14 | loc = get_file("quora_similarity.tsv", QUORA_QUESTIONS_URL) 15 | if isinstance(loc, str): 16 | loc = Path(loc) 17 | is_header = True 18 | lines = [] 19 | with loc.open("r", encoding="utf8") as file_: 20 | for row in csv.reader(file_, delimiter="\t"): 21 | if is_header: 22 | is_header = False 23 | continue 24 | id_, qid1, qid2, sent1, sent2, is_duplicate = row 25 | if not isinstance(sent1, str): 26 | sent1 = sent1.decode("utf8").strip() 27 | if not isinstance(sent2, str): 28 | sent2 = sent2.decode("utf8").strip() 29 | if sent1 and sent2: 30 | lines.append(((sent1, sent2), int(is_duplicate))) 31 | train, dev = partition(lines, 0.9) 32 | return train, dev 33 | -------------------------------------------------------------------------------- /ml_datasets/loaders/snli.py: -------------------------------------------------------------------------------- 1 | from srsly import json_loads 2 | from pathlib import Path 3 | 4 | from ..util import get_file 5 | from .._registry import register_loader 6 | 7 | 8 | SNLI_URL = "http://nlp.stanford.edu/projects/snli/snli_1.0.zip" 9 | THREE_LABELS = {"entailment": 2, "contradiction": 1, "neutral": 0} 10 | TWO_LABELS = {"entailment": 1, "contradiction": 0, "neutral": 0} 11 | 12 | 13 | @register_loader("snli") 14 | def snli(loc=None, ternary=False): 15 | label_scheme = THREE_LABELS if ternary else TWO_LABELS 16 | if loc is None: 17 | loc = get_file("snli_1.0", SNLI_URL, unzip=True) 18 | if isinstance(loc, str): 19 | loc = Path(loc) 20 | train = read_snli(Path(loc) / "snli_1.0_train.jsonl", label_scheme) 21 | dev = read_snli(Path(loc) / "snli_1.0_dev.jsonl", label_scheme) 22 | return train, dev 23 | 24 | 25 | def read_snli(loc, label_scheme): 26 | rows = [] 27 | with loc.open("r", encoding="utf8") as file_: 28 | for line in file_: 29 | eg = json_loads(line) 30 | label = eg["gold_label"] 31 | if label == "-": 32 | continue 33 | rows.append(((eg["sentence1"], eg["sentence2"]), label_scheme[label])) 34 | return rows 35 | -------------------------------------------------------------------------------- /ml_datasets/loaders/imdb.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import random 3 | 4 | from ..util import get_file 5 | from .._registry import register_loader 6 | 7 | IMDB_URL = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" 8 | 9 | 10 | @register_loader("imdb") 11 | def imdb(loc=None, *, train_limit=0, dev_limit=0): 12 | if loc is None: 13 | loc = get_file("aclImdb", IMDB_URL, untar=True, unzip=True) 14 | train_loc = Path(loc) / "train" 15 | test_loc = Path(loc) / "test" 16 | return read_imdb(train_loc, limit=train_limit), read_imdb(test_loc, limit=dev_limit) 17 | 18 | 19 | def read_imdb(data_dir, *, limit=0): 20 | locs = [] 21 | for subdir in ("pos", "neg"): 22 | for filename in (data_dir / subdir).iterdir(): 23 | locs.append((filename, subdir)) 24 | 25 | # shuffle and filter the file locations 26 | random.shuffle(locs) 27 | if limit >= 1: 28 | locs = locs[:limit] 29 | 30 | examples = [] 31 | for loc, gold_label in locs: 32 | with loc.open("r", encoding="utf8") as file_: 33 | text = file_.read() 34 | text = text.replace("
", "\n\n") 35 | if text.strip(): 36 | examples.append((text, gold_label)) 37 | return examples 38 | -------------------------------------------------------------------------------- /ml_datasets/loaders/dbpedia.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import csv 3 | import random 4 | 5 | from ..util import get_file 6 | from .._registry import register_loader 7 | 8 | 9 | # DBPedia Ontology from https://course.fast.ai/datasets 10 | DBPEDIA_ONTOLOGY_URL = "https://s3.amazonaws.com/fast-ai-nlp/dbpedia_csv.tgz" 11 | 12 | 13 | @register_loader("dbpedia") 14 | def dbpedia(loc=None, *, train_limit=0, dev_limit=0): 15 | if loc is None: 16 | loc = get_file("dbpedia_csv", DBPEDIA_ONTOLOGY_URL, untar=True, unzip=True) 17 | train_loc = Path(loc) / "train.csv" 18 | test_loc = Path(loc) / "test.csv" 19 | return ( 20 | read_dbpedia_ontology(train_loc, limit=train_limit), 21 | read_dbpedia_ontology(test_loc, limit=dev_limit), 22 | ) 23 | 24 | 25 | def read_dbpedia_ontology(data_file, *, limit=0): 26 | examples = [] 27 | with open(data_file, newline="", encoding="utf-8") as f: 28 | reader = csv.reader(f) 29 | for row in reader: 30 | label = row[0] 31 | title = row[1] 32 | text = row[2] 33 | examples.append((title + "\n" + text, label)) 34 | random.shuffle(examples) 35 | if limit >= 1: 36 | examples = examples[:limit] 37 | return examples 38 | -------------------------------------------------------------------------------- /ml_datasets/loaders/cmu.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | import random 4 | import csv 5 | 6 | from ..util import get_file, partition 7 | from .._registry import register_loader 8 | 9 | CMU_URL = "http://www.cs.cmu.edu/~ark/personas/data/MovieSummaries.tar.gz" 10 | 11 | 12 | @register_loader("cmu") 13 | def cmu(loc=None, *, limit=0, shuffle=True, labels=None, split=0.9): 14 | if loc is None: 15 | loc = get_file("MovieSummaries", CMU_URL, untar=True, unzip=True) 16 | meta_loc = Path(loc) / "movie.metadata.tsv" 17 | text_loc = Path(loc) / "plot_summaries.txt" 18 | 19 | data = read_cmu(meta_loc, text_loc, limit=limit, shuffle=shuffle, labels=labels) 20 | train, dev = partition(data, split) 21 | return train, dev 22 | 23 | 24 | def read_cmu(meta_loc, text_loc, *, limit, shuffle, labels): 25 | genre_by_id = {} 26 | title_by_id = {} 27 | with meta_loc.open("r", encoding="utf8") as file_: 28 | for row in csv.reader(file_, delimiter="\t"): 29 | movie_id = row[0] 30 | title = row[2] 31 | annot = row[8] 32 | d = json.loads(annot) 33 | genres = set(d.values()) 34 | genre_by_id[movie_id] = genres 35 | title_by_id[movie_id] = title 36 | 37 | examples = [] 38 | with text_loc.open("r", encoding="utf8") as file_: 39 | for row in csv.reader(file_, delimiter="\t"): 40 | movie_id = row[0] 41 | text = row[1] 42 | genres = genre_by_id.get(movie_id, None) 43 | title = title_by_id.get(movie_id, "") 44 | if genres: 45 | if not labels or [g for g in genres if g in labels]: 46 | examples.append((title + "\n" + text, list(genres))) 47 | if shuffle: 48 | random.shuffle(examples) 49 | if limit >= 1: 50 | examples = examples[:limit] 51 | return examples 52 | -------------------------------------------------------------------------------- /ml_datasets/test/test_datasets.py: -------------------------------------------------------------------------------- 1 | # TODO the tests below only verify that the various functions don't crash. 2 | # Expand them to test the actual output contents. 3 | 4 | import platform 5 | 6 | import pytest 7 | import numpy as np 8 | 9 | import ml_datasets 10 | 11 | NP_VERSION = tuple(int(x) for x in np.__version__.split(".")[:2]) 12 | 13 | # FIXME warning on NumPy 2.4 when downloading pre-computed pickles: 14 | # Python or NumPy boolean but got `align=0`. 15 | # Did you mean to pass a tuple to create a subarray type? (Deprecated NumPy 2.4) 16 | if NP_VERSION >= (2, 4): 17 | np_24_deprecation = pytest.mark.filterwarnings( 18 | "ignore::numpy.exceptions.VisibleDeprecationWarning", 19 | 20 | ) 21 | else: 22 | # Note: can't use `condition=NP_VERSION >= (2, 4)` on the decorator directly 23 | # as numpy.exceptions did not exist in old NumPy versions. 24 | np_24_deprecation = lambda x: x 25 | 26 | 27 | @np_24_deprecation 28 | def test_cifar(): 29 | (X_train, y_train), (X_test, y_test) = ml_datasets.cifar() 30 | 31 | 32 | @pytest.mark.skip(reason="very slow download") 33 | def test_cmu(): 34 | train, dev = ml_datasets.cmu() 35 | 36 | 37 | def test_dbpedia(): 38 | train, dev = ml_datasets.dbpedia() 39 | 40 | 41 | def test_imdb(): 42 | train, dev = ml_datasets.imdb() 43 | 44 | 45 | @np_24_deprecation 46 | def test_mnist(): 47 | (X_train, y_train), (X_test, y_test) = ml_datasets.mnist() 48 | 49 | 50 | @pytest.mark.xfail(reason="403 Forbidden") 51 | def test_quora_questions(): 52 | train, dev = ml_datasets.quora_questions() 53 | 54 | 55 | @np_24_deprecation 56 | def test_reuters(): 57 | (X_train, y_train), (X_test, y_test) = ml_datasets.reuters() 58 | 59 | 60 | @pytest.mark.xfail(platform.system() == "Windows", reason="path issues") 61 | def test_snli(): 62 | train, dev = ml_datasets.snli() 63 | 64 | 65 | @pytest.mark.xfail(reason="no default path") 66 | def test_stack_exchange(): 67 | train, dev = ml_datasets.stack_exchange() 68 | 69 | 70 | def test_ud_ancora_pos_tags(): 71 | (train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ancora_pos_tags() 72 | 73 | 74 | @pytest.mark.xfail(reason="str column where int expected") 75 | def test_ud_ewtb_pos_tags(): 76 | (train_X, train_y), (dev_X, dev_y) = ml_datasets.ud_ewtb_pos_tags() 77 | 78 | 79 | @pytest.mark.xfail(reason="no default path") 80 | def test_wikiner(): 81 | train, dev = ml_datasets.wikiner() 82 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: ["*"] 8 | workflow_dispatch: # allows you to trigger manually 9 | 10 | # When this workflow is queued, automatically cancel any previous running 11 | # or pending jobs from the same branch 12 | concurrency: 13 | group: pytest-${{ github.ref }} 14 | cancel-in-progress: true 15 | 16 | defaults: 17 | run: 18 | shell: bash -l {0} 19 | 20 | jobs: 21 | test: 22 | name: ${{ matrix.os }} Python ${{ matrix.python-version }} NumPy ${{ matrix.numpy-version}} 23 | runs-on: ${{ matrix.os}} 24 | strategy: 25 | fail-fast: false 26 | matrix: 27 | os: [ubuntu-latest] 28 | python-version: ["3.11", "3.14"] 29 | numpy-version: ["2.3.0", latest] 30 | include: 31 | # Test oldest supported Python and NumPy versions 32 | - os: ubuntu-latest 33 | python-version: "3.8" 34 | numpy-version: "1.18.0" 35 | # Test vs. NumPy nightly wheels 36 | - os: ubuntu-latest 37 | python-version: "3.14" 38 | numpy-version: "nightly" 39 | # Test issues re. preinstalled SSL certificates on different OSes 40 | - os: windows-latest 41 | python-version: "3.14" 42 | numpy-version: latest 43 | - os: macos-latest 44 | python-version: "3.14" 45 | numpy-version: latest 46 | 47 | steps: 48 | - name: Checkout 49 | uses: actions/checkout@v6 50 | 51 | - name: Set up Python ${{ matrix.python-version }} 52 | uses: actions/setup-python@v4 53 | with: 54 | python-version: ${{ matrix.python-version }} 55 | 56 | - name: Install pinned NumPy 57 | if: matrix.numpy-version != 'latest' && matrix.numpy-version != 'nightly' 58 | run: python -m pip install numpy==${{ matrix.numpy-version }} 59 | 60 | - name: Install nightly NumPy wheels 61 | if: matrix.numpy-version == 'nightly' 62 | run: pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple/ numpy 63 | 64 | - name: Install package 65 | run: pip install . 66 | 67 | - name: Smoke test 68 | run: python -c "import ml_datasets" 69 | 70 | - name: Install test dependencies 71 | run: pip install pytest 72 | 73 | - name: Run tests 74 | run: pytest 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # Spyder project settings 112 | .spyderproject 113 | .spyproject 114 | 115 | # Rope project settings 116 | .ropeproject 117 | 118 | # mkdocs documentation 119 | /site 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | # Pycharm project files 130 | /.idea/ 131 | -------------------------------------------------------------------------------- /ml_datasets/spacy_readers.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Iterable, Callable, Dict 3 | from pathlib import Path 4 | 5 | from .loaders import cmu, dbpedia, imdb 6 | 7 | 8 | def cmu_reader( 9 | path: Path = None, *, freq_cutoff: int = 0, limit: int = 0, split=0.9 10 | ) -> Dict[str, Callable[["Language"], Iterable["Example"]]]: 11 | from spacy.training.example import Example 12 | 13 | # Deduce the categories above threshold by inspecting all data 14 | all_train_data, _ = list(cmu(path, limit=0, split=1)) 15 | counted_cats = {} 16 | for text, cats in all_train_data: 17 | for cat in cats: 18 | counted_cats[cat] = counted_cats.get(cat, 0) + 1 19 | # filter labels by frequency 20 | unique_labels = [ 21 | l for l in sorted(counted_cats.keys()) if counted_cats[l] >= freq_cutoff 22 | ] 23 | train_data, dev_data = cmu(path, limit=limit, shuffle=False, labels=unique_labels, split=split) 24 | 25 | def read_examples(data, nlp): 26 | for text, cats in data: 27 | doc = nlp.make_doc(text) 28 | assert isinstance(cats, list) 29 | cat_dict = {label: float(label in cats) for label in unique_labels} 30 | yield Example.from_dict(doc, {"cats": cat_dict}) 31 | 32 | return { 33 | "train": partial(read_examples, train_data), 34 | "dev": partial(read_examples, dev_data), 35 | } 36 | 37 | 38 | def dbpedia_reader( 39 | path: Path = None, *, train_limit: int = 0, dev_limit: int = 0 40 | ) -> Dict[str, Callable[["Language"], Iterable["Example"]]]: 41 | from spacy.training.example import Example 42 | 43 | all_train_data, _ = dbpedia(path, train_limit=0, dev_limit=1) 44 | unique_labels = set() 45 | for text, gold_label in all_train_data: 46 | assert isinstance(gold_label, str) 47 | unique_labels.add(gold_label) 48 | train_data, dev_data = dbpedia(path, train_limit=train_limit, dev_limit=dev_limit) 49 | 50 | def read_examples(data, nlp): 51 | for text, gold_label in data: 52 | doc = nlp.make_doc(text) 53 | cat_dict = {label: float(gold_label == label) for label in unique_labels} 54 | yield Example.from_dict(doc, {"cats": cat_dict}) 55 | 56 | return { 57 | "train": partial(read_examples, train_data), 58 | "dev": partial(read_examples, dev_data), 59 | } 60 | 61 | 62 | def imdb_reader( 63 | path: Path = None, *, train_limit: int = 0, dev_limit: int = 0 64 | ) -> Dict[str, Callable[["Language"], Iterable["Example"]]]: 65 | from spacy.training.example import Example 66 | 67 | train_data, dev_data = imdb(path, train_limit=train_limit, dev_limit=dev_limit) 68 | unique_labels = ["pos", "neg"] 69 | 70 | def read_examples(data, nlp): 71 | for text, gold_label in data: 72 | doc = nlp.make_doc(text) 73 | cat_dict = {label: float(gold_label == label) for label in unique_labels} 74 | yield Example.from_dict(doc, {"cats": cat_dict}) 75 | 76 | return { 77 | "train": partial(read_examples, train_data), 78 | "dev": partial(read_examples, dev_data), 79 | } 80 | -------------------------------------------------------------------------------- /ml_datasets/loaders/cifar.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import tarfile 3 | import random 4 | import numpy 5 | 6 | from ..util import get_file, unzip, to_categorical 7 | 8 | CIFAR10_URL = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 9 | CIFAR100_URL = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz" 10 | 11 | 12 | def cifar(variant='10', channels_last=False, shuffle=True): 13 | if variant == '10': 14 | data = load_cifar10() 15 | elif variant == '100': 16 | data = load_cifar100(coarse=False) 17 | elif variant == '100-coarse': 18 | data = load_cifar100(coarse=True) 19 | else: 20 | raise ValueError("Variant must be one of: '10', '100', 100-coarse") 21 | X_train, y_train, X_test, y_test = data 22 | X_train = X_train.astype("float32") 23 | X_test = X_test.astype("float32") 24 | X_train /= 255.0 25 | X_test /= 255.0 26 | if shuffle: 27 | train_data = list(zip(X_train, y_train)) 28 | random.shuffle(train_data) 29 | X_train, y_train = unzip(train_data) 30 | if channels_last: 31 | X_train = X_train.reshape(X_train.shape[0], 32, 32, 3) 32 | X_test = X_test.reshape(X_test.shape[0], 32, 32, 3) 33 | else: 34 | X_train = X_train.reshape(X_train.shape[0], 3, 32, 32) 35 | X_test = X_test.reshape(X_test.shape[0], 3, 32, 32) 36 | y_train = to_categorical(y_train) 37 | y_test = to_categorical(y_test) 38 | return (X_train, y_train), (X_test, y_test) 39 | 40 | 41 | def load_cifar10(path='cifar-10-python.tar.gz'): 42 | path = get_file(path, origin=CIFAR10_URL) 43 | train_images = [] 44 | train_labels = [] 45 | with tarfile.open(path) as cifarf: 46 | for name in cifarf.getnames(): 47 | # data is stored in batches 48 | if 'data_batch' in name: 49 | decompressed = cifarf.extractfile(name) 50 | data = pickle.load(decompressed, encoding='bytes') 51 | train_images.append(data[b'data']) 52 | train_labels += data[b'labels'] 53 | elif 'test_batch' in name: 54 | decompressed = cifarf.extractfile(name) 55 | data = pickle.load(decompressed, encoding='bytes') 56 | test_images = data[b'data'] 57 | test_labels = data[b'labels'] 58 | train_images = numpy.vstack(train_images) 59 | train_labels = numpy.asarray(train_labels) 60 | test_labels = numpy.asarray(test_labels) 61 | return train_images, train_labels, test_images, test_labels 62 | 63 | 64 | def load_cifar100(path='cifar-100-python.tar.gz', coarse=False): 65 | path = get_file(path, origin=CIFAR10_URL) 66 | with tarfile.open(path) as cifarf: 67 | train_decomp = cifarf.extractfile('cifar-100-python/train') 68 | test_decomp = cifarf.extractfile('cifar-100-python/test') 69 | train_data = pickle.load(train_decomp, encoding='bytes') 70 | test_data = pickle.load(test_decomp, encoding='bytes') 71 | train_images = train_data[b'data'] 72 | test_images = test_data[b'data'] 73 | if coarse: 74 | train_labels = train_data[b'coarse_labels'] 75 | test_labels = test_data[b'coarse_labels'] 76 | else: 77 | train_labels = train_data[b'fine_labels'] 78 | test_labels = test_data[b'fine_labels'] 79 | train_labels = numpy.asarray(train_labels) 80 | test_labels = numpy.asarray(test_labels) 81 | return train_images, train_labels, test_images, test_labels 82 | -------------------------------------------------------------------------------- /ml_datasets/loaders/universal_dependencies.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from collections import Counter 3 | from pathlib import Path 4 | 5 | from ..util import get_file, to_categorical 6 | from .._registry import register_loader 7 | 8 | 9 | GITHUB = "https://github.com/UniversalDependencies/" 10 | TEMPLATE = "{github}/{repo}/archive/r1.4.zip" 11 | ANCORA_1_4_ZIP = TEMPLATE.format(github=GITHUB, repo="UD_Spanish-AnCora") 12 | EWTB_1_4_ZIP = TEMPLATE.format(github=GITHUB, repo="UD_English") 13 | 14 | 15 | @register_loader("ud_ancora_pos_tags") 16 | def ud_ancora_pos_tags(encode_words=False, limit=None): 17 | data_dir = Path(get_file("UD_Spanish-AnCora-r1.4", ANCORA_1_4_ZIP, unzip=True)) 18 | train_loc = data_dir / "es_ancora-ud-train.conllu" 19 | dev_loc = data_dir / "es_ancora-ud-dev.conllu" 20 | return ud_pos_tags(train_loc, dev_loc, encode_words=encode_words, limit=limit) 21 | 22 | 23 | @register_loader("ud_ewtb_pos_tags") 24 | def ud_ewtb_pos_tags(encode_tags=False, encode_words=False, limit=None): 25 | data_dir = Path(get_file("UD_English-EWT-r1.4", EWTB_1_4_ZIP, unzip=True)) 26 | train_loc = data_dir / "en-ud-train.conllu" 27 | dev_loc = data_dir / "en-ud-dev.conllu" 28 | return ud_pos_tags( 29 | train_loc, 30 | dev_loc, 31 | encode_tags=encode_tags, 32 | encode_words=encode_words, 33 | limit=limit, 34 | ) 35 | 36 | 37 | def ud_pos_tags(train_loc, dev_loc, encode_tags=True, encode_words=True, limit=None): 38 | train_sents = list(read_conll(train_loc)) 39 | dev_sents = list(read_conll(dev_loc)) 40 | tagmap = {} 41 | freqs = Counter() 42 | for words, tags in train_sents: 43 | for tag in tags: 44 | tagmap.setdefault(tag, len(tagmap)) 45 | for word in words: 46 | freqs[word] += 1 47 | vocab = {w: i for i, (w, freq) in enumerate(freqs.most_common()) if (freq >= 5)} 48 | 49 | def _encode(sents): 50 | X = [] 51 | y = [] 52 | for words, tags in sents: 53 | if encode_words: 54 | arr = [vocab.get(word, len(vocab)) for word in words] 55 | X.append(numpy.asarray(arr, dtype="uint64")) 56 | else: 57 | X.append(words) 58 | if encode_tags: 59 | y.append(numpy.asarray([tagmap[tag] for tag in tags], dtype="int32")) 60 | else: 61 | y.append(tags) 62 | return zip(X, y) 63 | 64 | train_data = _encode(train_sents) 65 | check_data = _encode(dev_sents) 66 | train_X, train_y = zip(*train_data) 67 | dev_X, dev_y = zip(*check_data) 68 | nb_tag = max(max(y) for y in train_y) + 1 69 | train_X = list(train_X) 70 | dev_X = list(dev_X) 71 | train_y = [to_categorical(y, nb_tag) for y in train_y] 72 | dev_y = [to_categorical(y, nb_tag) for y in dev_y] 73 | if limit is not None: 74 | train_X = train_X[:limit] 75 | train_y = train_y[:limit] 76 | return (train_X, train_y), (dev_X, dev_y) 77 | 78 | 79 | def read_conll(loc): 80 | with Path(loc).open(encoding="utf8") as file_: 81 | sent_strs = file_.read().strip().split("\n\n") 82 | for sent_str in sent_strs: 83 | lines = [li.split() for li in sent_str.split("\n") if not li.startswith("#")] 84 | words = [] 85 | tags = [] 86 | for i, pieces in enumerate(lines): 87 | if len(pieces) == 4: 88 | word, pos, head, label = pieces 89 | else: 90 | idx, word, lemma, pos1, pos, morph, head, label, _, _2 = pieces 91 | if "-" in idx: 92 | continue 93 | words.append(word) 94 | tags.append(pos) 95 | yield words, tags 96 | -------------------------------------------------------------------------------- /ml_datasets/loaders/reuters.py: -------------------------------------------------------------------------------- 1 | import cloudpickle as pickle 2 | import numpy 3 | 4 | from ..util import get_file 5 | from .._registry import register_loader 6 | 7 | URL = "https://s3.amazonaws.com/text-datasets/reuters.pkl" 8 | WORD_INDEX_URL = "https://s3.amazonaws.com/text-datasets/reuters_word_index.pkl" 9 | 10 | 11 | @register_loader("reuters") 12 | def reuters(): 13 | (X_train, y_train), (X_test, y_test) = load_reuters() 14 | return (X_train, y_train), (X_test, y_test) 15 | 16 | 17 | def get_word_index(path="reuters_word_index.pkl"): 18 | path = get_file(path, origin=WORD_INDEX_URL) 19 | f = open(path, "rb") 20 | data = pickle.load(f, encoding="latin1") 21 | f.close() 22 | return data 23 | 24 | 25 | def load_reuters( 26 | path="reuters.pkl", 27 | nb_words=None, 28 | skip_top=0, 29 | maxlen=None, 30 | test_split=0.2, 31 | seed=113, 32 | start_char=1, 33 | oov_char=2, 34 | index_from=3, 35 | ): 36 | """Loads the Reuters newswire classification dataset. 37 | 38 | # Arguments 39 | path: where to store the data (in `/.keras/dataset`) 40 | nb_words: max number of words to include. Words are ranked 41 | by how often they occur (in the training set) and only 42 | the most frequent words are kept 43 | skip_top: skip the top N most frequently occuring words 44 | (which may not be informative). 45 | maxlen: truncate sequences after this length. 46 | test_split: Fraction of the dataset to be used as test data. 47 | seed: random seed for sample shuffling. 48 | start_char: The start of a sequence will be marked with this character. 49 | Set to 1 because 0 is usually the padding character. 50 | oov_char: words that were cut out because of the `nb_words` 51 | or `skip_top` limit will be replaced with this character. 52 | index_from: index actual words with this index and higher. 53 | 54 | Note that the 'out of vocabulary' character is only used for 55 | words that were present in the training set but are not included 56 | because they're not making the `nb_words` cut here. 57 | Words that were not seen in the trining set but are in the test set 58 | have simply been skipped. 59 | """ 60 | # https://raw.githubusercontent.com/fchollet/keras/master/keras/datasets/mnist.py 61 | # Copyright Francois Chollet, Google, others (2015) 62 | # Under MIT license 63 | path = get_file(path, origin=URL) 64 | f = open(path, "rb") 65 | X, labels = pickle.load(f) 66 | f.close() 67 | numpy.random.seed(seed) 68 | numpy.random.shuffle(X) 69 | numpy.random.seed(seed) 70 | numpy.random.shuffle(labels) 71 | if start_char is not None: 72 | X = [[start_char] + [w + index_from for w in x] for x in X] 73 | elif index_from: 74 | X = [[w + index_from for w in x] for x in X] 75 | if maxlen: 76 | new_X = [] 77 | new_labels = [] 78 | for x, y in zip(X, labels): 79 | if len(x) < maxlen: 80 | new_X.append(x) 81 | new_labels.append(y) 82 | X = new_X 83 | labels = new_labels 84 | if not nb_words: 85 | nb_words = max([max(x) for x in X]) 86 | # by convention, use 2 as OOV word 87 | # reserve 'index_from' (=3 by default) characters: 0 (padding), 1 (start), 2 (OOV) 88 | if oov_char is not None: 89 | X = [[oov_char if (w >= nb_words or w < skip_top) else w for w in x] for x in X] 90 | else: 91 | nX = [] 92 | for x in X: 93 | nx = [] 94 | for w in x: 95 | if w >= nb_words or w < skip_top: 96 | nx.append(w) 97 | nX.append(nx) 98 | X = nX 99 | X_train = X[: int(len(X) * (1 - test_split))] 100 | y_train = labels[: int(len(X) * (1 - test_split))] 101 | X_test = X[int(len(X) * (1 - test_split)) :] 102 | y_test = labels[int(len(X) * (1 - test_split)) :] 103 | return (X_train, y_train), (X_test, y_test) 104 | -------------------------------------------------------------------------------- /ml_datasets/util.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import tarfile 3 | import zipfile 4 | import os 5 | import shutil 6 | from urllib.error import URLError, HTTPError 7 | from urllib.request import urlretrieve 8 | import tqdm 9 | 10 | 11 | def get_file(fname, origin, untar=False, unzip=False, cache_subdir="datasets"): 12 | """Downloads a file from a URL if it not already in the cache.""" 13 | # https://raw.githubusercontent.com/fchollet/keras/master/keras/utils/data_utils.py 14 | # Copyright Francois Chollet, Google, others (2015) 15 | # Under MIT license 16 | datadir_base = os.path.expanduser(os.path.join("~", ".keras")) 17 | if not os.access(datadir_base, os.W_OK): 18 | datadir_base = os.path.join("/tmp", ".keras") 19 | datadir = os.path.join(datadir_base, cache_subdir) 20 | if not os.path.exists(datadir): 21 | os.makedirs(datadir) 22 | if untar or unzip: 23 | untar_fpath = os.path.join(datadir, fname) 24 | if unzip: 25 | fpath = untar_fpath + ".zip" 26 | else: 27 | fpath = untar_fpath + ".tar.gz" 28 | else: 29 | fpath = os.path.join(datadir, fname) 30 | global progbar 31 | progbar = None 32 | 33 | def dl_progress(count, block_size, total_size): 34 | global progbar 35 | if progbar is None: 36 | progbar = tqdm.tqdm(total=total_size) 37 | else: 38 | progbar.update(block_size) 39 | 40 | if not os.path.exists(fpath): 41 | try: 42 | try: 43 | urlretrieve(origin, fpath, dl_progress) 44 | # Enrich download exceptions with full file name 45 | # HTTPError is a subclass of URLError, so it must be caught first 46 | except HTTPError as e: 47 | error_msg = "URL fetch failure on {} : {} -- {}" 48 | e.msg = error_msg.format(origin, e.code, e.msg) 49 | raise 50 | except URLError as e: 51 | error_msg = "URL fetch failure on {} -- {}" 52 | e.reason = error_msg.format(origin, e.reason) 53 | raise 54 | except (Exception, KeyboardInterrupt): 55 | if os.path.exists(fpath): 56 | os.remove(fpath) 57 | raise 58 | progbar = None 59 | 60 | if untar: 61 | if not os.path.exists(untar_fpath): 62 | print("Untaring file...") 63 | tfile = tarfile.open(fpath, "r:gz") 64 | try: 65 | tfile.extractall(path=datadir) 66 | except (Exception, KeyboardInterrupt): 67 | if os.path.exists(untar_fpath): 68 | if os.path.isfile(untar_fpath): 69 | os.remove(untar_fpath) 70 | else: 71 | shutil.rmtree(untar_fpath) 72 | raise 73 | tfile.close() 74 | return untar_fpath 75 | elif unzip: 76 | if not os.path.exists(untar_fpath): 77 | print("Unzipping file...") 78 | with zipfile.ZipFile(fpath) as file_: 79 | try: 80 | file_.extractall(path=datadir) 81 | except (Exception, KeyboardInterrupt): 82 | if os.path.exists(untar_fpath): 83 | if os.path.isfile(untar_fpath): 84 | os.remove(untar_fpath) 85 | else: 86 | shutil.rmtree(untar_fpath) 87 | raise 88 | return untar_fpath 89 | return fpath 90 | 91 | 92 | def partition(examples, split_size): 93 | examples = list(examples) 94 | numpy.random.shuffle(examples) 95 | n_docs = len(examples) 96 | split = int(n_docs * split_size) 97 | return examples[:split], examples[split:] 98 | 99 | 100 | def unzip(data): 101 | x, y = zip(*data) 102 | return numpy.asarray(x), numpy.asarray(y) 103 | 104 | 105 | def to_categorical(Y, n_classes=None): 106 | # From keras 107 | Y = numpy.array(Y, dtype="int").ravel() 108 | if not n_classes: 109 | n_classes = numpy.max(Y) + 1 110 | n = Y.shape[0] 111 | categorical = numpy.zeros((n, n_classes), dtype="float32") 112 | categorical[numpy.arange(n), Y] = 1 113 | return numpy.asarray(categorical) 114 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Machine learning dataset loaders for testing and examples 4 | 5 | Loaders for various machine learning datasets for testing and example scripts. 6 | Previously in `thinc.extra.datasets`. 7 | 8 | [![PyPi Version](https://img.shields.io/pypi/v/ml-datasets.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.python.org/pypi/ml-datasets) 9 | 10 | ## Setup and installation 11 | 12 | The package can be installed via pip: 13 | 14 | ```bash 15 | pip install ml-datasets 16 | ``` 17 | 18 | ## Loaders 19 | 20 | Loaders can be imported directly or used via their string name (which is useful if they're set via command line arguments). Some loaders may take arguments – see the source for details. 21 | 22 | ```python 23 | # Import directly 24 | from ml_datasets import imdb 25 | train_data, dev_data = imdb() 26 | ``` 27 | 28 | ```python 29 | # Load via registry 30 | from ml_datasets import loaders 31 | imdb_loader = loaders.get("imdb") 32 | train_data, dev_data = imdb_loader() 33 | ``` 34 | 35 | ### Available loaders 36 | 37 | #### NLP datasets 38 | 39 | | ID / Function | Description | NLP task | From URL | 40 | | -------------------- | -------------------------------------------- | ----------------------------------------- | :------: | 41 | | `imdb` | IMDB sentiment dataset | Binary classification: sentiment analysis | ✓ | 42 | | `dbpedia` | DBPedia ontology dataset | Multi-class single-label classification | ✓ | 43 | | `cmu` | CMU movie genres dataset | Multi-class, multi-label classification | ✓ | 44 | | `quora_questions` | Duplicate Quora questions dataset | Detecting duplicate questions | ✓ | 45 | | `reuters` | Reuters dataset (texts not included) | Multi-class multi-label classification | ✓ | 46 | | `snli` | Stanford Natural Language Inference corpus | Recognizing textual entailment | ✓ | 47 | | `stack_exchange` | Stack Exchange dataset | Question Answering | | 48 | | `ud_ancora_pos_tags` | Universal Dependencies Spanish AnCora corpus | POS tagging | ✓ | 49 | | `ud_ewtb_pos_tags` | Universal Dependencies English EWT corpus | POS tagging | ✓ | 50 | | `wikiner` | WikiNER data | Named entity recognition | | 51 | 52 | #### Other ML datasets 53 | 54 | | ID / Function | Description | ML task | From URL | 55 | | ------------- | ----------- | ----------------- | :------: | 56 | | `mnist` | MNIST data | Image recognition | ✓ | 57 | 58 | ### Dataset details 59 | 60 | #### IMDB 61 | 62 | Each instance contains the text of a movie review, and a sentiment expressed as `0` or `1`. 63 | 64 | ```python 65 | train_data, dev_data = ml_datasets.imdb() 66 | for text, annot in train_data[0:5]: 67 | print(f"Review: {text}") 68 | print(f"Sentiment: {annot}") 69 | ``` 70 | 71 | - Download URL: [http://ai.stanford.edu/~amaas/data/sentiment/](http://ai.stanford.edu/~amaas/data/sentiment/) 72 | - Citation: [Andrew L. Maas et al., 2011](https://www.aclweb.org/anthology/P11-1015/) 73 | 74 | | Property | Training | Dev | 75 | | ------------------- | ---------------- | ---------------- | 76 | | # Instances | 25000 | 25000 | 77 | | Label values | {`0`, `1`} | {`0`, `1`} | 78 | | Labels per instance | Single | Single | 79 | | Label distribution | Balanced (50/50) | Balanced (50/50) | 80 | 81 | #### DBPedia 82 | 83 | Each instance contains an ontological description, and a classification into one of the 14 distinct labels. 84 | 85 | ```python 86 | train_data, dev_data = ml_datasets.dbpedia() 87 | for text, annot in train_data[0:5]: 88 | print(f"Text: {text}") 89 | print(f"Category: {annot}") 90 | ``` 91 | 92 | - Download URL: [Via fast.ai](https://course.fast.ai/datasets) 93 | - Original citation: [Xiang Zhang et al., 2015](https://arxiv.org/abs/1509.01626) 94 | 95 | | Property | Training | Dev | 96 | | ------------------- | -------- | -------- | 97 | | # Instances | 560000 | 70000 | 98 | | Label values | `1`-`14` | `1`-`14` | 99 | | Labels per instance | Single | Single | 100 | | Label distribution | Balanced | Balanced | 101 | 102 | #### CMU 103 | 104 | Each instance contains a movie description, and a classification into a list of appropriate genres. 105 | 106 | ```python 107 | train_data, dev_data = ml_datasets.cmu() 108 | for text, annot in train_data[0:5]: 109 | print(f"Text: {text}") 110 | print(f"Genres: {annot}") 111 | ``` 112 | 113 | - Download URL: [http://www.cs.cmu.edu/~ark/personas/](http://www.cs.cmu.edu/~ark/personas/) 114 | - Original citation: [David Bamman et al., 2013](https://www.aclweb.org/anthology/P13-1035/) 115 | 116 | | Property | Training | Dev | 117 | | ------------------- | --------------------------------------------------------------------------------------------- | --- | 118 | | # Instances | 41793 | 0 | 119 | | Label values | 363 different genres | - | 120 | | Labels per instance | Multiple | - | 121 | | Label distribution | Imbalanced: 147 labels with less than 20 examples, while `Drama` occurs more than 19000 times | - | 122 | 123 | #### Quora 124 | 125 | ```python 126 | train_data, dev_data = ml_datasets.quora_questions() 127 | for questions, annot in train_data[0:50]: 128 | q1, q2 = questions 129 | print(f"Question 1: {q1}") 130 | print(f"Question 2: {q2}") 131 | print(f"Similarity: {annot}") 132 | ``` 133 | 134 | Each instance contains two quora questions, and a label indicating whether or not they are duplicates (`0`: no, `1`: yes). 135 | The ground-truth labels contain some amount of noise: they are not guaranteed to be perfect. 136 | 137 | - Download URL: [http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv](http://qim.fs.quoracdn.net/quora_duplicate_questions.tsv) 138 | - Original citation: [Kornél Csernai et al., 2017](https://www.quora.com/q/quoradata/First-Quora-Dataset-Release-Question-Pairs) 139 | 140 | | Property | Training | Dev | 141 | | ------------------- | ------------------------- | ------------------------- | 142 | | # Instances | 363859 | 40429 | 143 | | Label values | {`0`, `1`} | {`0`, `1`} | 144 | | Labels per instance | Single | Single | 145 | | Label distribution | Imbalanced: 63% label `0` | Imbalanced: 63% label `0` | 146 | 147 | ### Registering loaders 148 | 149 | Loaders can be registered externally using the `loaders` registry as a decorator. For example: 150 | 151 | ```python 152 | @ml_datasets.loaders("my_custom_loader") 153 | def my_custom_loader(): 154 | return load_some_data() 155 | 156 | assert "my_custom_loader" in ml_datasets.loaders 157 | ``` 158 | -------------------------------------------------------------------------------- /ml_datasets/loaders/mnist.py: -------------------------------------------------------------------------------- 1 | import random 2 | import zipfile 3 | import gzip 4 | 5 | import cloudpickle as pickle 6 | import numpy as np 7 | 8 | from ..util import unzip, to_categorical, get_file 9 | from .._registry import register_loader 10 | 11 | 12 | MNIST_URL = "https://s3.amazonaws.com/img-datasets/mnist.pkl.gz" 13 | EMNIST_URL = "http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip" 14 | EMNIST_FILE = "gzip.zip" 15 | 16 | FA_TRAIN_IMG_URL = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz" 17 | FA_TRAIN_LBL_URL = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz" 18 | FA_TEST_IMG_URL = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz" 19 | FA_TEST_LBL_URL = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz" 20 | 21 | KU_TRAIN_IMG_URL = "http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-imgs.npz" 22 | KU_TRAIN_LBL_URL = ( 23 | "http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-train-labels.npz" 24 | ) 25 | KU_TEST_IMG_URL = "http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-imgs.npz" 26 | KU_TEST_LBL_URL = "http://codh.rois.ac.jp/kmnist/dataset/kmnist/kmnist-test-labels.npz" 27 | 28 | 29 | @register_loader("mnist") 30 | def mnist(variant="mnist", shuffle=True): 31 | if variant == "mnist": 32 | (X_train, y_train), (X_test, y_test) = load_mnist() 33 | elif variant == "fashion": 34 | (X_train, y_train), (X_test, y_test) = load_fashion_mnist() 35 | elif variant == "kuzushiji": 36 | (X_train, y_train), (X_test, y_test) = load_kuzushiji_mnist() 37 | elif variant.startswith("emnist-"): 38 | if len(variant.split("-")) != 2: 39 | raise ValueError( 40 | "EMNIST data set should be given in format " 41 | "'emnist-subset', where 'subset' can be " 42 | "'digits', 'letters', 'balanced' " 43 | "'byclass', 'bymerge' and 'mnist'. " 44 | f"{variant} was provided." 45 | ) 46 | subset = variant.split("-")[1] 47 | if subset not in [ 48 | "digits", 49 | "letters", 50 | "balanced", 51 | "byclass", 52 | "bymerge", 53 | "mnist", 54 | ]: 55 | raise ValueError( 56 | "To load EMNIST use the format " 57 | "'emnist-subset' where 'subset' can be" 58 | "'digits', 'letters', 'balanced' " 59 | "'byclass', 'bymerge' and 'mnist'." 60 | ) 61 | else: 62 | (X_train, y_train), (X_test, y_test) = load_emnist(subset=subset) 63 | else: 64 | raise ValueError( 65 | "Variant must be one of: " 66 | "'mnist', 'fashion', 'kuzushiji', " 67 | "'emnist-digits', 'emnist-letters', " 68 | "'emnist-balanced', 'emnist-byclass', " 69 | "'emnist-bymerge', 'emnist-mnist'." 70 | ) 71 | n_train = X_train.shape[0] 72 | n_test = X_test.shape[0] 73 | n_classes = len(np.unique(y_train)) 74 | X_train = X_train.reshape(n_train, 784) 75 | X_test = X_test.reshape(n_test, 784) 76 | X_train = X_train.astype("float32") 77 | X_test = X_test.astype("float32") 78 | X_train /= 255.0 79 | X_test /= 255.0 80 | if shuffle: 81 | train_data = list(zip(X_train, y_train)) 82 | random.shuffle(train_data) 83 | X_train, y_train = unzip(train_data) 84 | y_train = to_categorical(y_train, n_classes=n_classes) 85 | y_test = to_categorical(y_test, n_classes=n_classes) 86 | return (X_train, y_train), (X_test, y_test) 87 | 88 | 89 | def load_mnist(path="mnist.pkl.gz"): 90 | path = get_file(path, origin=MNIST_URL) 91 | if path.endswith(".gz"): 92 | f = gzip.open(path, "rb") 93 | else: 94 | f = open(path, "rb") 95 | data = pickle.load(f, encoding="bytes") 96 | f.close() 97 | return data # (X_train, y_train), (X_test, y_test) 98 | 99 | 100 | def load_fashion_mnist( 101 | train_img_path="train-images-idx3-ubyte.gz", 102 | train_label_path="train-labels-idx1-ubyte.gz", 103 | test_img_path="t10k-images-idx3-ubyte.gz", 104 | test_label_path="t10k-labels-idz1-ubyte.gz", 105 | ): 106 | train_img_path = get_file(train_img_path, origin=FA_TRAIN_IMG_URL) 107 | train_label_path = get_file(train_label_path, origin=FA_TRAIN_LBL_URL) 108 | test_img_path = get_file(test_img_path, origin=FA_TEST_IMG_URL) 109 | test_label_path = get_file(test_label_path, origin=FA_TEST_LBL_URL) 110 | # Based on https://github.com/zalandoresearch/fashion-mnist/blob/master/utils/mnist_reader.py 111 | with gzip.open(train_label_path, "rb") as trlbpath: 112 | train_labels = np.frombuffer(trlbpath.read(), dtype=np.uint8, offset=8) 113 | with gzip.open(train_img_path, "rb") as trimgpath: 114 | train_images = np.frombuffer( 115 | trimgpath.read(), dtype=np.uint8, offset=16 116 | ).reshape(len(train_labels), 28, 28) 117 | with gzip.open(test_label_path, "rb") as telbpath: 118 | test_labels = np.frombuffer(telbpath.read(), dtype=np.uint8, offset=8) 119 | with gzip.open(test_img_path, "rb") as teimgpath: 120 | test_images = np.frombuffer( 121 | teimgpath.read(), dtype=np.uint8, offset=16 122 | ).reshape(len(test_labels), 28, 28) 123 | return (train_images, train_labels), (test_images, test_labels) 124 | 125 | 126 | def load_kuzushiji_mnist( 127 | train_img_path="kmnist-train-imgs.npz", 128 | train_label_path="kmnist-train-labels.npz", 129 | test_img_path="kmnist-test-imgs.npz", 130 | test_label_path="kmnist-test-labels.npz", 131 | ): 132 | train_img_path = get_file(train_img_path, origin=KU_TRAIN_IMG_URL) 133 | train_label_path = get_file(train_label_path, origin=KU_TRAIN_LBL_URL) 134 | test_img_path = get_file(test_img_path, origin=KU_TEST_IMG_URL) 135 | test_label_path = get_file(test_label_path, origin=KU_TEST_LBL_URL) 136 | train_images = np.load(train_img_path)["arr_0"] 137 | train_labels = np.load(train_label_path)["arr_0"] 138 | test_images = np.load(test_img_path)["arr_0"] 139 | test_labels = np.load(test_label_path)["arr_0"] 140 | return (train_images, train_labels), (test_images, test_labels) 141 | 142 | 143 | def _decode_idx(archive, path): 144 | comp = archive.read(path) 145 | data = bytes(gzip.decompress(comp)) 146 | axes = data[3] 147 | shape = [] 148 | dtype = np.dtype("ubyte").newbyteorder(">") 149 | for axis in range(axes): 150 | offset = 4 * (axis + 1) 151 | size = int(np.frombuffer(data[offset : offset + 4], dtype=">u4")) 152 | shape.append(size) 153 | shape = tuple(shape) 154 | offset = 4 * (axes + 1) 155 | flat = np.frombuffer(data[offset:], dtype=dtype) 156 | reshaped = flat.reshape(shape) 157 | return reshaped 158 | 159 | 160 | def load_emnist(path=EMNIST_FILE, subset="digits"): 161 | emnist_path = get_file(path, origin=EMNIST_URL) 162 | train_X_path = f"gzip/emnist-{subset}-train-images-idx3-ubyte.gz" 163 | train_y_path = f"gzip/emnist-{subset}-train-labels-idx1-ubyte.gz" 164 | test_X_path = f"gzip/emnist-{subset}-test-images-idx3-ubyte.gz" 165 | test_y_path = f"gzip/emnist-{subset}-test-labels-idx1-ubyte.gz" 166 | with zipfile.ZipFile(emnist_path, "r") as archive: 167 | train_X = _decode_idx(archive, train_X_path) 168 | train_y = _decode_idx(archive, train_y_path) 169 | test_X = _decode_idx(archive, test_X_path) 170 | test_y = _decode_idx(archive, test_y_path) 171 | # For some reason in this data set the labels start from 1 172 | if subset == "letters": 173 | train_y = train_y - 1 174 | test_y = test_y - 1 175 | return (train_X, train_y), (test_X, test_y) 176 | --------------------------------------------------------------------------------