├── .coveragerc ├── .github └── workflows │ ├── build.yml │ └── flake8.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.rst ├── codecov.yml ├── libsvmdata ├── __init__.py ├── abstraction.py ├── core.py ├── datasets.py ├── libsvm.py └── tests │ ├── test_core.py │ ├── test_datasets.py │ └── test_libsvm.py ├── requirements.txt ├── setup.cfg └── setup.py /.coveragerc: -------------------------------------------------------------------------------- 1 | # Configuration for coverage.py 2 | 3 | [run] 4 | branch = True 5 | source = libsvmdata 6 | include = */libsvmdata/* 7 | omit = */setup.py 8 | 9 | [report] 10 | exclude_lines = 11 | pragma: no cover 12 | def __repr__ 13 | if self.debug: 14 | if settings.DEBUG 15 | raise AssertionError 16 | raise NotImplementedError 17 | if 0: 18 | if __name__ == .__main__.: 19 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: 4 | push: 5 | branches: 6 | - 'main' 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | build-linux: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 3.8 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.8 21 | 22 | - name: install 23 | run: | 24 | pip install -U pip 25 | pip install -e . 26 | 27 | - name: test 28 | run: | 29 | pip install pytest pytest-cov coverage numpydoc codecov 30 | pytest -lv --cov-report term-missing libsvmdata --cov=libsvmdata --cov-config .coveragerc 31 | codecov 32 | - name: codecov 33 | uses: codecov/codecov-action@v1 34 | with: 35 | files: .coveragerc 36 | flags: unittests 37 | fail_ci_if_error: true 38 | verbose: true 39 | -------------------------------------------------------------------------------- /.github/workflows/flake8.yml: -------------------------------------------------------------------------------- 1 | name: linter 2 | 3 | on: push 4 | 5 | jobs: 6 | lint: 7 | name: Lint code base 8 | runs-on: ubuntu-latest 9 | 10 | steps: 11 | - name: Checkout code 12 | uses: actions/checkout@v2 13 | 14 | - name: Setup Python 3.8 15 | uses: actions/setup-python@v2 16 | with: 17 | python-version: 3.8 18 | 19 | - name: Lint with flake 20 | run: | 21 | pip install --upgrade pip 22 | pip install flake8 23 | flake8 libsvmdata/ 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | libsvmdata.egg-info/* 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2018-2020, libsvmdata 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include *.rst 2 | include *.md 3 | include *.in 4 | include LICENSE 5 | include libsvmdata/__init__.py 6 | include requirements.txt 7 | 8 | recursive-include libsvmdata *.py 9 | 10 | ### Exclude 11 | 12 | exclude .coveragerc 13 | exclude *.yml 14 | recursive-exclude libsvmdata *.pyc 15 | 16 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | |image0| |image1| 2 | 3 | A python util to fetch datasets from different databases. 4 | 5 | Currently supported databases are: 6 | 7 | - LIBSVM (libsvm_) 8 | 9 | Getting design matrix and target variable is as easy as: 10 | 11 | :: 12 | 13 | from libsvmdata import fetch_dataset 14 | X, y = fetch_dataset("news20.binary") 15 | 16 | Currently supported datasets are in ``libsvmdata.supported`` and can be displayed as: 17 | 18 | :: 19 | 20 | from libsvmdata import print_supported_datasets 21 | print_supported_datasets() 22 | 23 | There is no need to specify the database name. 24 | 25 | Files are saved under ``DATA_HOME/``, where the value of ``DATA_HOME`` is: 26 | 27 | - the environment variable ``LIBSVMDATA_HOME`` if it exists, 28 | 29 | - else, the environment variable ``XDG_DATA_HOME`` if it exists, 30 | 31 | - else, ``$HOME/data``. 32 | 33 | 34 | 35 | .. |image0| image:: https://github.com/mathurinm/libsvmdata/actions/workflows/build.yml/badge.svg?branch=main 36 | :target: https://github.com/mathurinm/libsvmdata/actions/workflows/build.yml 37 | .. |image1| image:: https://codecov.io/gh/mathurinm/libsvmdata/branch/main/graphs/badge.svg?branch=main 38 | :target: https://codecov.io/gh/mathurinm/libsvmdata 39 | .. _libsvm: https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 40 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: down 4 | range: "70...100" 5 | status: 6 | project: 7 | default: 8 | target: auto 9 | threshold: 0.01 10 | patch: false 11 | changes: false 12 | comment: 13 | layout: "header, diff, sunburst, uncovered" 14 | behavior: default 15 | -------------------------------------------------------------------------------- /libsvmdata/__init__.py: -------------------------------------------------------------------------------- 1 | from libsvmdata.datasets import fetch_libsvm, download_libsvm 2 | from libsvmdata.core import fetch_dataset, print_supported_datasets, ALL_DATASETS 3 | 4 | supported = list(ALL_DATASETS.keys()) 5 | 6 | __version__ = '0.5dev0' 7 | -------------------------------------------------------------------------------- /libsvmdata/abstraction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | from abc import ABC, abstractmethod 5 | from download import download 6 | from pathlib import Path 7 | from scipy import sparse 8 | 9 | 10 | def _get_data_home(subdir_name=""): 11 | """ 12 | Defines the data home folder. The top priority is the environment 13 | variable $LIBSVMDATA_HOME which is specific to this package. Otherwise, we 14 | seek for the variable $XDG_DATA_HOME. Finally, the fallback is $HOME/data. 15 | """ 16 | data_home = os.environ.get("LIBSVMDATA_HOME", None) 17 | 18 | if data_home is None: 19 | data_home = os.environ.get("XDG_DATA_HOME", None) 20 | if data_home is None: 21 | data_home = Path.home() / "data" 22 | else: 23 | data_home = Path(data_home) 24 | 25 | return data_home / subdir_name 26 | 27 | 28 | class AbstractDataset(ABC): 29 | """Base class defining a dataset along with its fetching methods.""" 30 | 31 | # In the derived class, __init__() must set the following attributes : 32 | dataset_name = None # dataset name 33 | dataset_file = None # dataset file (with potential extensions) 34 | dataset_dir = None # subdirectory name (see _get_data_home()) 35 | dataset_url = None # dataset download url 36 | 37 | @abstractmethod 38 | def __init__(self): 39 | """ 40 | In the derived class, this function must define the class attributes. 41 | It can also be used to pass additional information required in the 42 | function _load_file_and_save_data() of the derived class. 43 | """ 44 | pass 45 | 46 | @abstractmethod 47 | def _load_file_and_save_data(self, raw_dataset_path, ext_dataset_path): 48 | """ 49 | In the derived class, this function is responsible of the 50 | transformation of the raw dataset file into two .npy/.npz files 51 | containing the feature matrix X and the response vector/matrix y. These 52 | files must be named _X. and 53 | _y.. This function is also responsible for 54 | removing the raw dataset file when needed. 55 | """ 56 | pass 57 | 58 | def _load_data(self, ext_dataset_path): 59 | """Load data from the extracted .npz/.npy files.""" 60 | 61 | try: 62 | X = sparse.load_npz(str(ext_dataset_path) + "_X.npz") 63 | except FileNotFoundError: 64 | X = np.load(str(ext_dataset_path) + "_X.npy") 65 | 66 | try: 67 | y = sparse.load_npz(str(ext_dataset_path) + "_y.npz") 68 | except FileNotFoundError: 69 | y = np.load(str(ext_dataset_path) + "_y.npy") 70 | 71 | return X, y 72 | 73 | def get_X_y(self, replace=False, verbose=False): 74 | """ 75 | Load a dataset as matrix X and vector y. If X and y already exist as 76 | .npz and/or .npy files, they are not redownloaded, unless replace=True. 77 | """ 78 | 79 | raw_dataset_path = self.dataset_dir / self.dataset_file 80 | ext_dataset_path = self.dataset_dir / self.dataset_name 81 | 82 | # Check if the dataset already exists 83 | if self.dataset_dir.exists(): 84 | regex = re.compile(f"{self.dataset_name}_(X|y).(npz|npy)") 85 | files = os.listdir(self.dataset_dir) 86 | found = [f for f in files if re.search(regex, f)] 87 | exists = len(found) == 2 88 | else: 89 | found = [] 90 | exists = False 91 | 92 | if replace or not exists: 93 | 94 | # Remove existing dataset files if there are any 95 | if raw_dataset_path.exists(): 96 | raw_dataset_path.unlink() 97 | for file in found: 98 | Path(self.dataset_dir / file).unlink() 99 | 100 | # Path of the raw dataset file 101 | if verbose: 102 | print("Downloading...") 103 | download( 104 | self.dataset_url, 105 | raw_dataset_path, 106 | progressbar=verbose, 107 | replace=replace, 108 | verbose=verbose, 109 | ) 110 | 111 | if verbose: 112 | print("Loading file and saving data...") 113 | X, y = self._load_file_and_save_data( 114 | raw_dataset_path, 115 | ext_dataset_path 116 | ) 117 | 118 | else: 119 | if verbose: 120 | print("Loading data...") 121 | X, y = self._load_data(ext_dataset_path) 122 | 123 | return X, y 124 | -------------------------------------------------------------------------------- /libsvmdata/core.py: -------------------------------------------------------------------------------- 1 | from libsvmdata.libsvm import DATASETS as libsvm_datasets 2 | 3 | ALL_DATABASES = {"LIBSVM": libsvm_datasets} 4 | 5 | ALL_DATASETS = { 6 | dataset.dataset_name: dataset 7 | for datasets in ALL_DATABASES.values() 8 | for dataset in datasets 9 | } 10 | 11 | 12 | def fetch_dataset(dataset_name, replace=False, verbose=False): 13 | """ 14 | Load a dataset. It is downloaded only if not present or when replace=True. 15 | 16 | Parameters 17 | ---------- 18 | dataset_name : string 19 | Dataset name. 20 | 21 | replace : bool, default=False 22 | Whether to re-download the dataset if it is already downloaded. 23 | 24 | verbose : bool, default=False 25 | Whether or not to print information about dataset loading. 26 | 27 | 28 | Returns 29 | ------- 30 | X : np.ndarray or scipy.sparse.csc_matrix 31 | Design matrix, as 2D array or column sparse format depending on the 32 | dataset. 33 | 34 | y : 1D or 2D np.ndarray 35 | Design vector (or matrix in multiclass setting). 36 | """ 37 | 38 | if dataset_name not in ALL_DATASETS.keys(): 39 | raise ValueError( 40 | f"Unsupported dataset `{dataset_name}`. Supported datasets can be " 41 | "displayed using the `libsvmdata.print_supported_datasets` " 42 | "function." 43 | ) 44 | 45 | dataset = ALL_DATASETS[dataset_name] 46 | 47 | X, y = dataset.get_X_y(replace=replace, verbose=verbose) 48 | 49 | return X, y 50 | 51 | 52 | def print_supported_datasets(): 53 | print("Supported datasets") 54 | for database_name, datasets in ALL_DATABASES.items(): 55 | print(f"- {database_name}: ") 56 | print(", ".join(dataset.dataset_name for dataset in datasets)) 57 | -------------------------------------------------------------------------------- /libsvmdata/datasets.py: -------------------------------------------------------------------------------- 1 | # This file aims to avoid compatibility issues with versions of the libsvmdata 2 | # package anterior to https://github.com/mathurinm/libsvmdata/pull/37. 3 | 4 | import download 5 | import numpy as np 6 | import warnings 7 | from sklearn import preprocessing 8 | from scipy import sparse 9 | from libsvmdata.libsvm import DATASETS as libsvm_datasets 10 | from libsvmdata.core import fetch_dataset 11 | 12 | # The `NAMES` variable before the pull request #37 can be reconstructed from 13 | # the `DATASETS` variable in the `libsvm.py` file. 14 | NAMES = { 15 | dataset.dataset_name: "/".join([dataset.task_name, dataset.dataset_file]) 16 | for dataset in libsvm_datasets 17 | } 18 | 19 | 20 | def download_libsvm(dataset, destination, replace=False, verbose=False): 21 | """Download a dataset from LIBSVM website.""" 22 | 23 | warnings.warn( 24 | "The function `download_libsvm` is deprecated in `v0.5` and " 25 | "removed in `v0.6`. See " 26 | "https://github.com/mathurinm/libsvmdata/pull/37 for more details.", 27 | FutureWarning 28 | ) 29 | 30 | url = ( 31 | "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/" 32 | + NAMES[dataset] 33 | ) 34 | path = download(url, destination, replace=replace, verbose=verbose) 35 | return path 36 | 37 | 38 | def fetch_libsvm(dataset, replace=False, normalize=False, min_nnz=0, 39 | verbose=False): 40 | """ 41 | Download a dataset from LIBSVM website. 42 | Parameters 43 | ---------- 44 | dataset : string 45 | Dataset name. Must be in libsvmdata.supported. 46 | replace : bool, default=False 47 | Whether to force download of dataset if already downloaded. 48 | normalize : bool, default=False 49 | If True, columns of X are set to unit norm. This may make little sense 50 | for a sparse matrix since centering is not performed. 51 | y is centered and set to unit norm if the dataset is a regression one. 52 | min_nnz : int, default=0 53 | When X is sparse, columns of X with strictly less than min_nnz 54 | non-zero entries are discarded. 55 | verbose : bool, default=False 56 | Whether or not to print information about dataset loading. 57 | Returns 58 | ------- 59 | X : np.ndarray or scipy.sparse.csc_matrix 60 | Design matrix, as 2D array or column sparse format depending on the 61 | dataset. 62 | y : 1D or 2D np.ndarray 63 | Design vector or matrix (in multiclass setting). 64 | References 65 | ---------- 66 | https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/ 67 | """ 68 | 69 | warnings.warn( 70 | "The function `fetch_libsvm` is deprecated in `v0.5` and " 71 | "replaced by `fetch_dataset`. It will be removed in `v0.6`. See " 72 | "https://github.com/mathurinm/libsvmdata/pull/37 for more details.", 73 | FutureWarning 74 | ) 75 | 76 | if dataset not in NAMES: 77 | raise ValueError("Unsupported dataset %s. " % dataset + 78 | "Supported datasets are: \n" + ', '.join(NAMES)) 79 | is_regression = NAMES[dataset].split('/')[0] == 'regression' 80 | 81 | if verbose: 82 | print("Dataset: %s" % dataset) 83 | 84 | # Does exactly the same as the original `_get_X_y` function but without 85 | # the normalization and the removing of too sparse columns done in 86 | # post-processing step when. These steps are therefore done just below. 87 | X, y = fetch_dataset(dataset, replace=replace, verbose=verbose) 88 | 89 | # removing columns with to few non zero entries when using sparse X 90 | if sparse.issparse(X) and min_nnz != 0: 91 | X = X[:, np.diff(X.indptr) >= min_nnz] 92 | 93 | if normalize: 94 | X = preprocessing.normalize(X, axis=0) 95 | if is_regression: 96 | y -= np.mean(y) 97 | y /= np.std(y) 98 | 99 | return X, y 100 | -------------------------------------------------------------------------------- /libsvmdata/libsvm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from bz2 import BZ2Decompressor 3 | from scipy import sparse 4 | from sklearn.datasets import load_svmlight_file 5 | from libsvmdata.abstraction import _get_data_home, AbstractDataset 6 | 7 | LIBSVM_DATA_HOME = _get_data_home("libsvm") 8 | LIBSVM_BASE_URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets" 9 | 10 | 11 | class LibsvmDataset(AbstractDataset): 12 | def __init__(self, dataset_name, dataset_file, task_name, n_features): 13 | 14 | self.dataset_name = dataset_name 15 | self.dataset_file = dataset_file 16 | self.dataset_dir = LIBSVM_DATA_HOME 17 | self.dataset_url = "/".join([LIBSVM_BASE_URL, task_name, dataset_file]) 18 | self.task_name = task_name 19 | self.n_features = n_features 20 | 21 | def _load_file_and_save_data(self, raw_dataset_path, ext_dataset_path): 22 | 23 | # Handle .bz2 compressed datasets 24 | if str(raw_dataset_path).endswith(".bz2"): 25 | tmp_dataset_path = raw_dataset_path.with_suffix("") 26 | decompressor = BZ2Decompressor() 27 | # TODO : why not using bz2.decompress(raw_dataset_path) only ? 28 | with open(tmp_dataset_path, "wb") as f, open( 29 | raw_dataset_path, "rb" 30 | ) as g: 31 | for data in iter(lambda: g.read(100 * 1024), b""): 32 | f.write(decompressor.decompress(data)) 33 | raw_dataset_path.unlink() 34 | raw_dataset_path = tmp_dataset_path 35 | 36 | with open(raw_dataset_path, "rb") as file: 37 | X, y = load_svmlight_file( 38 | file, 39 | n_features=self.n_features, 40 | multilabel=(self.task_name == "multilabel"), 41 | ) 42 | raw_dataset_path.unlink() 43 | 44 | X_path = str(ext_dataset_path) + "_X" 45 | if len(X.data) >= 0.5 * X.shape[0] * X.shape[1]: 46 | X = X.toarray(order="F") 47 | np.save(X_path, X) 48 | else: 49 | X = sparse.csc_matrix(X) 50 | X.sort_indices() 51 | sparse.save_npz(X_path, X) 52 | 53 | y_path = str(ext_dataset_path) + "_y" 54 | if self.task_name == "multilabel": 55 | indices = np.array([lab for labels in y for lab in labels]) 56 | indptr = np.cumsum([0] + [len(labels) for labels in y]) 57 | data = np.ones_like(indices) 58 | y = sparse.csr_matrix((data, indices, indptr)) 59 | sparse.save_npz(y_path, y) 60 | else: 61 | np.save(y_path, y) 62 | 63 | return X, y 64 | 65 | 66 | DATASETS = [ 67 | LibsvmDataset("a1a", "a1a", "binary", 123), 68 | LibsvmDataset("a1a_test", "a1a.t", "binary", 123), 69 | LibsvmDataset("a2a", "a2a", "binary", 123), 70 | LibsvmDataset("a2a_test", "a2a.t", "binary", 123), 71 | LibsvmDataset("a3a", "a3a", "binary", 123), 72 | LibsvmDataset("a3a_test", "a3a.t", "binary", 123), 73 | LibsvmDataset("a4a", "a4a", "binary", 123), 74 | LibsvmDataset("a4a_test", "a4a.t", "binary", 123), 75 | LibsvmDataset("a5a", "a5a", "binary", 123), 76 | LibsvmDataset("a5a_test", "a5a.t", "binary", 123), 77 | LibsvmDataset("a6a", "a6a", "binary", 123), 78 | LibsvmDataset("a6a_test", "a6a.t", "binary", 123), 79 | LibsvmDataset("a7a", "a7a", "binary", 123), 80 | LibsvmDataset("a7a_test", "a7a.t", "binary", 123), 81 | LibsvmDataset("a8a", "a8a", "binary", 123), 82 | LibsvmDataset("a8a_test", "a8a.t", "binary", 123), 83 | LibsvmDataset("a9a", "a9a", "binary", 123), 84 | LibsvmDataset("a9a_test", "a9a.t", "binary", 123), 85 | LibsvmDataset("abalone", "abalone", "regression", 8), 86 | LibsvmDataset("abalone_scale", "abalone_scale", "regression", 8), 87 | LibsvmDataset("aloi", "aloi.bz2", "multiclass", 128), 88 | LibsvmDataset("australian", "australian", "binary", 14), 89 | LibsvmDataset("australian_scale", "australian_scale", "binary", 14), 90 | LibsvmDataset("bibtex", "bibtex.bz2", "multilabel", 1836), 91 | LibsvmDataset("bodyfat", "bodyfat", "regression", 14), 92 | LibsvmDataset("breast-cancer", "breast-cancer", "binary", 10), 93 | LibsvmDataset("breast-cancer_scale", "breast-cancer_scale", "binary", 10), 94 | LibsvmDataset("cadata", "cadata", "regression", 8), 95 | LibsvmDataset("cifar10", "cifar10.bz2", "multiclass", 3072), 96 | LibsvmDataset("cifar10_test", "cifar10.t.bz2", "multiclass", 3072), 97 | LibsvmDataset("cod-rna", "cod-rna", "binary", 8), 98 | LibsvmDataset("cod-rna_test", "cod-rna.t", "binary", 8), 99 | LibsvmDataset("colon-cancer", "colon-cancer.bz2", "binary", 2000), 100 | LibsvmDataset("connect-4", "connect-4", "multiclass", 126), 101 | LibsvmDataset("covtype.binary", "covtype.libsvm.binary.bz2", "binary", 54), 102 | LibsvmDataset("covtype.multiclass", "covtype.bz2", "multiclass", 54), 103 | LibsvmDataset("covtype.multiclass_scale", "covtype.scale01.bz2", "multiclass", 54), # noqa 104 | LibsvmDataset("cpusmall", "cpusmall", "regression", 12), 105 | LibsvmDataset("delicious", "delicious.bz2", "multilabel", 500), 106 | LibsvmDataset("diabetes", "diabetes", "binary", 8), 107 | LibsvmDataset("diabetes_scale", "diabetes_scale", "binary", 8), 108 | LibsvmDataset("dna", "dna.scale", "multiclass", 180), 109 | LibsvmDataset("duke breast-cancer", "duke.bz2", "binary", 7129), 110 | LibsvmDataset("epsilon", "epsilon_normalized.bz2", "binary", 2000), 111 | LibsvmDataset("epsilon_test", "epsilon_normalized.t.bz2", "binary", 2000), 112 | LibsvmDataset("eunite2001", "eunite2001", "regression", 16), 113 | LibsvmDataset("finance", "log1p.E2006.train.bz2", "regression", 4272227), 114 | LibsvmDataset("finance-tf-idf", "E2006.train.bz2", "regression", 150360), 115 | LibsvmDataset("fourclass", "fourclass", "binary", 2), 116 | LibsvmDataset("fourclass_scale", "fourclass_scale", "binary", 2), 117 | LibsvmDataset("german.numer", "german.numer", "binary", 24), 118 | LibsvmDataset("german.numer_scale", "german.numer_scale", "binary", 24), 119 | LibsvmDataset("gisette", "gisette_scale.bz2", "binary", 5000), 120 | LibsvmDataset("glass", "glass.scale", "multiclass", 9), 121 | LibsvmDataset("heart", "heart", "binary", 13), 122 | LibsvmDataset("heart_scale", "heart_scale", "binary", 13), 123 | LibsvmDataset("HIGGS", "HIGGS.bz2", "binary", 28), 124 | LibsvmDataset("housing", "housing", "regression", 13), 125 | LibsvmDataset("ijcnn1", "ijcnn1.bz2", "binary", 22), 126 | LibsvmDataset("ijcnn1_test", "ijcnn1.t.bz2", "binary", 22), 127 | LibsvmDataset("ionosphere", "ionosphere_scale", "binary", 34), 128 | LibsvmDataset("iris", "iris.scale", "multiclass", 4), 129 | LibsvmDataset("kdda_train", "kdda.bz2", "binary", 20216830), 130 | LibsvmDataset("letter", "letter.scale", "multiclass", 16), 131 | LibsvmDataset("leukemia", "leu.bz2", "binary", 7129), 132 | LibsvmDataset("leukemia_test", "leu.t.bz2", "binary", 7129), 133 | LibsvmDataset("liver-disorders", "liver-disorders", "binary", 5), 134 | LibsvmDataset("liver-disorders_scale", "liver-disorders_scale", "binary", 5), # noqa 135 | LibsvmDataset("liver-disorders_test", "liver-disorders.t", "binary", 5), 136 | LibsvmDataset("madelon", "madelon", "binary", 500), 137 | LibsvmDataset("madelon_test", "madelon.t", "binary", 500), 138 | LibsvmDataset("mediamill", "mediamill/train-exp1.svm.bz2", "multilabel", 120), # noqa 139 | LibsvmDataset("mediamill_test", "mediamill/test-exp1.svm.bz2", "multilabel", 120), # noqa 140 | LibsvmDataset("mnist", "mnist.bz2", "multiclass", 780), 141 | LibsvmDataset("news20.binary", "news20.binary.bz2", "binary", 1355191), 142 | LibsvmDataset("news20.multiclass", "news20.bz2", "multiclass", 62061), 143 | LibsvmDataset("pendigits", "pendigits", "multiclass", 16), 144 | LibsvmDataset("pendigits_test", "pendigits.t", "multiclass", 16), 145 | LibsvmDataset("phishing", "phishing", "binary", 68), 146 | LibsvmDataset("rcv1.binary", "rcv1_train.binary.bz2", "binary", 47236), 147 | LibsvmDataset("rcv1.binary_test", "rcv1_test.binary.bz2", "binary", 47236), 148 | LibsvmDataset("rcv1.multiclass", "rcv1_train.multiclass.bz2", "multiclass", 47236), # noqa 149 | LibsvmDataset("rcv1.multiclass_test", "rcv1_test.multiclass.bz2", "multiclass", 47236), # noqa 150 | LibsvmDataset("rcv1_topics_test", "rcv1_topics_test_2.svm.bz2", "multilabel", 47236), # noqa 151 | LibsvmDataset("real-sim", "real-sim.bz2", "binary", 20958), 152 | LibsvmDataset("scene-classification", "scene_train.bz2", "multilabel", 294), # noqa 153 | LibsvmDataset("scene-classification_test", "scene_test.bz2", "multilabel", 294), # noqa 154 | LibsvmDataset("sector.scale", "sector/sector.scale.bz2", "multiclass", 55197), # noqa 155 | LibsvmDataset("sector.scale_test", "sector/sector.t.scale.bz2", "multiclass", 55197), # noqa 156 | LibsvmDataset("sector", "sector/sector.bz2", "multiclass", 55197), 157 | LibsvmDataset("sector_test", "sector/sector.t.bz2", "multiclass", 55197), 158 | LibsvmDataset("sensit", "vehicle/combined.bz2", "multiclass", 100), 159 | LibsvmDataset("siam-competition2007", "tmc2007_train.svm.bz2", "multilabel", 30438), # noqa 160 | LibsvmDataset("siam-competition2007_test", "tmc2007_test.svm.bz2", "multilabel", 30438,), # noqa 161 | LibsvmDataset("skin_nonskin", "skin_nonskin", "binary", 3), 162 | LibsvmDataset("smallNORB", "smallNORB.bz2", "multiclass", 18432), 163 | LibsvmDataset("sonar", "sonar_scale", "binary", 60), 164 | LibsvmDataset("splice", "splice", "binary", 60), 165 | LibsvmDataset("splice_scale", "splice_scale", "binary", 60), 166 | LibsvmDataset("splice_test", "splice.t", "binary", 60), 167 | LibsvmDataset("SUSY", "SUSY.bz2", "binary", 18), 168 | LibsvmDataset("svmguide1", "svmguide1", "binary", 4), 169 | LibsvmDataset("svmguide1_test", "svmguide1.t", "binary", 4), 170 | LibsvmDataset("url", "url_combined.bz2", "binary", 3231961), 171 | LibsvmDataset("usps", "usps.bz2", "multiclass", 7291), 172 | LibsvmDataset("usps_test", "usps.t.bz2", "multiclass", 7291), 173 | LibsvmDataset("w1a", "w1a", "binary", 300), 174 | LibsvmDataset("w1a_test", "w1a.t", "binary", 300), 175 | LibsvmDataset("w2a", "w2a", "binary", 300), 176 | LibsvmDataset("w2a_test", "w2a.t", "binary", 300), 177 | LibsvmDataset("w3a", "w1a", "binary", 300), 178 | LibsvmDataset("w3a_test", "w1a.t", "binary", 300), 179 | LibsvmDataset("w4a", "w1a", "binary", 300), 180 | LibsvmDataset("w4a_test", "w1a.t", "binary", 300), 181 | LibsvmDataset("w5a", "w1a", "binary", 300), 182 | LibsvmDataset("w5a_test", "w1a.t", "binary", 300), 183 | LibsvmDataset("w6a", "w1a", "binary", 300), 184 | LibsvmDataset("w6a_test", "w1a.t", "binary", 300), 185 | LibsvmDataset("w7a", "w7a", "binary", 300), 186 | LibsvmDataset("w7a_test", "w7a.t", "binary", 300), 187 | LibsvmDataset("w8a", "w8a", "binary", 300), 188 | LibsvmDataset("w8a_test", "w8a.t", "binary", 300), 189 | LibsvmDataset("w9a", "w9a", "binary", 300), 190 | LibsvmDataset("w9a_test", "w9a.t", "binary", 300), 191 | LibsvmDataset("webspam", "webspam_wc_normalized_trigram.svm.bz2", "binary", 16609143), # noqa 192 | LibsvmDataset("YearPredictionMSD", "YearPredictionMSD.bz2", "regression", 90), # noqa 193 | LibsvmDataset("YearPredictionMSD_test", "YearPredictionMSD.t.bz2", "regression", 90), # noqa 194 | LibsvmDataset("yeast", "yeast_train.svm.bz2", "multilabel", 103), 195 | LibsvmDataset("yeast_test", "yeast_test.svm.bz2", "multilabel", 103), 196 | ] 197 | -------------------------------------------------------------------------------- /libsvmdata/tests/test_core.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from numpy.testing import assert_equal 3 | from libsvmdata import fetch_dataset, print_supported_datasets 4 | 5 | 6 | # TODO : add other datasets to test ? 7 | def test_replace(): 8 | X_first, y_first = fetch_dataset("iris", replace=True) 9 | X_second, y_second = fetch_dataset("iris") 10 | assert_equal(X_first, X_second) 11 | assert_equal(y_first, y_second) 12 | 13 | 14 | def test_wrong_dataset(): 15 | with pytest.raises(ValueError): 16 | fetch_dataset("unknowndataset") 17 | 18 | 19 | def test_print_supported_datasets(): 20 | print_supported_datasets() 21 | -------------------------------------------------------------------------------- /libsvmdata/tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from numpy.testing import assert_, assert_allclose, assert_equal 4 | from libsvmdata import fetch_libsvm 5 | 6 | TEST_DATASETS = { 7 | "regression": [ 8 | ("abalone", 4_177, 8), 9 | ("bodyfat", 252, 14), 10 | ], 11 | "binary": [ 12 | ("a1a", 1_605, 123), 13 | ("breast-cancer", 683, 10), 14 | ], 15 | "multiclass": [ 16 | ("dna", 2_000, 180), 17 | ("iris", 150, 4), 18 | ], 19 | "multilabel": [ 20 | ("bibtex", 7_395, 1_836), 21 | ("scene-classification", 1_211, 294), 22 | ], 23 | } 24 | 25 | 26 | @pytest.mark.filterwarnings("ignore:FutureWarning") 27 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["regression"]) 28 | def test_regression(dataset_name, n, p): 29 | X, y = fetch_libsvm(dataset_name) 30 | assert_equal(X.shape[0], n) 31 | assert_equal(X.shape[1], p) 32 | assert_equal(y.shape[0], n) 33 | 34 | 35 | @pytest.mark.filterwarnings("ignore:FutureWarning") 36 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["binary"]) 37 | def test_binary(dataset_name, n, p): 38 | X, y = fetch_libsvm(dataset_name, n, p) 39 | assert_equal(X.shape[0], n) 40 | assert_equal(X.shape[1], p) 41 | assert_equal(y.shape[0], n) 42 | assert_equal(len(np.unique(y)), 2) 43 | 44 | 45 | @pytest.mark.filterwarnings("ignore:FutureWarning") 46 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["multiclass"]) 47 | def test_multiclass(dataset_name, n, p): 48 | X, y = fetch_libsvm(dataset_name) 49 | assert_equal(X.shape[0], n) 50 | assert_equal(X.shape[1], p) 51 | assert_equal(y.shape[0], n) 52 | assert_(len(np.unique(y)) > 2) 53 | 54 | 55 | @pytest.mark.filterwarnings("ignore:FutureWarning") 56 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["multilabel"]) 57 | def test_multilabel(dataset_name, n, p): 58 | X, y = fetch_libsvm(dataset_name) 59 | assert_equal(X.shape[0], n) 60 | assert_equal(X.shape[1], p) 61 | assert_equal(y.shape[0], n) 62 | assert_(y.shape[1] > 2) 63 | 64 | 65 | @pytest.mark.filterwarnings("ignore:FutureWarning") 66 | def test_normalization(): 67 | X, y = fetch_libsvm("abalone", normalize=True) 68 | assert_allclose(np.linalg.norm(X, axis=0), 1.) 69 | assert_allclose(np.mean(y), 0., atol=1e-07) 70 | assert_allclose(np.std(y), 1.) 71 | -------------------------------------------------------------------------------- /libsvmdata/tests/test_libsvm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from numpy.testing import assert_, assert_equal 4 | from libsvmdata import fetch_dataset 5 | 6 | # TODO : add more datasets to test ? 7 | TEST_DATASETS = { 8 | "regression": [ 9 | ("abalone", 4_177, 8), 10 | ("bodyfat", 252, 14), 11 | ], 12 | "binary": [ 13 | ("a1a", 1_605, 123), 14 | ("breast-cancer", 683, 10), 15 | ], 16 | "multiclass": [ 17 | ("dna", 2_000, 180), 18 | ("iris", 150, 4), 19 | ], 20 | "multilabel": [ 21 | ("bibtex", 7_395, 1_836), 22 | ("scene-classification", 1_211, 294), 23 | ], 24 | } 25 | 26 | 27 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["regression"]) 28 | def test_regression(dataset_name, n, p): 29 | X, y = fetch_dataset(dataset_name) 30 | assert_equal(X.shape[0], n) 31 | assert_equal(X.shape[1], p) 32 | assert_equal(y.shape[0], n) 33 | 34 | 35 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["binary"]) 36 | def test_binary(dataset_name, n, p): 37 | X, y = fetch_dataset(dataset_name, n, p) 38 | assert_equal(X.shape[0], n) 39 | assert_equal(X.shape[1], p) 40 | assert_equal(y.shape[0], n) 41 | assert_equal(len(np.unique(y)), 2) 42 | 43 | 44 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["multiclass"]) 45 | def test_multiclass(dataset_name, n, p): 46 | X, y = fetch_dataset(dataset_name) 47 | assert_equal(X.shape[0], n) 48 | assert_equal(X.shape[1], p) 49 | assert_equal(y.shape[0], n) 50 | assert_(len(np.unique(y)) > 2) 51 | 52 | 53 | @pytest.mark.parametrize("dataset_name,n,p", TEST_DATASETS["multilabel"]) 54 | def test_multilabel(dataset_name, n, p): 55 | X, y = fetch_dataset(dataset_name) 56 | assert_equal(X.shape[0], n) 57 | assert_equal(X.shape[1], p) 58 | assert_equal(y.shape[0], n) 59 | assert_(y.shape[1] > 2) 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | download 2 | numpy 3 | scikit-learn 4 | scipy -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = __init__.py -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools.command.build_ext import build_ext 3 | from setuptools import dist, setup, Extension, find_packages 4 | 5 | descr = 'Fetcher for datasets' 6 | 7 | version = None 8 | with open(os.path.join('libsvmdata', '__init__.py'), 'r') as fid: 9 | for line in (line.strip() for line in fid): 10 | if line.startswith('__version__'): 11 | version = line.split('=')[1].strip().strip('\'') 12 | break 13 | if version is None: 14 | raise RuntimeError('Could not determine version') 15 | 16 | DISTNAME = 'libsvmdata' 17 | DESCRIPTION = descr 18 | MAINTAINER = 'Mathurin Massias' 19 | MAINTAINER_EMAIL = 'mathurin.massias@gmail.com' 20 | LICENSE = 'BSD (3-clause)' 21 | DOWNLOAD_URL = 'https://github.com/mathurinm/libsvmdata.git' 22 | VERSION = version 23 | 24 | setup(name='libsvmdata', 25 | version=VERSION, 26 | description=DESCRIPTION, 27 | long_description=open('README.rst').read(), 28 | license=LICENSE, 29 | maintainer=MAINTAINER, 30 | maintainer_email=MAINTAINER_EMAIL, 31 | download_url=DOWNLOAD_URL, 32 | install_requires=['download', 'numpy>=1.12', 'scikit-learn', 'scipy'], 33 | packages=find_packages(), 34 | ) 35 | --------------------------------------------------------------------------------