├── .github └── workflows │ └── testing.yml ├── .gitignore ├── CHANGES.txt ├── LICENSE ├── MANIFEST.in ├── README.rst ├── __init__.py ├── examples ├── abod_example.py ├── ecod_example.py ├── hbos_example.py ├── knn_example.py ├── lof_example.py └── pca_example.py ├── figs ├── abstraction.png ├── abstraction_example.png └── run_time.png ├── pypi_build_commands.txt ├── pytod ├── __init__.py ├── models │ ├── __init__.py │ ├── abod.py │ ├── base.py │ ├── basic_operators.py │ ├── basic_operators_batch.py │ ├── ecod.py │ ├── functional_operators.py │ ├── hbos.py │ ├── intermediate_layers.py │ ├── knn.py │ ├── lof.py │ ├── pca.py │ ├── quantization.py │ └── sklearn_base.py ├── test │ ├── test_abod.py │ ├── test_base.py │ ├── test_basic_operators.py │ ├── test_ecod.py │ ├── test_hbos.py │ ├── test_knn.py │ ├── test_lof.py │ └── test_pca.py ├── utils │ ├── __init__.py │ ├── data.py │ └── utility.py └── version.py ├── reproducibility ├── __init__.py ├── additional_scripts │ ├── multi-knn.py │ ├── numpy_vs_torch_batch.py │ ├── quant_odsys.py │ ├── quant_odsys_f.py │ ├── quantization.py │ ├── quantization_memory.py │ ├── readme.MD │ └── single-knn.py ├── compare_real_data.py ├── compare_real_data_adbench.py ├── compare_real_data_quant.py ├── compare_synthetic.py ├── datasets │ └── ODDS │ │ ├── annthyroid.mat │ │ ├── arrhythmia.mat │ │ ├── breastw.mat │ │ ├── glass.mat │ │ ├── ionosphere.mat │ │ ├── letter.mat │ │ ├── lympho.mat │ │ ├── mammography.mat │ │ ├── mnist.mat │ │ ├── musk.mat │ │ ├── optdigits.mat │ │ ├── pendigits.mat │ │ ├── pima.mat │ │ ├── satellite.mat │ │ ├── satimage-2.mat │ │ ├── shuttle.mat │ │ ├── smtp_n.mat │ │ ├── speech.mat │ │ ├── thyroid.mat │ │ ├── vertebral.mat │ │ ├── vowels.mat │ │ ├── wbc.mat │ │ └── wine.mat ├── fusion_experiment.py ├── implement_new.py ├── knn_classification.py ├── results.txt ├── results_synthetic.txt ├── time_break.xlsx └── time_breakdown.py ├── requirements.txt └── setup.py /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Testing 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | - dev 11 | pull_request: 12 | branches: 13 | - main 14 | - dev 15 | 16 | jobs: 17 | build: 18 | runs-on: ${{ matrix.os }} 19 | 20 | strategy: 21 | fail-fast: false 22 | matrix: 23 | os: [windows-latest] 24 | python-version: [3.6, 3.7, 3.8, 3.9] 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | 33 | - name: Install dependencies 34 | run: | 35 | python -m pip install --upgrade pip 36 | pip install -r requirements.txt 37 | pip install pytest 38 | pip install coverage 39 | pip install coveralls 40 | 41 | - name: Test with pytest 42 | run: | 43 | coverage run --source=pytod -m pytest 44 | - name: coverage report 45 | env: 46 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 47 | run: | 48 | coveralls --service=github 49 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CHANGES.txt: -------------------------------------------------------------------------------- 1 | v<0.0.0>, <11/22/2021> -- Initial release. 2 | v<0.0.1>, <04/12/2021> -- Add LOF. 3 | v<0.0.1>, <04/23/2021> -- Add ABOD. 4 | v<0.0.2>, <06/19/2021> -- Add PCA and HBOS. 5 | v<0.0.2>, <06/19/2021> -- Turn on test suites. 6 | v<0.0.3>, <06/19/2021> -- Hotfix. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2021, Yue Zhao, George H. Chen, and Zhihao Jia 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | prune examples 2 | prune pytod/test 3 | prune figs 4 | prune reproducibility 5 | include README.rst 6 | include requirements.txt 7 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/__init__.py -------------------------------------------------------------------------------- /examples/abod_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using Angle-base outlier detection (ABOD) for outlier detection 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import os 8 | import sys 9 | import time 10 | 11 | import torch 12 | from pyod.models.abod import ABOD as ABOD_PyOD 13 | from pyod.utils.data import evaluate_print 14 | from pyod.utils.data import generate_data 15 | 16 | # temporary solution for relative imports in case pyod is not installed 17 | # if pyod is installed, no need to use the following line 18 | sys.path.append( 19 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 20 | 21 | from pytod.models.abod import ABOD 22 | from pytod.utils.utility import validate_device 23 | 24 | contamination = 0.1 # percentage of outliers 25 | n_train = 10000 # number of training points 26 | n_test = 5000 # number of testing points 27 | n_features = 20 28 | k = 10 29 | 30 | # Generate sample data 31 | X_train, X_test, y_train, y_test = \ 32 | generate_data(n_train=n_train, 33 | n_test=n_test, 34 | n_features=n_features, 35 | contamination=contamination, 36 | random_state=42) 37 | 38 | clf_name = 'ABOD-PyOD' 39 | clf = ABOD_PyOD(n_neighbors=k) 40 | start = time.time() 41 | clf.fit(X_train) 42 | end = time.time() 43 | # get the prediction labels and outlier scores of the training data 44 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 45 | y_train_scores = clf.decision_scores_ # raw outlier scores 46 | 47 | # evaluate and print the results 48 | print("\nOn Training Data:") 49 | evaluate_print(clf_name, y_train, y_train_scores) 50 | pyod_time = end - start 51 | print('PyOD execution time', pyod_time) 52 | 53 | X_train, y_train, X_test, y_test = torch.from_numpy(X_train), \ 54 | torch.from_numpy(y_train), \ 55 | torch.from_numpy(X_test), \ 56 | torch.from_numpy(y_test) 57 | 58 | print() 59 | print() 60 | # try to access the GPU, fall back to cpu if no gpu is available 61 | device = validate_device(0) 62 | # device = 'cpu' 63 | clf_name = 'abod-PyTOD' 64 | clf = ABOD(n_neighbors=k, batch_size=10000, device=device) 65 | # clf = ABOD(n_neighbors=k, batch_size=None, device=device) 66 | start = time.time() 67 | clf.fit(X_train) 68 | end = time.time() 69 | # get the prediction labels and outlier scores of the training data 70 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 71 | y_train_scores = clf.decision_scores_ # raw outlier scores 72 | 73 | # evaluate and print the results 74 | print("\nOn Training Data:") 75 | evaluate_print(clf_name, y_train, y_train_scores) 76 | tod_time = end - start 77 | print('TOD execution time', tod_time) 78 | 79 | print('TOD is', round(pyod_time / tod_time, ndigits=2), 80 | 'times faster than PyOD') 81 | -------------------------------------------------------------------------------- /examples/ecod_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using Empirical Cumulative Distribution Functions (ECOD) for 3 | outlier detection 4 | """ 5 | # Author: Yue Zhao 6 | # License: BSD 2 clause 7 | 8 | import os 9 | import sys 10 | import time 11 | 12 | import torch 13 | from pyod.models.ecod import ECOD as ECOD_PyOD 14 | from pyod.utils.data import evaluate_print 15 | from pyod.utils.data import generate_data 16 | 17 | # temporary solution for relative imports in case pyod is not installed 18 | # if pyod is installed, no need to use the following line 19 | sys.path.append( 20 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 21 | 22 | from pytod.models.ecod import ECOD 23 | from pytod.utils.utility import validate_device 24 | 25 | contamination = 0.1 # percentage of outliers 26 | n_train = 10000 # number of training points 27 | n_test = 5000 # number of testing points 28 | n_features = 5000 29 | k = 10 30 | 31 | # Generate sample data 32 | X_train, X_test, y_train, y_test = \ 33 | generate_data(n_train=n_train, 34 | n_test=n_test, 35 | n_features=n_features, 36 | contamination=contamination, 37 | random_state=42) 38 | 39 | clf_name = 'ECOD-PyOD' 40 | clf = ECOD_PyOD() 41 | start = time.time() 42 | clf.fit(X_train) 43 | end = time.time() 44 | 45 | pyod_time = end - start 46 | print('PyOD execution time', pyod_time) 47 | 48 | X_train, y_train, X_test, y_test = torch.from_numpy(X_train), \ 49 | torch.from_numpy(y_train), \ 50 | torch.from_numpy(X_test), \ 51 | torch.from_numpy(y_test) 52 | 53 | print() 54 | print() 55 | # try to access the GPU, fall back to cpu if no gpu is available 56 | device = validate_device(0) 57 | # device = 'cpu' 58 | clf_name = 'ECOD-PyTOD' 59 | clf = ECOD(device=device) 60 | start = time.time() 61 | clf.fit(X_train) 62 | end = time.time() 63 | 64 | tod_time = end - start 65 | print('TOD execution time', tod_time) 66 | 67 | print('TOD is', round(pyod_time / tod_time, ndigits=2), 68 | 'times faster than PyOD') 69 | -------------------------------------------------------------------------------- /examples/hbos_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using Histogram- based outlier detection (HBOS) for 3 | outlier detection 4 | """ 5 | # Author: Yue Zhao 6 | # License: BSD 2 clause 7 | 8 | import os 9 | import sys 10 | import time 11 | 12 | import torch 13 | from pyod.models.hbos import HBOS as HBOS_PyOD 14 | from pyod.utils.data import evaluate_print 15 | from pyod.utils.data import generate_data 16 | 17 | # temporary solution for relative imports in case pyod is not installed 18 | # if pyod is installed, no need to use the following line 19 | sys.path.append( 20 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 21 | 22 | from pytod.models.hbos import HBOS 23 | from pytod.utils.utility import validate_device 24 | 25 | contamination = 0.1 # percentage of outliers 26 | n_train = 500000 # number of training points 27 | n_test = 5000 # number of testing points 28 | n_features = 1000 29 | k = 10 30 | 31 | # Generate sample data 32 | X_train, X_test, y_train, y_test = \ 33 | generate_data(n_train=n_train, 34 | n_test=n_test, 35 | n_features=n_features, 36 | contamination=contamination, 37 | random_state=42) 38 | 39 | clf_name = 'HBOS-PyOD' 40 | clf = HBOS_PyOD() 41 | start = time.time() 42 | clf.fit(X_train) 43 | end = time.time() 44 | # get the prediction labels and outlier scores of the training data 45 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 46 | y_train_scores = clf.decision_scores_ # raw outlier scores 47 | 48 | # evaluate and print the results 49 | print("\nOn Training Data:") 50 | evaluate_print(clf_name, y_train, y_train_scores) 51 | pyod_time = end - start 52 | print('PyOD execution time', pyod_time) 53 | 54 | X_train, y_train, X_test, y_test = torch.from_numpy(X_train), \ 55 | torch.from_numpy(y_train), \ 56 | torch.from_numpy(X_test), \ 57 | torch.from_numpy(y_test) 58 | 59 | print() 60 | print() 61 | # try to access the GPU, fall back to cpu if no gpu is available 62 | device = validate_device(0) 63 | # device = 'cpu' 64 | clf_name = 'hbos-PyTOD' 65 | clf = HBOS(device=device) 66 | # clf = HBOS(n_neighbors=k, batch_size=None, device=device) 67 | start = time.time() 68 | clf.fit(X_train) 69 | end = time.time() 70 | # get the prediction labels and outlier scores of the training data 71 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 72 | y_train_scores = clf.decision_scores_ # raw outlier scores 73 | 74 | # evaluate and print the results 75 | print("\nOn Training Data:") 76 | evaluate_print(clf_name, y_train, y_train_scores) 77 | tod_time = end - start 78 | print('TOD execution time', tod_time) 79 | 80 | print('TOD is', round(pyod_time / tod_time, ndigits=2), 81 | 'times faster than PyOD') 82 | -------------------------------------------------------------------------------- /examples/knn_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using kNN for outlier detection 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import os 8 | import sys 9 | import time 10 | 11 | import torch 12 | from pyod.models.knn import KNN as KNN_PyOD 13 | from pyod.utils.data import evaluate_print 14 | from pyod.utils.data import generate_data 15 | 16 | # temporary solution for relative imports in case pyod is not installed 17 | # if pyod is installed, no need to use the following line 18 | sys.path.append( 19 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 20 | 21 | from pytod.models.knn import KNN 22 | from pytod.utils.utility import validate_device 23 | 24 | contamination = 0.1 # percentage of outliers 25 | n_train = 30000 # number of training points 26 | n_test = 5000 # number of testing points 27 | n_features = 20 28 | k = 10 29 | 30 | # Generate sample data 31 | X_train, X_test, y_train, y_test = \ 32 | generate_data(n_train=n_train, 33 | n_test=n_test, 34 | n_features=n_features, 35 | contamination=contamination, 36 | random_state=42) 37 | 38 | clf_name = 'KNN-PyOD' 39 | clf = KNN_PyOD(n_neighbors=k) 40 | start = time.time() 41 | clf.fit(X_train) 42 | end = time.time() 43 | # get the prediction labels and outlier scores of the training data 44 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 45 | y_train_scores = clf.decision_scores_ # raw outlier scores 46 | 47 | # evaluate and print the results 48 | print("\nOn Training Data:") 49 | evaluate_print(clf_name, y_train, y_train_scores) 50 | pyod_time = end - start 51 | print('Execution time', end - start) 52 | 53 | X_train, y_train, X_test, y_test = torch.from_numpy(X_train), \ 54 | torch.from_numpy(y_train), \ 55 | torch.from_numpy(X_test), \ 56 | torch.from_numpy(y_test) 57 | 58 | print() 59 | print() 60 | # try to access the GPU, fall back to cpu if no gpu is available 61 | device = validate_device(0) 62 | device = 'cpu' 63 | clf_name = 'KNN-PyTOD' 64 | # clf = KNN(n_neighbors=k, batch_size=10000, device=device) 65 | clf = KNN(n_neighbors=k, batch_size=None, device=device) 66 | start = time.time() 67 | clf.fit(X_train) 68 | end = time.time() 69 | # get the prediction labels and outlier scores of the training data 70 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 71 | y_train_scores = clf.decision_scores_ # raw outlier scores 72 | 73 | # evaluate and print the results 74 | print("\nOn Training Data:") 75 | evaluate_print(clf_name, y_train, y_train_scores) 76 | tod_time = end - start 77 | print('Execution time', end - start) 78 | 79 | print('TOD is', round(pyod_time / tod_time, ndigits=2), 80 | 'times faster than PyOD') 81 | -------------------------------------------------------------------------------- /examples/lof_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using LOF for outlier detection 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import os 8 | import sys 9 | import time 10 | 11 | import torch 12 | from pyod.models.lof import LOF as LOF_PyOD 13 | from pyod.utils.data import evaluate_print 14 | from pyod.utils.data import generate_data 15 | 16 | # temporary solution for relative imports in case pyod is not installed 17 | # if pyod is installed, no need to use the following line 18 | sys.path.append( 19 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 20 | 21 | from pytod.models.lof import LOF 22 | from pytod.utils.utility import validate_device 23 | 24 | contamination = 0.1 # percentage of outliers 25 | n_train = 30000 # number of training points 26 | n_test = 5000 # number of testing points 27 | n_features = 20 28 | k = 10 29 | 30 | # Generate sample data 31 | X_train, X_test, y_train, y_test = \ 32 | generate_data(n_train=n_train, 33 | n_test=n_test, 34 | n_features=n_features, 35 | contamination=contamination, 36 | random_state=42) 37 | 38 | clf_name = 'LOF-PyOD' 39 | clf = LOF_PyOD(n_neighbors=k) 40 | start = time.time() 41 | clf.fit(X_train) 42 | end = time.time() 43 | # get the prediction labels and outlier scores of the training data 44 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 45 | y_train_scores = clf.decision_scores_ # raw outlier scores 46 | 47 | # evaluate and print the results 48 | print("\nOn Training Data:") 49 | evaluate_print(clf_name, y_train, y_train_scores) 50 | pyod_time = end - start 51 | print('PyOD execution time', pyod_time) 52 | 53 | X_train, y_train, X_test, y_test = torch.from_numpy(X_train), \ 54 | torch.from_numpy(y_train), \ 55 | torch.from_numpy(X_test), \ 56 | torch.from_numpy(y_test) 57 | 58 | print() 59 | print() 60 | # try to access the GPU, fall back to cpu if no gpu is available 61 | device = validate_device(0) 62 | device = 'cpu' 63 | clf_name = 'lof-PyTOD' 64 | clf = LOF(n_neighbors=k, batch_size=10000, device=device) 65 | # clf = LOF(n_neighbors=k, batch_size=None, device=device) 66 | start = time.time() 67 | clf.fit(X_train) 68 | end = time.time() 69 | # get the prediction labels and outlier scores of the training data 70 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 71 | y_train_scores = clf.decision_scores_ # raw outlier scores 72 | 73 | # evaluate and print the results 74 | print("\nOn Training Data:") 75 | evaluate_print(clf_name, y_train, y_train_scores) 76 | tod_time = end - start 77 | print('TOD execution time', tod_time) 78 | 79 | print('TOD is', round(pyod_time / tod_time, ndigits=2), 80 | 'times faster than PyOD') 81 | -------------------------------------------------------------------------------- /examples/pca_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using PCA for outlier detection 3 | """ 4 | import os 5 | import sys 6 | import time 7 | 8 | # Author: Yue Zhao 9 | # License: BSD 2 clause 10 | import torch 11 | from pyod.models.pca import PCA as PCA_PyOD 12 | from pyod.utils.data import evaluate_print 13 | from pyod.utils.data import generate_data 14 | 15 | # temporary solution for relative imports in case pyod is not installed 16 | # if pyod is installed, no need to use the following line 17 | sys.path.append( 18 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 19 | 20 | from pytod.models.pca import PCA 21 | from pytod.utils.utility import validate_device 22 | 23 | contamination = 0.1 # percentage of outliers 24 | n_train = 1000000 # number of training points 25 | n_test = 5000 # number of testing points 26 | n_features = 200 27 | k = 10 28 | 29 | # Generate sample data 30 | X_train, X_test, y_train, y_test = \ 31 | generate_data(n_train=n_train, 32 | n_test=n_test, 33 | n_features=n_features, 34 | contamination=contamination, 35 | random_state=42) 36 | 37 | clf_name = 'PCA-PyOD' 38 | clf = PCA_PyOD(n_components=5) 39 | start = time.time() 40 | clf.fit(X_train) 41 | end = time.time() 42 | # get the prediction labels and outlier scores of the training data 43 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 44 | y_train_scores = clf.decision_scores_ # raw outlier scores 45 | 46 | # evaluate and print the results 47 | print("\nOn Training Data:") 48 | evaluate_print(clf_name, y_train, y_train_scores) 49 | pyod_time = end - start 50 | print('PyOD execution time', pyod_time) 51 | 52 | X_train, y_train, X_test, y_test = torch.from_numpy(X_train), \ 53 | torch.from_numpy(y_train), \ 54 | torch.from_numpy(X_test), \ 55 | torch.from_numpy(y_test) 56 | 57 | print() 58 | print() 59 | # try to access the GPU, fall back to cpu if no gpu is available 60 | device = validate_device(0) 61 | # device = 'cpu' 62 | clf_name = 'PCA-PyTOD' 63 | clf = PCA(n_components=k, device=device) 64 | # clf = PCA(n_neighbors=k, batch_size=None, device=device) 65 | start = time.time() 66 | clf.fit(X_train) 67 | end = time.time() 68 | # get the prediction labels and outlier scores of the training data 69 | y_train_pred = clf.labels_ # binary labels (0: inliers, 1: outliers) 70 | y_train_scores = clf.decision_scores_ # raw outlier scores 71 | 72 | # evaluate and print the results 73 | print("\nOn Training Data:") 74 | evaluate_print(clf_name, y_train, y_train_scores) 75 | tod_time = end - start 76 | print('TOD execution time', tod_time) 77 | 78 | print('TOD is', round(pyod_time / tod_time, ndigits=2), 79 | 'times faster than PyOD') 80 | -------------------------------------------------------------------------------- /figs/abstraction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/figs/abstraction.png -------------------------------------------------------------------------------- /figs/abstraction_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/figs/abstraction_example.png -------------------------------------------------------------------------------- /figs/run_time.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/figs/run_time.png -------------------------------------------------------------------------------- /pypi_build_commands.txt: -------------------------------------------------------------------------------- 1 | # This is a command list for building pypi packages 2 | 3 | python setup.py sdist 4 | twine check dist/* 5 | 6 | # docstring check 7 | pytest --doctest-modules pytod/ 8 | 9 | twine upload --repository pypitest dist/* 10 | # https://test.pypi.org/project/pytod/ 11 | 12 | twine upload dist/* 13 | 14 | https://pypi.org/project/pytod/ 15 | 16 | 17 | ####################################################### 18 | # For newly added models, conduct the following checks: 19 | 20 | 1. check the license, author information, and imports 21 | 2. read the algorithm introduction and citation 22 | 3. check the parameter order and correctness 23 | 4. check comment formats 24 | 5. make sure the test run locally (roc floor) 25 | 6. make sure the example look consistent 26 | 7. add algorithm to pytod.model.rst 27 | 8. add algorithm to index.rst and README.rst 28 | 9. add to benchmark.py and compare_all_models.py; when applicable, change notebooks as well 29 | 30 | conda activate myenv 31 | conda deactivate 32 | -------------------------------------------------------------------------------- /pytod/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/pytod/__init__.py -------------------------------------------------------------------------------- /pytod/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | -------------------------------------------------------------------------------- /pytod/models/basic_operators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Implementation of some basic operators. 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | 8 | import torch 9 | from torch import cdist as torch_cdist 10 | 11 | # disable autograd since no grad is needed 12 | torch.set_grad_enabled(False) 13 | 14 | 15 | def cdist(a, b=None, p=2, device='cpu'): 16 | """Basic cdist without using batching 17 | 18 | Parameters 19 | ---------- 20 | a 21 | b 22 | p 23 | 24 | Returns 25 | ------- 26 | 27 | """ 28 | if b is None: 29 | b = a 30 | return torch_cdist(b.to(device), b.to(device), p=p) 31 | else: 32 | return torch_cdist(a.to(device), b.to(device), p=p) 33 | 34 | 35 | # def cdist_s(a, b): 36 | # """Memory saving version of cdist 37 | # 38 | # Parameters 39 | # ---------- 40 | # a 41 | # b 42 | # 43 | # Returns 44 | # ------- 45 | # 46 | # """ 47 | # norm_a = torch.norm(a, dim=1).reshape(a.shape[0], 1) 48 | # norm_b = torch.norm(b, dim=1).reshape(1, b.shape[0]) 49 | # 50 | # w = norm_a ** 2 + norm_b ** 2 - 2 * torch.matmul(a, b.T) 51 | # return torch.sqrt(w) 52 | 53 | 54 | def topk(A, k, dim=1, device='cpu'): 55 | """Returns the k the largest elements of the given input tensor along a given dimension. 56 | 57 | Parameters 58 | ---------- 59 | A 60 | k 61 | dim 62 | 63 | Returns 64 | ------- 65 | values : tensor of shape (n_samples, k) 66 | Top k values. 67 | 68 | index : tensor of shape (n_samples, k) 69 | Top k indexes. 70 | 71 | """ 72 | if len(A.shape) == 1: 73 | dim = 0 74 | tk = torch.topk(A.to(device), k, dim=dim) 75 | return tk[0].cpu(), tk[1].cpu() 76 | 77 | 78 | def bottomk(A, k, dim=1, device='cpu'): 79 | if len(A.shape) == 1: 80 | dim = 0 81 | # tk = torch.topk(A * -1, k, dim=dim) 82 | # see parameter https://pytorch.org/docs/stable/generated/torch.topk.html 83 | tk = torch.topk(A.to(device), k, dim=dim, largest=False) 84 | return tk[0].cpu(), tk[1].cpu() 85 | 86 | 87 | def bottomk_cpu(A, k, dim=1): 88 | if len(A.shape) == 1: 89 | dim = 0 90 | # tk = torch.topk(A * -1, k, dim=dim) 91 | # see parameter https://pytorch.org/docs/stable/generated/torch.topk.html 92 | tk = torch.topk(A, k, dim=dim, largest=False) 93 | return tk[0], tk[1] 94 | 95 | 96 | def bottomk_low_prec(A, k, dim=1, mode='half', sort_value=False, device='cpu'): 97 | # in lower precision 98 | if mode == 'half': 99 | # do conversion first 100 | A_GPU = A.half().to(device) 101 | 102 | else: 103 | A_GPU = A.float().to(device) 104 | 105 | bottomk_dist, bottomk_indices = bottomk(A_GPU, k + 1) 106 | 107 | # get all the ambiguous indices with 2-element assumption 108 | amb_indices_p1 = torch.where(bottomk_dist[:, k] <= bottomk_dist[:, k - 1])[ 109 | 0] 110 | amb_indices_m1 = \ 111 | torch.where(bottomk_dist[:, k - 2] >= bottomk_dist[:, k - 1])[0] 112 | 113 | # there might be repetition, so we need to find the unique element only 114 | amb_indices = torch.unique(torch.cat((amb_indices_p1, amb_indices_m1))) 115 | 116 | print("ambiguous indices", len(amb_indices)) 117 | 118 | # recal_cdist = cdist_dist[amb_indices, :].double() 119 | A_GPU_recal = A[amb_indices, :].to(device) 120 | 121 | _, bottomk_indices[amb_indices, :k] = bottomk(A_GPU_recal, k) 122 | 123 | # drop the last bit k+1 124 | bottomk_indices = bottomk_indices[:, :k].cpu() 125 | 126 | # select by indices for bottom distance 127 | # https://stackoverflow.com/questions/58523290/select-mask-different-column-index-in-every-row 128 | bottomk_dist = A.gather(1, bottomk_indices) 129 | 130 | if sort_value: 131 | bottomk_dist_sorted, bottomk_indices_argsort = torch.sort(bottomk_dist, 132 | dim=dim) 133 | bottomk_indices_sorted = bottomk_indices.gather(1, 134 | bottomk_indices_argsort) 135 | return bottomk_dist_sorted, bottomk_indices_sorted 136 | else: 137 | return bottomk_dist, bottomk_indices 138 | 139 | 140 | def topk_low_prec(A, k, dim=1, mode='half', sort_value=False, device='cpu'): 141 | # in lower precision 142 | if mode == 'half': 143 | # do conversion first 144 | A_GPU = A.half().to(device) 145 | 146 | else: 147 | A_GPU = A.float().to(device) 148 | 149 | print(A_GPU) 150 | 151 | topk_dist, topk_indices = topk(A_GPU, k + 1) 152 | 153 | # topk(A, k+1) 154 | # print(A) 155 | # get all the ambiguous indices with 2-element assumption 156 | amb_indices_p1 = torch.where(topk_dist[:, k] >= topk_dist[:, k - 1])[0] 157 | amb_indices_m1 = torch.where(topk_dist[:, k - 2] <= topk_dist[:, k - 1])[0] 158 | 159 | # there might be repetition, so we need to find the unique element only 160 | amb_indices = torch.unique(torch.cat((amb_indices_p1, amb_indices_m1))) 161 | 162 | print("ambiguous indices", len(amb_indices)) 163 | 164 | A_GPU_recal = A[amb_indices, :].to(device) 165 | # recal_cdist = cdist_dist[amb_indices, :].double() 166 | _, topk_indices[amb_indices, :k] = topk(A_GPU_recal, k) 167 | 168 | # drop the last bit k+1 169 | topk_indices = topk_indices[:, :k].cpu() 170 | 171 | # select by indices for bottom distance 172 | # https://stackoverflow.com/questions/58523290/select-mask-different-column-index-in-every-row 173 | 174 | topk_dist = A.gather(1, topk_indices) 175 | 176 | if sort_value: 177 | topk_dist_sorted, topk_indices_argsort = torch.sort(topk_dist, dim=dim, 178 | descending=True) 179 | topk_indices_sorted = topk_indices.gather(1, topk_indices_argsort) 180 | return topk_dist_sorted, topk_indices_sorted 181 | else: 182 | return topk_dist, topk_indices 183 | 184 | 185 | def intersec1d(t1_orig, t2_orig, assume_unique=False, device='cpu'): 186 | t1_orig = t1_orig.to(device) 187 | t2_orig = t2_orig.to(device) 188 | # adapted from https://github.com/numpy/numpy/blob/v1.19.0/numpy/lib/arraysetops.py#L347-L441 189 | if assume_unique: 190 | aux = torch.cat((t1_orig, t2_orig)) 191 | else: 192 | t1 = torch.unique(t1_orig) 193 | t2 = torch.unique(t2_orig) 194 | aux = torch.cat((t1, t2)) 195 | 196 | aux = torch.sort(aux)[0] 197 | 198 | mask = aux[1:] == aux[:-1] 199 | int1d = aux[:-1][mask] 200 | # print(t1) 201 | # for i in int1d: 202 | # print('t1', (i==t1_orig).nonzero()) 203 | # print('t2', (i==t2_orig).nonzero()) 204 | 205 | return int1d.cpu() 206 | 207 | 208 | def intersecmulti(A, B, assume_unique=False): 209 | assert (A.shape[0] == B.shape[0]) 210 | n_samples = A.shape[0] 211 | 212 | intersec = [] 213 | intersec_count = [] 214 | for i in range(n_samples): 215 | intersec.append(intersec1d(A[i, :], B[i, :])) 216 | intersec_count.append(len(intersec[-1])) 217 | return intersec, intersec_count 218 | 219 | 220 | def post_check_intersection1d(t1, t2, intersect): 221 | for i in intersect: 222 | if i not in t1 or i not in t2: 223 | assert ('intersection error') 224 | 225 | 226 | def ecdf_multiple(X, device='cpu'): 227 | """Get the ECDF results per feature. GPU version is way faster than the 228 | CPU version. 229 | 230 | Parameters 231 | ---------- 232 | X : numpy array of shape (n_samples, n_features) 233 | The input samples. 234 | 235 | device : string of device (default='cpu') 236 | 237 | Returns 238 | ------- 239 | 240 | ECDF : numpy array of shape (n_samples, n_features) 241 | """ 242 | argx_tensor = torch.argsort(X.to(device), dim=0) 243 | y_tensor = torch.linspace(1 / X.shape[0], 1, X.shape[0]).to(device) 244 | return y_tensor[argx_tensor].cpu() 245 | 246 | 247 | def svd_randomized(M, k=10, device='cpu'): 248 | # http://gregorygundersen.com/blog/2019/01/17/randomized-svd/ 249 | # http://algorithm-interest-group.me/assets/slides/randomized_SVD.pdf 250 | n_samples, n_dims = M.shape[0], M.shape[1] 251 | P = torch.randn([n_dims, k]).float().to(device) 252 | M_P = torch.mm(M, P) 253 | Q, _ = torch.qr(M_P) 254 | B = torch.mm(Q.T, M) 255 | U, S, V = torch.svd(B) 256 | U = torch.mm(Q, U) 257 | 258 | return U, S, V 259 | 260 | 261 | def histt(a, bins=10, density=True, device='cpu'): 262 | def diff(a): 263 | # https://discuss.pytorch.org/t/equivalent-function-like-numpy-diff-in-pytorch/35327 264 | return a[1:] - a[:-1] 265 | 266 | # https://github.com/numpy/numpy/blob/v1.19.0/numpy/lib/histograms.py#L677-L928 267 | # for i in range(a.shape[1]): 268 | hist = torch.histc(a, bins=bins) 269 | # normalize histogram to sum to 1 270 | # hist = torch.true_divide(hist, hist.sum()) 271 | bin_edges = torch.linspace(a.min(), a.max(), steps=bins + 1).to(device) 272 | if density: 273 | hist_sum = hist.sum() 274 | db = diff(bin_edges).to(device) 275 | return torch.true_divide(hist, db) / hist_sum, bin_edges 276 | 277 | else: 278 | return hist, bin_edges 279 | -------------------------------------------------------------------------------- /pytod/models/basic_operators_batch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jan 14 21:44:34 2021 4 | """ 5 | 6 | import torch 7 | 8 | from .basic_operators import topk, bottomk, intersec1d, cdist 9 | from ..utils.utility import get_batch_index 10 | 11 | 12 | def cdist_batch(A, B, p=2.0, batch_size=None, device='cpu'): 13 | """Batch version of cdist 14 | 15 | Parameters 16 | ---------- 17 | A 18 | B 19 | p 20 | batch_size 21 | 22 | Returns 23 | ------- 24 | 25 | """ 26 | 27 | # batch is not needed 28 | if batch_size is None or batch_size >= A.shape[0]: 29 | return torch.cdist(A.to(device), B.to(device), p=p) 30 | 31 | if B is None: 32 | B = A 33 | 34 | n_samples, n_features = A.shape[0], A.shape[1] 35 | n_distance = B.shape[0] 36 | 37 | batch_index_A = get_batch_index(n_samples, batch_size) 38 | batch_index_B = get_batch_index(n_distance, batch_size) 39 | # print(batch_index_A) 40 | # print(batch_index_B) 41 | 42 | # this is a cpu tensor to save space 43 | cdist_mat = torch.zeros([n_samples, n_distance]) 44 | 45 | for i, index_A in enumerate(batch_index_A): 46 | for j, index_B in enumerate(batch_index_B): 47 | cdist_mat[index_A[0]:index_A[1], index_B[0]:index_B[1]] = \ 48 | cdist(A[index_A[0]:index_A[1], :], 49 | B[index_B[0]:index_B[1], :], 50 | device=device).cpu() 51 | # cdist_s(A[index_A[0]:index_A[1], :].to(device), 52 | # B[index_B[0]:index_B[1], :].to(device) 53 | # ).cpu() 54 | return cdist_mat 55 | 56 | 57 | def topk_batch(A, k, dim=1, batch_size=None, device='cpu'): 58 | if batch_size is None: 59 | print("original") 60 | return topk(A.to(device), k, dim) 61 | else: 62 | n_samples = A.shape[0] 63 | batch_index = get_batch_index(n_samples, batch_size) 64 | index_mat = torch.zeros([n_samples, k]) 65 | value_mat = torch.zeros([n_samples, k]) 66 | 67 | for i, index in enumerate(batch_index): 68 | print('batch', i) 69 | tk = topk(A[index[0]:index[1], :].to(device), k, dim=dim) 70 | value_mat[index[0]:index[1], :], index_mat[index[0]:index[1], :] = \ 71 | tk[0], tk[1] 72 | 73 | return value_mat, index_mat 74 | 75 | 76 | def bottomk_batch(A, k, dim=1, batch_size=None, device='cpu'): 77 | # half canm be a choice 78 | if batch_size is None: 79 | print("original") 80 | return bottomk(A.to(device), k, dim) 81 | 82 | else: 83 | n_samples = A.shape[0] 84 | batch_index = get_batch_index(n_samples, batch_size) 85 | index_mat = torch.zeros([n_samples, k]) 86 | value_mat = torch.zeros([n_samples, k]) 87 | 88 | for i, index in enumerate(batch_index): 89 | print('batch', i) 90 | tk = bottomk(A[index[0]:index[1], :].to(device), k, dim=dim) 91 | value_mat[index[0]:index[1], :], index_mat[index[0]:index[1], :] = \ 92 | tk[0], tk[1] 93 | 94 | return value_mat, index_mat 95 | 96 | 97 | def intersec1d_batch(t1, t2, batch_size=100000, device='cpu'): 98 | if batch_size >= len(t1) or batch_size >= len(t2): 99 | return intersec1d(t1, t2) 100 | 101 | batch_index_A = get_batch_index(len(t1), batch_size) 102 | batch_index_B = get_batch_index(len(t2), batch_size) 103 | print(batch_index_A) 104 | print(batch_index_B) 105 | 106 | # use cuda for fast computation 107 | candidate_set = torch.tensor([]).to(device) 108 | for i, index_A in enumerate(batch_index_A): 109 | for j, index_B in enumerate(batch_index_B): 110 | candidate_set = torch.cat((candidate_set, 111 | intersec1d(t1[index_A[0]:index_A[1]], 112 | t2[index_B[0]:index_B[1]])), 113 | dim=0) 114 | return torch.unique(candidate_set).cpu() 115 | -------------------------------------------------------------------------------- /pytod/models/ecod.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """k-Nearest Neighbors Detector (kNN) 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .base import BaseDetector 11 | from .basic_operators import ecdf_multiple 12 | 13 | 14 | class ECOD(BaseDetector): 15 | """ECOD class for Unsupervised Outlier Detection Using Empirical 16 | Cumulative Distribution Functions (ECOD) 17 | ECOD is a parameter-free, highly interpretable outlier detection algorithm 18 | based on empirical CDF functions. 19 | See :cite:`Li2021ecod` for details. 20 | 21 | Parameters 22 | ---------- 23 | contamination : float in (0., 0.5), optional (default=0.1) 24 | The amount of contamination of the data set, i.e. 25 | the proportion of outliers in the data set. Used when fitting to 26 | define the threshold on the decision function. 27 | 28 | 29 | Attributes 30 | ---------- 31 | decision_scores_ : numpy array of shape (n_samples,) 32 | The outlier scores of the training data. 33 | The higher, the more abnormal. Outliers tend to have higher 34 | scores. This value is available once the detector is 35 | fitted. 36 | threshold_ : float 37 | The threshold is based on ``contamination``. It is the 38 | ``n_samples * contamination`` most abnormal samples in 39 | ``decision_scores_``. The threshold is calculated for generating 40 | binary outlier labels. 41 | labels_ : int, either 0 or 1 42 | The binary labels of the training data. 0 stands for inliers 43 | and 1 for outliers/anomalies. It is generated by applying 44 | ``threshold_`` on ``decision_scores_``. 45 | """ 46 | 47 | def __init__(self, contamination=0.1, n_neighbors=5, batch_size=None, 48 | device='cuda:0'): 49 | super(ECOD, self).__init__(contamination=contamination) 50 | self.n_neighbors = n_neighbors 51 | self.device = device 52 | 53 | def fit(self, X, y=None, return_time=False): 54 | """Fit detector. y is ignored in unsupervised methods. 55 | 56 | Parameters 57 | ---------- 58 | X : numpy array of shape (n_samples, n_features) 59 | The input samples. 60 | 61 | y : Ignored 62 | Not used, present for API consistency by convention. 63 | 64 | return_time : boolean (default=True) 65 | If True, set self.gpu_time to the measured GPU time. 66 | 67 | Returns 68 | ------- 69 | self : object 70 | Fitted estimator. 71 | """ 72 | # todo: add one for pytorch tensor 73 | # X = check_array(X) 74 | self._set_n_classes(y) 75 | 76 | if self.device != 'cpu' and return_time: 77 | start = torch.cuda.Event(enable_timing=True) 78 | end = torch.cuda.Event(enable_timing=True) 79 | start.record() 80 | 81 | # density estimation via ECDF 82 | self.U_l = ecdf_multiple(X, device=self.device) 83 | self.U_r = ecdf_multiple(-X, device=self.device) 84 | 85 | if self.device != 'cpu' and return_time: 86 | end.record() 87 | torch.cuda.synchronize() 88 | 89 | # take the negative log 90 | self.U_l = -1 * torch.log(self.U_l) 91 | self.U_r = -1 * torch.log(self.U_r) 92 | 93 | # aggregate and generate outlier scores 94 | self.O = torch.maximum(self.U_l, self.U_r) 95 | self.decision_scores_ = torch.sum(self.O, dim=1).cpu().numpy() * -1 96 | 97 | self._process_decision_scores() 98 | 99 | # return GPU time in seconds 100 | if self.device != 'cpu' and return_time: 101 | self.gpu_time = start.elapsed_time(end) / 1000 102 | 103 | return self 104 | 105 | def decision_function(self, X): 106 | """Predict raw anomaly score of X using the fitted detector. 107 | For consistency, outliers are assigned with larger anomaly scores. 108 | Parameters 109 | ---------- 110 | X : numpy array of shape (n_samples, n_features) 111 | The training input samples. Sparse matrices are accepted only 112 | if they are supported by the base estimator. 113 | Returns 114 | ------- 115 | anomaly_scores : numpy array of shape (n_samples,) 116 | The anomaly score of the input samples. 117 | """ 118 | # use multi-thread execution 119 | if hasattr(self, 'X_train'): 120 | original_size = X.shape[0] 121 | X = np.concatenate((self.X_train, X), axis=0) 122 | 123 | # return decision_scores_.ravel() 124 | -------------------------------------------------------------------------------- /pytod/models/functional_operators.py: -------------------------------------------------------------------------------- 1 | from torch import cdist 2 | 3 | from pytod.models.basic_operators import bottomk 4 | 5 | 6 | def knn_full(A, B, k=5, p=2.0, device=None): 7 | """Get kNN in the non-batch way 8 | 9 | Parameters 10 | ---------- 11 | A 12 | B 13 | k 14 | p 15 | device 16 | 17 | Returns 18 | ------- 19 | 20 | """ 21 | dist_c = cdist(A.to(device), B.to(device), p=p) 22 | btk_d, btk_i = bottomk(dist_c, k=k, device=device) 23 | return btk_d.cpu(), btk_i.cpu() 24 | 25 | 26 | -------------------------------------------------------------------------------- /pytod/models/hbos.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Histogram-based Outlier Detection (HBOS) 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .base import BaseDetector 11 | from .basic_operators import histt 12 | 13 | 14 | class HBOS(BaseDetector): 15 | """Histogram-based outlier detection (HBOS) is an efficient unsupervised 16 | method. It assumes the feature independence and calculates the degree 17 | of outlyingness by building histograms. See :cite:`goldstein2012histogram` 18 | for details. 19 | 20 | 21 | Parameters 22 | ---------- 23 | n_bins : int or string, optional (default=10) 24 | The number of bins. "auto" uses the birge-rozenblac method for 25 | automatic selection of the optimal number of bins for each feature. 26 | 27 | alpha : float in (0, 1), optional (default=0.1) 28 | The regularizer for preventing overflow. 29 | 30 | contamination : float in (0., 0.5), optional (default=0.1) 31 | The amount of contamination of the data set, 32 | i.e. the proportion of outliers in the data set. Used when fitting to 33 | define the threshold on the decision function. 34 | 35 | batch_size : integer, optional (default = None) 36 | Number of samples to process per batch. 37 | 38 | device : str, optional (default = 'cpu') 39 | Valid device id, e.g., 'cuda:0' or 'cpu' 40 | 41 | Attributes 42 | ---------- 43 | bin_edges_ : numpy array of shape (n_bins + 1, n_features ) 44 | The edges of the bins. 45 | 46 | hist_ : numpy array of shape (n_bins, n_features) 47 | The density of each histogram. 48 | 49 | decision_scores_ : numpy array of shape (n_samples,) 50 | The outlier scores of the training data. 51 | The higher, the more abnormal. Outliers tend to have higher 52 | scores. This value is available once the detector is fitted. 53 | 54 | threshold_ : float 55 | The threshold is based on ``contamination``. It is the 56 | ``n_samples * contamination`` most abnormal samples in 57 | ``decision_scores_``. The threshold is calculated for generating 58 | binary outlier labels. 59 | 60 | labels_ : int, either 0 or 1 61 | The binary labels of the training data. 0 stands for inliers 62 | and 1 for outliers/anomalies. It is generated by applying 63 | ``threshold_`` on ``decision_scores_``. 64 | """ 65 | 66 | def __init__(self, contamination=0.1, n_bins=10, alpha=0.1, 67 | device='cuda:0'): 68 | super(HBOS, self).__init__(contamination=contamination) 69 | self.n_bins = n_bins 70 | self.alpha = alpha 71 | self.device = device 72 | 73 | def fit(self, X, y=None, return_time=False): 74 | """Fit detector. y is ignored in unsupervised methods. 75 | 76 | Parameters 77 | ---------- 78 | X : numpy array of shape (n_samples, n_features) 79 | The input samples. 80 | 81 | y : Ignored 82 | Not used, present for API consistency by convention. 83 | 84 | return_time : boolean (default=True) 85 | If True, set self.gpu_time to the measured GPU time. 86 | 87 | Returns 88 | ------- 89 | self : object 90 | Fitted estimator. 91 | """ 92 | # todo: add one for pytorch tensor 93 | # X = check_array(X) 94 | self._set_n_classes(y) 95 | 96 | if self.device != 'cpu' and return_time: 97 | start = torch.cuda.Event(enable_timing=True) 98 | end = torch.cuda.Event(enable_timing=True) 99 | start.record() 100 | 101 | X = X.to(self.device) 102 | n_samples, n_features = X.shape[0], X.shape[1] 103 | 104 | # initialize containers for calculation 105 | hist_ = torch.zeros([self.n_bins, n_features]).to(self.device) 106 | bin_edges = torch.zeros([self.n_bins + 1, n_features]).to(self.device) 107 | outlier_scores = torch.zeros([n_samples, n_features]).to(self.device) 108 | 109 | for i in range(n_features): 110 | hist_[:, i], bin_edges[:, i] = histt(X[:, i], bins=self.n_bins, 111 | device=self.device) 112 | 113 | hist_ = hist_.contiguous() 114 | bin_edges = bin_edges.contiguous() 115 | 116 | # conduct feature-wise binning 117 | for i in range(n_features): 118 | bin_inds = torch.bucketize(X[:, i], bin_edges[:, i]) 119 | out_score_i = torch.log2(hist_[:, i] + self.alpha) 120 | 121 | bin_inds[bin_inds == 0] = 1 122 | bin_inds[bin_inds == self.n_bins + 1] = self.n_bins 123 | outlier_scores[:, i] = out_score_i[bin_inds - 1] 124 | 125 | if self.device != 'cpu' and return_time: 126 | end.record() 127 | torch.cuda.synchronize() 128 | 129 | self.decision_scores_ = ( 130 | torch.sum(outlier_scores, dim=1) * -1).cpu().numpy() 131 | 132 | self._process_decision_scores() 133 | 134 | # return GPU time in seconds 135 | if self.device != 'cpu' and return_time: 136 | self.gpu_time = start.elapsed_time(end) / 1000 137 | 138 | return self 139 | 140 | def decision_function(self, X): 141 | """Predict raw anomaly score of X using the fitted detector. 142 | For consistency, outliers are assigned with larger anomaly scores. 143 | Parameters 144 | ---------- 145 | X : numpy array of shape (n_samples, n_features) 146 | The training input samples. Sparse matrices are accepted only 147 | if they are supported by the base estimator. 148 | Returns 149 | ------- 150 | anomaly_scores : numpy array of shape (n_samples,) 151 | The anomaly score of the input samples. 152 | """ 153 | # use multi-thread execution 154 | if hasattr(self, 'X_train'): 155 | original_size = X.shape[0] 156 | X = np.concatenate((self.X_train, X), axis=0) 157 | 158 | # return decision_scores_.ravel() 159 | -------------------------------------------------------------------------------- /pytod/models/knn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """k-Nearest Neighbors Detector (kNN) 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from .base import BaseDetector 11 | from .intermediate_layers import knn_batch 12 | 13 | 14 | class KNN(BaseDetector): 15 | # noinspection PyPep8 16 | """kNN class for outlier detection. 17 | For an observation, its distance to its kth nearest neighbor could be 18 | viewed as the outlying score. It could be viewed as a way to measure 19 | the density. See :cite:`ramaswamy2000efficient,angiulli2002fast` for 20 | details. 21 | 22 | Three kNN detectors are supported: 23 | largest: use the distance to the kth neighbor as the outlier score 24 | mean: use the average of all k neighbors as the outlier score 25 | median: use the median of the distance to k neighbors as the outlier score 26 | 27 | Parameters 28 | ---------- 29 | contamination : float in (0., 0.5), optional (default=0.1) 30 | The amount of contamination of the data set, 31 | i.e. the proportion of outliers in the data set. Used when fitting to 32 | define the threshold on the decision function. 33 | 34 | n_neighbors : int, optional (default=20) 35 | Number of neighbors to use by default for `kneighbors` queries. 36 | If n_neighbors is larger than the number of samples provided, 37 | all samples will be used. 38 | 39 | batch_size : integer, optional (default = None) 40 | Number of samples to process per batch. 41 | 42 | device : str, optional (default = 'cpu') 43 | Valid device id, e.g., 'cuda:0' or 'cpu' 44 | 45 | Attributes 46 | ---------- 47 | decision_scores_ : numpy array of shape (n_samples,) 48 | The outlier scores of the training data. 49 | The higher, the more abnormal. Outliers tend to have higher 50 | scores. This value is available once the detector is 51 | fitted. 52 | 53 | threshold_ : float 54 | The threshold is based on ``contamination``. It is the 55 | ``n_samples * contamination`` most abnormal samples in 56 | ``decision_scores_``. The threshold is calculated for generating 57 | binary outlier labels. 58 | 59 | labels_ : int, either 0 or 1 60 | The binary labels of the training data. 0 stands for inliers 61 | and 1 for outliers/anomalies. It is generated by applying 62 | ``threshold_`` on ``decision_scores_``. 63 | """ 64 | 65 | def __init__(self, contamination=0.1, n_neighbors=5, batch_size=None, 66 | device='cuda:0'): 67 | super(KNN, self).__init__(contamination=contamination) 68 | self.n_neighbors = n_neighbors 69 | self.batch_size = batch_size 70 | self.device = device 71 | 72 | def fit(self, X, y=None, return_time=False): 73 | """Fit detector. y is ignored in unsupervised methods. 74 | 75 | Parameters 76 | ---------- 77 | X : numpy array of shape (n_samples, n_features) 78 | The input samples. 79 | 80 | y : Ignored 81 | Not used, present for API consistency by convention. 82 | 83 | return_time : boolean (default=True) 84 | If True, set self.gpu_time to the measured GPU time. 85 | 86 | Returns 87 | ------- 88 | self : object 89 | Fitted estimator. 90 | """ 91 | # todo: add one for pytorch tensor 92 | # X = check_array(X) 93 | self._set_n_classes(y) 94 | 95 | if self.device != 'cpu' and return_time: 96 | start = torch.cuda.Event(enable_timing=True) 97 | end = torch.cuda.Event(enable_timing=True) 98 | start.record() 99 | 100 | knn_dist, _ = knn_batch(X, X, self.n_neighbors + 1, 101 | batch_size=self.batch_size, 102 | device=self.device) 103 | 104 | if self.device != 'cpu' and return_time: 105 | end.record() 106 | torch.cuda.synchronize() 107 | 108 | self.decision_scores_ = knn_dist[:, -1].cpu().numpy() 109 | self._process_decision_scores() 110 | 111 | # return GPU time in seconds 112 | if return_time: 113 | self.gpu_time = start.elapsed_time(end) / 1000 114 | 115 | return self 116 | 117 | def decision_function(self, X): 118 | """Predict raw anomaly score of X using the fitted detector. 119 | For consistency, outliers are assigned with larger anomaly scores. 120 | Parameters 121 | ---------- 122 | X : numpy array of shape (n_samples, n_features) 123 | The training input samples. Sparse matrices are accepted only 124 | if they are supported by the base estimator. 125 | Returns 126 | ------- 127 | anomaly_scores : numpy array of shape (n_samples,) 128 | The anomaly score of the input samples. 129 | """ 130 | # use multi-thread execution 131 | if hasattr(self, 'X_train'): 132 | original_size = X.shape[0] 133 | X = np.concatenate((self.X_train, X), axis=0) 134 | 135 | # return decision_scores_.ravel() 136 | -------------------------------------------------------------------------------- /pytod/models/lof.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Local Outlier Factor (LOF). Implemented on scikit-learn library. 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | 8 | import numpy as np 9 | import scipy as sp 10 | import torch 11 | 12 | from .base import BaseDetector 13 | from .intermediate_layers import knn_batch 14 | 15 | 16 | class LOF(BaseDetector): 17 | """Wrapper of scikit-learn LOF Class with more functionalities. 18 | Unsupervised Outlier Detection using Local Outlier Factor (LOF). 19 | 20 | The anomaly score of each sample is called Local Outlier Factor. 21 | It measures the local deviation of density of a given sample with 22 | respect to its neighbors. 23 | It is local in that the anomaly score depends on how isolated the object 24 | is with respect to the surrounding neighborhood. 25 | More precisely, locality is given by k-nearest neighbors, whose distance 26 | is used to estimate the local density. 27 | By comparing the local density of a sample to the local densities of 28 | its neighbors, one can identify samples that have a substantially lower 29 | density than their neighbors. These are considered outliers. 30 | See :cite:`breunig2000lof` for details. 31 | 32 | Parameters 33 | ---------- 34 | n_neighbors : int, optional (default=20) 35 | Number of neighbors to use by default for `kneighbors` queries. 36 | If n_neighbors is larger than the number of samples provided, 37 | all samples will be used. 38 | 39 | batch_size : integer, optional (default = None) 40 | Number of samples to process per batch. 41 | 42 | device : str, optional (default = 'cpu') 43 | Valid device id, e.g., 'cuda:0' or 'cpu' 44 | 45 | Attributes 46 | ---------- 47 | decision_scores_ : numpy array of shape (n_samples,) 48 | The outlier scores of the training data. 49 | The higher, the more abnormal. Outliers tend to have higher 50 | scores. This value is available once the detector is 51 | fitted. 52 | 53 | threshold_ : float 54 | The threshold is based on ``contamination``. It is the 55 | ``n_samples * contamination`` most abnormal samples in 56 | ``decision_scores_``. The threshold is calculated for generating 57 | binary outlier labels. 58 | 59 | labels_ : int, either 0 or 1 60 | The binary labels of the training data. 0 stands for inliers 61 | and 1 for outliers/anomalies. It is generated by applying 62 | ``threshold_`` on ``decision_scores_``. 63 | """ 64 | 65 | def __init__(self, contamination=0.1, n_neighbors=5, batch_size=None, 66 | device='cuda:0'): 67 | super(LOF, self).__init__(contamination=contamination) 68 | self.n_neighbors = n_neighbors 69 | self.batch_size = batch_size 70 | self.device = device 71 | 72 | def fit(self, X, y=None, return_time=False): 73 | """Fit detector. y is ignored in unsupervised methods. 74 | 75 | Parameters 76 | ---------- 77 | X : numpy array of shape (n_samples, n_features) 78 | The input samples. 79 | 80 | y : Ignored 81 | Not used, present for API consistency by convention. 82 | 83 | return_time : boolean (default=True) 84 | If True, set self.gpu_time to the measured GPU time. 85 | 86 | Returns 87 | ------- 88 | self : object 89 | Fitted estimator. 90 | """ 91 | # todo: add one for pytorch tensor 92 | # X = check_array(X) 93 | self._set_n_classes(y) 94 | 95 | if self.device != 'cpu' and return_time: 96 | start = torch.cuda.Event(enable_timing=True) 97 | end = torch.cuda.Event(enable_timing=True) 98 | start.record() 99 | 100 | # find the k nearst neighbors of all samples 101 | knn_dist, knn_inds = knn_batch(X, X, self.n_neighbors + 1, 102 | batch_size=self.batch_size, 103 | device=self.device) 104 | knn_dist, knn_inds = knn_dist[:, 1:], knn_inds[:, 1:] 105 | 106 | if self.device != 'cpu' and return_time: 107 | end.record() 108 | torch.cuda.synchronize() 109 | 110 | # this is the index of kNN's index 111 | knn_inds_flat = torch.flatten(knn_inds).long() 112 | knn_dist_flat = torch.flatten(knn_dist) 113 | 114 | # for each sample, find their kNN's *kth* neighbor's distance 115 | # -1 is for selecting the kth distance 116 | knn_kth_dist = torch.index_select(knn_dist, 0, knn_inds_flat)[:, -1] 117 | knn_kth_inds = torch.index_select(knn_inds, 0, knn_inds_flat)[:, -1] 118 | 119 | # to calculate the reachable distance, we need to compare these two distances 120 | raw_smaller = torch.where(knn_dist_flat < knn_kth_dist)[0] 121 | 122 | # let's override the place where it is not the case 123 | # this can save one variable 124 | knn_dist_flat[raw_smaller] = knn_kth_dist[raw_smaller] 125 | # print(knn_dist_flat[:10]) 126 | 127 | # then we need to calculate the average reachability distance 128 | 129 | # this result in [n, k] shape 130 | ar = torch.mean(knn_dist_flat.view(-1, self.n_neighbors), dim=1) 131 | 132 | # harmonic mean give the exact result! 133 | # todo: harmonic mean can be written in PyTorch as well 134 | ar_nn = sp.stats.hmean( 135 | torch.index_select(ar, 0, knn_inds_flat).view(-1, 136 | self.n_neighbors).numpy(), 137 | axis=1) 138 | assert (len(ar_nn) == len(ar)) 139 | 140 | self.decision_scores_ = (ar / ar_nn).cpu().numpy() 141 | 142 | self._process_decision_scores() 143 | 144 | # return GPU time in seconds 145 | if self.device != 'cpu' and return_time: 146 | self.gpu_time = start.elapsed_time(end) / 1000 147 | return self 148 | 149 | def decision_function(self, X): 150 | """Predict raw anomaly score of X using the fitted detector. 151 | For consistency, outliers are assigned with larger anomaly scores. 152 | Parameters 153 | ---------- 154 | X : numpy array of shape (n_samples, n_features) 155 | The training input samples. Sparse matrices are accepted only 156 | if they are supported by the base estimator. 157 | Returns 158 | ------- 159 | anomaly_scores : numpy array of shape (n_samples,) 160 | The anomaly score of the input samples. 161 | """ 162 | # use multi-thread execution 163 | if hasattr(self, 'X_train'): 164 | original_size = X.shape[0] 165 | X = np.concatenate((self.X_train, X), axis=0) 166 | 167 | # return decision_scores_.ravel() 168 | -------------------------------------------------------------------------------- /pytod/models/pca.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Principal Component Analysis (PCA) Outlier Detector 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from .base import BaseDetector 12 | 13 | 14 | class PCA(BaseDetector): 15 | """Principal component analysis (PCA) can be used in detecting outliers. 16 | PCA is a linear dimensionality reduction using Singular Value Decomposition 17 | of the data to project it to a lower dimensional space. 18 | 19 | In this procedure, covariance matrix of the data can be decomposed to 20 | orthogonal vectors, called eigenvectors, associated with eigenvalues. The 21 | eigenvectors with high eigenvalues capture most of the variance in the 22 | data. 23 | 24 | Therefore, a low dimensional hyperplane constructed by k eigenvectors can 25 | capture most of the variance in the data. However, outliers are different 26 | from normal data points, which is more obvious on the hyperplane 27 | constructed by the eigenvectors with small eigenvalues. 28 | 29 | Therefore, outlier scores can be obtained as the sum of the projected 30 | distance of a sample on all eigenvectors. 31 | See :cite:`shyu2003novel,aggarwal2015outlier` for details. 32 | 33 | Score(X) = Sum of weighted euclidean distance between each sample to the 34 | hyperplane constructed by the selected eigenvectors 35 | 36 | Parameters 37 | ---------- 38 | n_components : int, float, None or string 39 | Number of components to keep. 40 | if n_components is not set all components are kept:: 41 | 42 | n_components == min(n_samples, n_features) 43 | 44 | if n_components == 'mle' and svd_solver == 'full', Minka\'s MLE is used 45 | to guess the dimension 46 | if ``0 < n_components < 1`` and svd_solver == 'full', select the number 47 | of components such that the amount of variance that needs to be 48 | explained is greater than the percentage specified by n_components 49 | n_components cannot be equal to n_features for svd_solver == 'arpack'. 50 | 51 | device : str, optional (default = 'cpu') 52 | Valid device id, e.g., 'cuda:0' or 'cpu' 53 | 54 | Attributes 55 | ---------- 56 | decision_scores_ : numpy array of shape (n_samples,) 57 | The outlier scores of the training data. 58 | The higher, the more abnormal. Outliers tend to have higher 59 | scores. This value is available once the detector is 60 | fitted. 61 | 62 | threshold_ : float 63 | The threshold is based on ``contamination``. It is the 64 | ``n_samples * contamination`` most abnormal samples in 65 | ``decision_scores_``. The threshold is calculated for generating 66 | binary outlier labels. 67 | 68 | labels_ : int, either 0 or 1 69 | The binary labels of the training data. 0 stands for inliers 70 | and 1 for outliers/anomalies. It is generated by applying 71 | ``threshold_`` on ``decision_scores_``. 72 | """ 73 | 74 | def __init__(self, contamination=0.1, n_components=5, device='cuda:0'): 75 | super(PCA, self).__init__(contamination=contamination) 76 | self.n_components = n_components 77 | self.device = device 78 | 79 | def fit(self, X, y=None, return_time=False): 80 | """Fit detector. y is ignored in unsupervised methods. 81 | 82 | Parameters 83 | ---------- 84 | X : numpy array of shape (n_samples, n_features) 85 | The input samples. 86 | 87 | y : Ignored 88 | Not used, present for API consistency by convention. 89 | 90 | return_time : boolean (default=True) 91 | If True, set self.gpu_time to the measured GPU time. 92 | 93 | Returns 94 | ------- 95 | self : object 96 | Fitted estimator. 97 | """ 98 | # todo: add one for pytorch tensor 99 | # X = check_array(X) 100 | self._set_n_classes(y) 101 | 102 | if self.device != 'cpu' and return_time: 103 | start = torch.cuda.Event(enable_timing=True) 104 | end = torch.cuda.Event(enable_timing=True) 105 | start.record() 106 | 107 | X = X.to(self.device) 108 | 109 | U, S, V = torch.pca_lowrank(X, q=self.n_components) 110 | 111 | X_projected = torch.matmul(X, V) 112 | 113 | # https://ro-che.info/articles/2017-12-11-pca-explained-variance#:~:text=The%20total%20variance%20is%20the,divide%20by%20the%20total%20variance. 114 | vars_by_pc = torch.var(X_projected, dim=0) 115 | 116 | exaplained_var = vars_by_pc / vars_by_pc.sum() 117 | 118 | if self.device != 'cpu' and return_time: 119 | end.record() 120 | torch.cuda.synchronize() 121 | 122 | self.decision_scores_ = torch.sum(torch.cdist(X, V.T) / exaplained_var, 123 | dim=1).cpu().numpy() 124 | 125 | # return GPU time in seconds 126 | if return_time: 127 | self.gpu_time = start.elapsed_time(end) / 1000 128 | 129 | self._process_decision_scores() 130 | return self 131 | 132 | def decision_function(self, X): 133 | """Predict raw anomaly score of X using the fitted detector. 134 | For consistency, outliers are assigned with larger anomaly scores. 135 | Parameters 136 | ---------- 137 | X : numpy array of shape (n_samples, n_features) 138 | The training input samples. Sparse matrices are accepted only 139 | if they are supported by the base estimator. 140 | Returns 141 | ------- 142 | anomaly_scores : numpy array of shape (n_samples,) 143 | The anomaly score of the input samples. 144 | """ 145 | # use multi-thread execution 146 | if hasattr(self, 'X_train'): 147 | original_size = X.shape[0] 148 | X = np.concatenate((self.X_train, X), axis=0) 149 | 150 | # return decision_scores_.ravel() 151 | -------------------------------------------------------------------------------- /pytod/models/quantization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Benchmark of all implemented algorithms 3 | """ 4 | 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | 11 | import torch 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append( 16 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 17 | # supress warnings for clean output 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | import numpy as np 23 | from scipy.io import loadmat 24 | 25 | # from pyod.models.feature_bagging import FeatureBagging 26 | 27 | from pyod.utils.data import evaluate_print 28 | 29 | from pyod.utils.utility import standardizer 30 | 31 | from basic_operators import topk, bottomk, bottomk_low_prec, topk_low_prec 32 | 33 | from mpmath import mp, mpf 34 | 35 | machine_eps = mpf(2 ** -53) 36 | 37 | 38 | def get_bounded_error(max_value, dimension, machine_eps=np.finfo(float).eps, 39 | two_sided=True): 40 | mp.dps = 100 41 | factor = (1 + machine_eps) ** (mp.log(dimension) + 2) - 1 42 | if two_sided: 43 | return float(2 * (4 * dimension * (max_value ** 2) * factor)) 44 | else: 45 | return float(4 * dimension * (max_value ** 2) * factor) 46 | 47 | 48 | # print(get_bounded_error(1, 1000000)) 49 | # error_bound = float(get_bounded_error(1, 1000000)) 50 | 51 | # TODO: add neural networks, LOCI, SOS, COF, SOD 52 | 53 | # Define data file and read X and y 54 | mat_file_list = [ 55 | # 'annthyroid.mat', 56 | 'arrhythmia.mat', 57 | # 'breastw.mat', 58 | # 'glass.mat', 59 | # 'ionosphere.mat', 60 | # 'letter.mat', 61 | # 'lympho.mat', 62 | # 'mammography.mat', 63 | # 'mnist.mat', 64 | # 'musk.mat', 65 | 66 | # 'optdigits.mat', 67 | # 'pendigits.mat', 68 | # 'pima.mat', 69 | # 'satellite.mat', 70 | # 'satimage-2.mat', 71 | # # 'shuttle.mat', 72 | # # 'smtp_n.mat', 73 | # 'speech.mat', 74 | # 'thyroid.mat', 75 | # 'vertebral.mat', 76 | # 'vowels.mat', 77 | # 'wbc.mat', 78 | # 'wine.mat', 79 | ] 80 | 81 | mat_file = 'speech.mat' 82 | mat = loadmat(os.path.join("datasets", "ODDS", mat_file)) 83 | 84 | X = mat['X'] 85 | y = mat['y'].ravel() 86 | 87 | n_samples, n_features = X.shape[0], X.shape[1] 88 | 89 | outliers_fraction = np.count_nonzero(y) / len(y) 90 | outliers_percentage = round(outliers_fraction * 100, ndigits=4) 91 | 92 | # scaler = MinMaxScaler(feature_range=((1,2))) 93 | 94 | # X_transform = scaler.fit_transform(X) 95 | # a = rankdata(X, axis=0) 96 | # b = rankdata(X_transform, axis=0) 97 | 98 | X = standardizer(X) 99 | error_bound = get_bounded_error(np.max(X), n_features) 100 | print(error_bound) 101 | 102 | k = 10 103 | # X_train = torch.tensor(X).half().cuda() 104 | X_train = torch.tensor(X).float() 105 | # X_train = torch.tensor(X).double().cuda() 106 | 107 | 108 | cdist_dist = torch.cdist(X_train, X_train, p=2) 109 | 110 | bottomk_dist, bottomk_indices = bottomk(cdist_dist, k) 111 | bottomk_dist1, bottomk_indices1 = bottomk_low_prec(cdist_dist, k) 112 | 113 | # bottomk_dist_sorted, bottomk_indices_argsort = torch.sort(bottomk_dist1, dim=1) 114 | # bottomk_indices_sorted = bottomk_indices1.gather(1, bottomk_indices_argsort) 115 | print() 116 | print('bottomk is not sorted...') 117 | # we can only ensure the top k 118 | print(torch.sum((bottomk_dist[:, k - 1] != bottomk_dist1[:, k - 1]).int())) 119 | print( 120 | torch.sum((bottomk_indices[:, k - 1] != bottomk_indices1[:, k - 1]).int())) 121 | 122 | # we can only ensure the top k 123 | print(torch.sum((bottomk_dist != bottomk_dist1).int())) 124 | print(torch.sum((bottomk_indices != bottomk_indices1).int())) 125 | 126 | bottomk_dist2, bottomk_indices2 = bottomk_low_prec(cdist_dist, k, 127 | sort_value=True) 128 | print() 129 | print('bottomk is sorted...') 130 | # we ensure topk 131 | print(torch.sum((bottomk_dist[:, k - 1] != bottomk_dist2[:, k - 1]).int())) 132 | print( 133 | torch.sum((bottomk_indices[:, k - 1] != bottomk_indices2[:, k - 1]).int())) 134 | 135 | # we can ensure all 136 | print(torch.sum((bottomk_dist != bottomk_dist2).int())) 137 | print(torch.sum((bottomk_indices != bottomk_indices2).int())) 138 | 139 | # %% 140 | 141 | print() 142 | print('topk is not sorted...') 143 | 144 | topk_dist, topk_indices = topk(cdist_dist, k) 145 | topk_dist1, topk_indices1 = topk_low_prec(cdist_dist, k) 146 | 147 | # we can only ensure the top k 148 | print(torch.sum((topk_dist[:, k - 1] != topk_dist1[:, k - 1]).int())) 149 | print(torch.sum((topk_indices[:, k - 1] != topk_indices1[:, k - 1]).int())) 150 | 151 | print(torch.sum((topk_dist != topk_dist1).int())) 152 | print(torch.sum((topk_indices != topk_indices1).int())) 153 | 154 | topk_dist2, topk_indices2 = topk_low_prec(cdist_dist, k, sort_value=True) 155 | print() 156 | print('topk is sorted...') 157 | print(torch.sum((topk_dist[:, k - 1] != topk_dist2[:, k - 1]).int())) 158 | print(torch.sum((topk_indices[:, k - 1] != topk_indices2[:, k - 1]).int())) 159 | 160 | print(torch.sum((topk_dist != topk_dist2).int())) 161 | print(torch.sum((topk_indices != topk_indices2).int())) 162 | 163 | # here we flip the order 164 | decision_scores = bottomk_dist[:, -1] 165 | 166 | evaluate_print('knn', y, decision_scores.cpu()) 167 | 168 | # #%% 169 | # from basic_operators import topk, intersec1d 170 | # from pytorch_memlab import LineProfiler 171 | # from pytorch_memlab import MemReporter 172 | # import time 173 | 174 | 175 | # # t1 = torch.randint(low=0, high=20000000, size=[20000000]) 176 | # # t2 = torch.randint(low=5000000, high=25000000, size=[20000000]) 177 | 178 | # t1 = torch.rand(size=[50000000]) 179 | # t2 = torch.rand(size=[50000000]) 180 | 181 | 182 | # t1, t2 = t1.half().cuda(), t2.half().cuda() 183 | # # t1, t2 = t1.float().cuda(), t2.float().cuda() 184 | # # t1, t2 = t1.double().cuda(), t2.double().cuda() 185 | 186 | # def w(A, B): 187 | # return intersec1d(A, B) 188 | 189 | # with LineProfiler(w) as prof: 190 | # # distance_mat = batch_cdist(X_train_norm, X_train_norm, batch_size=5000) 191 | # start = time.time() 192 | # a = w(t1, t2) 193 | # end = time.time() 194 | # print(end - start) 195 | 196 | # print(prof.display()) 197 | 198 | 199 | # #%% 200 | # from basic_operators import topk 201 | # from pytorch_memlab import LineProfiler 202 | # from pytorch_memlab import MemReporter 203 | # import time 204 | 205 | # def Standardizer(X_train, mean=None, std=None, return_mean_std=False): 206 | 207 | # if mean is None: 208 | # mean = torch.mean(X_train, axis=0) 209 | # std = torch.std(X_train, axis=0) 210 | # # print(mean.shape, std.shape) 211 | # assert (mean.shape[0] == X_train.shape[1]) 212 | # assert (std.shape[0] == X_train.shape[1]) 213 | 214 | 215 | # X_train_norm = (X_train-mean)/std 216 | # assert(X_train_norm.shape == X_train.shape) 217 | 218 | # if return_mean_std: 219 | # return X_train_norm, mean, std 220 | # else: 221 | # return X_train_norm 222 | 223 | # contamination = 0.1 # percentage of outliers 224 | # n_train = 200000 # number of training points 225 | # n_test = 1000 # number of testing points 226 | # n_features = 2000 227 | 228 | # # Generate sample data 229 | # X_train, y_train, X_test, y_test = \ 230 | # generate_data(n_train=n_train, 231 | # n_test=n_test, 232 | # n_features=n_features, 233 | # contamination=contamination, 234 | # random_state=42) 235 | 236 | # k = 5 237 | 238 | 239 | # X_train = torch.tensor(X_train) 240 | # X_test = torch.tensor(X_test) 241 | 242 | # # X_train_norm, X_train_mean, X_train_std = Standardizer(X_train, return_mean_std=True) 243 | # # X_test_norm = Standardizer(X_test, mean=X_train_mean, std=X_train_std) 244 | 245 | 246 | # # X_train_norm = X_train.half().cuda() 247 | # # X_train_norm = X_train.float().cuda() 248 | # X_train_norm = X_train.double().cuda() 249 | # print(X_train_norm.type()) 250 | 251 | 252 | # def w(A, k): 253 | # return torch.topk(A, k) 254 | 255 | # with LineProfiler(w) as prof: 256 | # # distance_mat = batch_cdist(X_train_norm, X_train_norm, batch_size=5000) 257 | # start = time.time() 258 | # a,b = w(X_train_norm, k) 259 | # end = time.time() 260 | # print(end - start) 261 | 262 | # print(prof.display()) 263 | 264 | # %% 265 | -------------------------------------------------------------------------------- /pytod/models/sklearn_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Utility function copied over from sklearn/base.py 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | import numpy as np 11 | import six 12 | from joblib.parallel import cpu_count 13 | 14 | 15 | def _get_n_jobs(n_jobs): 16 | """Get number of jobs for the computation. 17 | See sklearn/utils/__init__.py for more information. 18 | 19 | This function reimplements the logic of joblib to determine the actual 20 | number of jobs depending on the cpu count. If -1 all CPUs are used. 21 | If 1 is given, no parallel computing code is used at all, which is useful 22 | for debugging. For n_jobs below -1, (n_cpus + 1 + n_jobs) are used. 23 | Thus for n_jobs = -2, all CPUs but one are used. 24 | Parameters 25 | ---------- 26 | n_jobs : int 27 | Number of jobs stated in joblib convention. 28 | Returns 29 | ------- 30 | n_jobs : int 31 | The actual number of jobs as positive integer. 32 | """ 33 | if n_jobs < 0: 34 | return max(cpu_count() + 1 + n_jobs, 1) 35 | elif n_jobs == 0: 36 | raise ValueError('Parameter n_jobs == 0 has no meaning.') 37 | else: 38 | return n_jobs 39 | 40 | 41 | def _partition_estimators(n_estimators, n_jobs): 42 | """Private function used to partition estimators between jobs. 43 | See sklearn/ensemble/base.py for more information. 44 | """ 45 | # Compute the number of jobs 46 | n_jobs = min(_get_n_jobs(n_jobs), n_estimators) 47 | 48 | # Partition estimators between jobs 49 | n_estimators_per_job = (n_estimators // n_jobs) * np.ones(n_jobs, 50 | dtype=np.int) 51 | n_estimators_per_job[:n_estimators % n_jobs] += 1 52 | starts = np.cumsum(n_estimators_per_job) 53 | 54 | return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist() 55 | 56 | 57 | def _pprint(params, offset=0, printer=repr): 58 | # noinspection PyPep8 59 | """Pretty print the dictionary 'params' 60 | 61 | See http://scikit-learn.org/stable/modules/generated/sklearn.base.BaseEstimator.html 62 | and sklearn/base.py for more information. 63 | 64 | :param params: The dictionary to pretty print 65 | :type params: dict 66 | 67 | :param offset: The offset in characters to add at the begin of each line. 68 | :type offset: int 69 | 70 | :param printer: The function to convert entries to strings, typically 71 | the builtin str or repr 72 | :type printer: callable 73 | 74 | :return: None 75 | """ 76 | 77 | # Do a multi-line justified repr: 78 | options = np.get_printoptions() 79 | np.set_printoptions(precision=5, threshold=64, edgeitems=2) 80 | params_list = list() 81 | this_line_length = offset 82 | line_sep = ',\n' + (1 + offset // 2) * ' ' 83 | for i, (k, v) in enumerate(sorted(six.iteritems(params))): 84 | if type(v) is float: 85 | # use str for representing floating point numbers 86 | # this way we get consistent representation across 87 | # architectures and versions. 88 | this_repr = '%s=%s' % (k, str(v)) 89 | else: 90 | # use repr of the rest 91 | this_repr = '%s=%s' % (k, printer(v)) 92 | if len(this_repr) > 500: 93 | this_repr = this_repr[:300] + '...' + this_repr[-100:] 94 | if i > 0: 95 | if this_line_length + len(this_repr) >= 75 or '\n' in this_repr: 96 | params_list.append(line_sep) 97 | this_line_length = len(line_sep) 98 | else: 99 | params_list.append(', ') 100 | this_line_length += 2 101 | params_list.append(this_repr) 102 | this_line_length += len(this_repr) 103 | 104 | np.set_printoptions(**options) 105 | lines = ''.join(params_list) 106 | # Strip trailing space to avoid nightmare in doctests 107 | lines = '\n'.join(l.rstrip(' ') for l in lines.split('\n')) 108 | return lines 109 | -------------------------------------------------------------------------------- /pytod/test/test_abod.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import unittest 9 | 10 | # noinspection PyProtectedMember 11 | from numpy.testing import assert_equal 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.abod import ABOD 18 | from pytod.utils.data import generate_data 19 | from pytod.utils.utility import validate_device 20 | 21 | 22 | class TestABOD(unittest.TestCase): 23 | def setUp(self): 24 | self.n_train = 200 25 | self.n_test = 100 26 | self.contamination = 0.1 27 | self.roc_floor = 0.8 28 | self.X_train, self.y_train, self.X_test, self.y_test = generate_data( 29 | n_train=self.n_train, n_test=self.n_test, 30 | contamination=self.contamination, random_state=42) 31 | 32 | device = 'cpu' 33 | self.clf = ABOD(contamination=self.contamination, device=device) 34 | self.clf.fit(self.X_train) 35 | 36 | device = validate_device(0) 37 | self.clf = ABOD(contamination=self.contamination, device=device) 38 | self.clf.fit(self.X_train) 39 | 40 | def test_parameters(self): 41 | assert (hasattr(self.clf, 'decision_scores_') and 42 | self.clf.decision_scores_ is not None) 43 | assert (hasattr(self.clf, 'labels_') and 44 | self.clf.labels_ is not None) 45 | assert (hasattr(self.clf, 'threshold_') and 46 | self.clf.threshold_ is not None) 47 | assert (hasattr(self.clf, '_mu') and 48 | self.clf._mu is not None) 49 | assert (hasattr(self.clf, '_sigma') and 50 | self.clf._sigma is not None) 51 | 52 | def test_train_scores(self): 53 | assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) 54 | 55 | # def test_prediction_scores(self): 56 | # pred_scores = self.clf.decision_function(self.X_test) 57 | # 58 | # # check score shapes 59 | # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) 60 | # 61 | # # check performance 62 | # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) 63 | # 64 | # def test_prediction_labels(self): 65 | # pred_labels = self.clf.predict(self.X_test) 66 | # assert_equal(pred_labels.shape, self.y_test.shape) 67 | # 68 | # def test_prediction_proba(self): 69 | # pred_proba = self.clf.predict_proba(self.X_test) 70 | # assert (pred_proba.min() >= 0) 71 | # assert (pred_proba.max() <= 1) 72 | # 73 | # def test_prediction_proba_linear(self): 74 | # pred_proba = self.clf.predict_proba(self.X_test, method='linear') 75 | # assert (pred_proba.min() >= 0) 76 | # assert (pred_proba.max() <= 1) 77 | # 78 | # def test_prediction_proba_unify(self): 79 | # pred_proba = self.clf.predict_proba(self.X_test, method='unify') 80 | # assert (pred_proba.min() >= 0) 81 | # assert (pred_proba.max() <= 1) 82 | # 83 | # def test_prediction_proba_parameter(self): 84 | # with assert_raises(ValueError): 85 | # self.clf.predict_proba(self.X_test, method='something') 86 | # 87 | # def test_prediction_labels_confidence(self): 88 | # pred_labels, confidence = self.clf.predict(self.X_test, 89 | # return_confidence=True) 90 | # assert_equal(pred_labels.shape, self.y_test.shape) 91 | # assert_equal(confidence.shape, self.y_test.shape) 92 | # assert (confidence.min() >= 0) 93 | # assert (confidence.max() <= 1) 94 | # 95 | # def test_prediction_proba_linear_confidence(self): 96 | # pred_proba, confidence = self.clf.predict_proba(self.X_test, 97 | # method='linear', 98 | # return_confidence=True) 99 | # assert (pred_proba.min() >= 0) 100 | # assert (pred_proba.max() <= 1) 101 | # 102 | # assert_equal(confidence.shape, self.y_test.shape) 103 | # assert (confidence.min() >= 0) 104 | # assert (confidence.max() <= 1) 105 | # 106 | # def test_fit_predict(self): 107 | # pred_labels = self.clf.fit_predict(self.X_train) 108 | # assert_equal(pred_labels.shape, self.y_train.shape) 109 | # 110 | # def test_fit_predict_score(self): 111 | # self.clf.fit_predict_score(self.X_test, self.y_test) 112 | # self.clf.fit_predict_score(self.X_test, self.y_test, 113 | # scoring='roc_auc_score') 114 | # self.clf.fit_predict_score(self.X_test, self.y_test, 115 | # scoring='prc_n_score') 116 | # with assert_raises(NotImplementedError): 117 | # self.clf.fit_predict_score(self.X_test, self.y_test, 118 | # scoring='something') 119 | # 120 | # def test_predict_rank(self): 121 | # pred_socres = self.clf.decision_function(self.X_test) 122 | # pred_ranks = self.clf._predict_rank(self.X_test) 123 | # 124 | # # assert the order is reserved 125 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 126 | # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) 127 | # assert_array_less(-0.1, pred_ranks) 128 | # 129 | # def test_predict_rank_normalized(self): 130 | # pred_socres = self.clf.decision_function(self.X_test) 131 | # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) 132 | # 133 | # # assert the order is reserved 134 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 135 | # assert_array_less(pred_ranks, 1.01) 136 | # assert_array_less(-0.1, pred_ranks) 137 | # 138 | # def test_model_clone(self): 139 | # clone_clf = clone(self.clf) 140 | 141 | def tearDown(self): 142 | pass 143 | 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /pytod/test/test_base.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import unittest 8 | 9 | import numpy as np 10 | from numpy.testing import assert_equal 11 | from numpy.testing import assert_raises 12 | 13 | # temporary solution for relative imports in case pytod is not installed 14 | # if pytod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.base import BaseDetector 18 | 19 | 20 | # Check sklearn\tests\test_base 21 | # A few test classes 22 | # noinspection PyMissingConstructor,PyPep8Naming 23 | class MyEstimator(BaseDetector): 24 | 25 | def __init__(self, l1=0, empty=None): 26 | self.l1 = l1 27 | self.empty = empty 28 | 29 | def fit(self, X, y=None): 30 | pass 31 | 32 | def decision_function(self, X): 33 | pass 34 | 35 | 36 | # noinspection PyMissingConstructor 37 | class K(BaseDetector): 38 | def __init__(self, c=None, d=None): 39 | self.c = c 40 | self.d = d 41 | 42 | def fit(self, X, y=None): 43 | pass 44 | 45 | def decision_function(self, X): 46 | pass 47 | 48 | 49 | # noinspection PyMissingConstructor 50 | class T(BaseDetector): 51 | def __init__(self, a=None, b=None): 52 | self.a = a 53 | self.b = b 54 | 55 | def fit(self, X, y=None): 56 | pass 57 | 58 | def decision_function(self, X): 59 | pass 60 | 61 | 62 | # noinspection PyMissingConstructor 63 | class ModifyInitParams(BaseDetector): 64 | """Deprecated behavior. 65 | Equal parameters but with a type cast. 66 | Doesn't fulfill a is a 67 | """ 68 | 69 | def __init__(self, a=np.array([0])): 70 | self.a = a.copy() 71 | 72 | def fit(self, X, y=None): 73 | pass 74 | 75 | def decision_function(self, X): 76 | pass 77 | 78 | 79 | # noinspection PyMissingConstructor 80 | class VargEstimator(BaseDetector): 81 | """scikit-learn estimators shouldn't have vargs.""" 82 | 83 | def __init__(self, *vargs): 84 | pass 85 | 86 | def fit(self, X, y=None): 87 | pass 88 | 89 | def decision_function(self, X): 90 | pass 91 | 92 | 93 | class Dummy1(BaseDetector): 94 | def __init__(self, contamination=0.1): 95 | super(Dummy1, self).__init__(contamination=contamination) 96 | 97 | def decision_function(self, X): 98 | pass 99 | 100 | def fit(self, X, y=None): 101 | pass 102 | 103 | 104 | class Dummy2(BaseDetector): 105 | def __init__(self, contamination=0.1): 106 | super(Dummy2, self).__init__(contamination=contamination) 107 | 108 | def decision_function(self, X): 109 | pass 110 | 111 | def fit(self, X, y=None): 112 | return X 113 | 114 | 115 | class Dummy3(BaseDetector): 116 | def __init__(self, contamination=0.1): 117 | super(Dummy3, self).__init__(contamination=contamination) 118 | 119 | def decision_function(self, X): 120 | pass 121 | 122 | def fit(self, X, y=None): 123 | self.labels_ = X 124 | 125 | 126 | class TestBASE(unittest.TestCase): 127 | def setUp(self): 128 | self.n_train = 100 129 | self.n_test = 50 130 | self.contamination = 0.1 131 | self.roc_floor = 0.6 132 | 133 | def test_init(self): 134 | """ 135 | Test base class initialization 136 | 137 | :return: 138 | """ 139 | self.dummy_clf = Dummy1() 140 | assert_equal(self.dummy_clf.contamination, 0.1) 141 | 142 | self.dummy_clf = Dummy1(contamination=0.2) 143 | assert_equal(self.dummy_clf.contamination, 0.2) 144 | 145 | with assert_raises(ValueError): 146 | Dummy1(contamination=0.51) 147 | 148 | with assert_raises(ValueError): 149 | Dummy1(contamination=0) 150 | 151 | with assert_raises(ValueError): 152 | Dummy1(contamination=-0.5) 153 | 154 | def test_fit(self): 155 | self.dummy_clf = Dummy2() 156 | assert_equal(self.dummy_clf.fit(0), 0) 157 | 158 | def test_fit_predict(self): 159 | # TODO: add more testcases 160 | 161 | self.dummy_clf = Dummy3() 162 | 163 | assert_equal(self.dummy_clf.fit_predict(0), 0) 164 | 165 | def test_predict_proba(self): 166 | # TODO: create uniform testcases 167 | pass 168 | 169 | def test_predict_confidence(self): 170 | # TODO: create uniform testcases 171 | pass 172 | 173 | def test_rank(self): 174 | # TODO: create uniform testcases 175 | pass 176 | 177 | def test_repr(self): 178 | # Smoke test the repr of the base estimator. 179 | my_estimator = MyEstimator() 180 | repr(my_estimator) 181 | test = T(K(), K()) 182 | assert_equal( 183 | repr(test), 184 | "T(a=K(c=None, d=None), b=K(c=None, d=None))" 185 | ) 186 | 187 | some_est = T(a=["long_params"] * 1000) 188 | assert_equal(len(repr(some_est)), 415) 189 | 190 | def test_str(self): 191 | # Smoke test the str of the base estimator 192 | my_estimator = MyEstimator() 193 | str(my_estimator) 194 | 195 | def test_get_params(self): 196 | test = T(K(), K()) 197 | 198 | assert ('a__d' in test.get_params(deep=True)) 199 | assert ('a__d' not in test.get_params(deep=False)) 200 | 201 | test.set_params(a__d=2) 202 | assert (test.a.d == 2) 203 | assert_raises(ValueError, test.set_params, a__a=2) 204 | 205 | def tearDown(self): 206 | pass 207 | 208 | 209 | if __name__ == '__main__': 210 | unittest.main() 211 | -------------------------------------------------------------------------------- /pytod/test/test_basic_operators.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os 6 | import sys 7 | import unittest 8 | 9 | import numpy as np 10 | import torch 11 | from numpy.testing import assert_equal 12 | from numpy.testing import assert_raises 13 | 14 | # temporary solution for relative imports in case pytod is not installed 15 | # if pytod is installed, no need to use the following line 16 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 17 | 18 | from pytod.utils.data import generate_data 19 | from pytod.models.basic_operators import cdist 20 | from pytod.models.basic_operators import topk 21 | from pytod.utils.utility import validate_device 22 | 23 | 24 | class TestCDIST(unittest.TestCase): 25 | 26 | def setUp(self): 27 | self.X = torch.Tensor([[1, 1], [2, 2], [3, 3]]) 28 | self.device = validate_device(0) 29 | 30 | def test_calc(self): 31 | dist = cdist(self.X, self.X, p=2, device=self.device) 32 | assert (dist.shape[0] - dist.shape[1] == 0) 33 | assert (torch.diagonal(dist).sum() == 0) 34 | 35 | class TestTOPK(unittest.TestCase): 36 | 37 | def setUp(self): 38 | self.X = torch.Tensor([[1, 1], [2, 2], [3, 3]]) 39 | self.device = validate_device(0) 40 | self.dist = cdist(self.X, self.X, p=2, device=self.device) 41 | 42 | def test_calc(self): 43 | topk_val, topk_ind = topk(self.dist, k=1, device=self.device) 44 | # print(topk_ind) 45 | # print(topk_ind.cpu().numpy().tolist()) 46 | assert (topk_ind.cpu().numpy().tolist() == [[2], [0], [0]]) 47 | # print(topk_val) 48 | # print(np.round(topk_val.cpu().numpy(), decimals=4).tolist()) 49 | assert (np.round(topk_val.cpu().numpy(), decimals=4).tolist() == [[2.828399896621704], [1.414199948310852], [2.828399896621704]]) 50 | 51 | -------------------------------------------------------------------------------- /pytod/test/test_ecod.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import unittest 9 | 10 | # noinspection PyProtectedMember 11 | from numpy.testing import assert_equal 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.ecod import ECOD 18 | from pytod.utils.data import generate_data 19 | from pytod.utils.utility import validate_device 20 | 21 | 22 | class TestECOD(unittest.TestCase): 23 | def setUp(self): 24 | self.n_train = 200 25 | self.n_test = 100 26 | self.contamination = 0.1 27 | self.roc_floor = 0.8 28 | self.X_train, self.y_train, self.X_test, self.y_test = generate_data( 29 | n_train=self.n_train, n_test=self.n_test, 30 | contamination=self.contamination, random_state=42) 31 | 32 | device = 'cpu' 33 | self.clf = ECOD(contamination=self.contamination, device=device) 34 | self.clf.fit(self.X_train) 35 | 36 | device = validate_device(0) 37 | self.clf = ECOD(contamination=self.contamination, device=device) 38 | self.clf.fit(self.X_train) 39 | 40 | def test_parameters(self): 41 | assert (hasattr(self.clf, 'decision_scores_') and 42 | self.clf.decision_scores_ is not None) 43 | assert (hasattr(self.clf, 'labels_') and 44 | self.clf.labels_ is not None) 45 | assert (hasattr(self.clf, 'threshold_') and 46 | self.clf.threshold_ is not None) 47 | assert (hasattr(self.clf, '_mu') and 48 | self.clf._mu is not None) 49 | assert (hasattr(self.clf, '_sigma') and 50 | self.clf._sigma is not None) 51 | 52 | def test_train_scores(self): 53 | assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) 54 | 55 | # def test_prediction_scores(self): 56 | # pred_scores = self.clf.decision_function(self.X_test) 57 | # 58 | # # check score shapes 59 | # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) 60 | # 61 | # # check performance 62 | # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) 63 | # 64 | # def test_prediction_labels(self): 65 | # pred_labels = self.clf.predict(self.X_test) 66 | # assert_equal(pred_labels.shape, self.y_test.shape) 67 | # 68 | # def test_prediction_proba(self): 69 | # pred_proba = self.clf.predict_proba(self.X_test) 70 | # assert (pred_proba.min() >= 0) 71 | # assert (pred_proba.max() <= 1) 72 | # 73 | # def test_prediction_proba_linear(self): 74 | # pred_proba = self.clf.predict_proba(self.X_test, method='linear') 75 | # assert (pred_proba.min() >= 0) 76 | # assert (pred_proba.max() <= 1) 77 | # 78 | # def test_prediction_proba_unify(self): 79 | # pred_proba = self.clf.predict_proba(self.X_test, method='unify') 80 | # assert (pred_proba.min() >= 0) 81 | # assert (pred_proba.max() <= 1) 82 | # 83 | # def test_prediction_proba_parameter(self): 84 | # with assert_raises(ValueError): 85 | # self.clf.predict_proba(self.X_test, method='something') 86 | # 87 | # def test_prediction_labels_confidence(self): 88 | # pred_labels, confidence = self.clf.predict(self.X_test, 89 | # return_confidence=True) 90 | # assert_equal(pred_labels.shape, self.y_test.shape) 91 | # assert_equal(confidence.shape, self.y_test.shape) 92 | # assert (confidence.min() >= 0) 93 | # assert (confidence.max() <= 1) 94 | # 95 | # def test_prediction_proba_linear_confidence(self): 96 | # pred_proba, confidence = self.clf.predict_proba(self.X_test, 97 | # method='linear', 98 | # return_confidence=True) 99 | # assert (pred_proba.min() >= 0) 100 | # assert (pred_proba.max() <= 1) 101 | # 102 | # assert_equal(confidence.shape, self.y_test.shape) 103 | # assert (confidence.min() >= 0) 104 | # assert (confidence.max() <= 1) 105 | # 106 | # def test_fit_predict(self): 107 | # pred_labels = self.clf.fit_predict(self.X_train) 108 | # assert_equal(pred_labels.shape, self.y_train.shape) 109 | # 110 | # def test_fit_predict_score(self): 111 | # self.clf.fit_predict_score(self.X_test, self.y_test) 112 | # self.clf.fit_predict_score(self.X_test, self.y_test, 113 | # scoring='roc_auc_score') 114 | # self.clf.fit_predict_score(self.X_test, self.y_test, 115 | # scoring='prc_n_score') 116 | # with assert_raises(NotImplementedError): 117 | # self.clf.fit_predict_score(self.X_test, self.y_test, 118 | # scoring='something') 119 | # 120 | # def test_predict_rank(self): 121 | # pred_socres = self.clf.decision_function(self.X_test) 122 | # pred_ranks = self.clf._predict_rank(self.X_test) 123 | # 124 | # # assert the order is reserved 125 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 126 | # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) 127 | # assert_array_less(-0.1, pred_ranks) 128 | # 129 | # def test_predict_rank_normalized(self): 130 | # pred_socres = self.clf.decision_function(self.X_test) 131 | # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) 132 | # 133 | # # assert the order is reserved 134 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 135 | # assert_array_less(pred_ranks, 1.01) 136 | # assert_array_less(-0.1, pred_ranks) 137 | # 138 | # def test_model_clone(self): 139 | # clone_clf = clone(self.clf) 140 | 141 | def tearDown(self): 142 | pass 143 | 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /pytod/test/test_hbos.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import unittest 9 | 10 | # noinspection PyProtectedMember 11 | from numpy.testing import assert_equal 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.hbos import HBOS 18 | from pytod.utils.data import generate_data 19 | from pytod.utils.utility import validate_device 20 | 21 | class TestHBOS(unittest.TestCase): 22 | def setUp(self): 23 | self.n_train = 200 24 | self.n_test = 100 25 | self.contamination = 0.1 26 | self.roc_floor = 0.8 27 | self.X_train, self.y_train, self.X_test, self.y_test = generate_data( 28 | n_train=self.n_train, n_test=self.n_test, 29 | contamination=self.contamination, random_state=42) 30 | 31 | device = 'cpu' 32 | self.clf = HBOS(contamination=self.contamination, device=device) 33 | self.clf.fit(self.X_train) 34 | 35 | device = validate_device(0) 36 | self.clf = HBOS(contamination=self.contamination, device=device) 37 | self.clf.fit(self.X_train) 38 | 39 | 40 | def test_parameters(self): 41 | assert (hasattr(self.clf, 'decision_scores_') and 42 | self.clf.decision_scores_ is not None) 43 | assert (hasattr(self.clf, 'labels_') and 44 | self.clf.labels_ is not None) 45 | assert (hasattr(self.clf, 'threshold_') and 46 | self.clf.threshold_ is not None) 47 | assert (hasattr(self.clf, '_mu') and 48 | self.clf._mu is not None) 49 | assert (hasattr(self.clf, '_sigma') and 50 | self.clf._sigma is not None) 51 | 52 | def test_train_scores(self): 53 | assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) 54 | 55 | # def test_prediction_scores(self): 56 | # pred_scores = self.clf.decision_function(self.X_test) 57 | # 58 | # # check score shapes 59 | # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) 60 | # 61 | # # check performance 62 | # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) 63 | # 64 | # def test_prediction_labels(self): 65 | # pred_labels = self.clf.predict(self.X_test) 66 | # assert_equal(pred_labels.shape, self.y_test.shape) 67 | # 68 | # def test_prediction_proba(self): 69 | # pred_proba = self.clf.predict_proba(self.X_test) 70 | # assert (pred_proba.min() >= 0) 71 | # assert (pred_proba.max() <= 1) 72 | # 73 | # def test_prediction_proba_linear(self): 74 | # pred_proba = self.clf.predict_proba(self.X_test, method='linear') 75 | # assert (pred_proba.min() >= 0) 76 | # assert (pred_proba.max() <= 1) 77 | # 78 | # def test_prediction_proba_unify(self): 79 | # pred_proba = self.clf.predict_proba(self.X_test, method='unify') 80 | # assert (pred_proba.min() >= 0) 81 | # assert (pred_proba.max() <= 1) 82 | # 83 | # def test_prediction_proba_parameter(self): 84 | # with assert_raises(ValueError): 85 | # self.clf.predict_proba(self.X_test, method='something') 86 | # 87 | # def test_prediction_labels_confidence(self): 88 | # pred_labels, confidence = self.clf.predict(self.X_test, 89 | # return_confidence=True) 90 | # assert_equal(pred_labels.shape, self.y_test.shape) 91 | # assert_equal(confidence.shape, self.y_test.shape) 92 | # assert (confidence.min() >= 0) 93 | # assert (confidence.max() <= 1) 94 | # 95 | # def test_prediction_proba_linear_confidence(self): 96 | # pred_proba, confidence = self.clf.predict_proba(self.X_test, 97 | # method='linear', 98 | # return_confidence=True) 99 | # assert (pred_proba.min() >= 0) 100 | # assert (pred_proba.max() <= 1) 101 | # 102 | # assert_equal(confidence.shape, self.y_test.shape) 103 | # assert (confidence.min() >= 0) 104 | # assert (confidence.max() <= 1) 105 | # 106 | # def test_fit_predict(self): 107 | # pred_labels = self.clf.fit_predict(self.X_train) 108 | # assert_equal(pred_labels.shape, self.y_train.shape) 109 | # 110 | # def test_fit_predict_score(self): 111 | # self.clf.fit_predict_score(self.X_test, self.y_test) 112 | # self.clf.fit_predict_score(self.X_test, self.y_test, 113 | # scoring='roc_auc_score') 114 | # self.clf.fit_predict_score(self.X_test, self.y_test, 115 | # scoring='prc_n_score') 116 | # with assert_raises(NotImplementedError): 117 | # self.clf.fit_predict_score(self.X_test, self.y_test, 118 | # scoring='something') 119 | # 120 | # def test_predict_rank(self): 121 | # pred_socres = self.clf.decision_function(self.X_test) 122 | # pred_ranks = self.clf._predict_rank(self.X_test) 123 | # 124 | # # assert the order is reserved 125 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 126 | # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) 127 | # assert_array_less(-0.1, pred_ranks) 128 | # 129 | # def test_predict_rank_normalized(self): 130 | # pred_socres = self.clf.decision_function(self.X_test) 131 | # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) 132 | # 133 | # # assert the order is reserved 134 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 135 | # assert_array_less(pred_ranks, 1.01) 136 | # assert_array_less(-0.1, pred_ranks) 137 | # 138 | # def test_model_clone(self): 139 | # clone_clf = clone(self.clf) 140 | 141 | def tearDown(self): 142 | pass 143 | 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /pytod/test/test_knn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import unittest 9 | 10 | # noinspection PyProtectedMember 11 | from numpy.testing import assert_equal 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.knn import KNN 18 | from pytod.utils.data import generate_data 19 | from pytod.utils.utility import validate_device 20 | 21 | class TestKnn(unittest.TestCase): 22 | def setUp(self): 23 | self.n_train = 200 24 | self.n_test = 100 25 | self.contamination = 0.1 26 | self.roc_floor = 0.8 27 | self.X_train, self.y_train, self.X_test, self.y_test = generate_data( 28 | n_train=self.n_train, n_test=self.n_test, 29 | contamination=self.contamination, random_state=42) 30 | 31 | device = 'cpu' 32 | self.clf = KNN(contamination=self.contamination, device=device) 33 | self.clf.fit(self.X_train) 34 | 35 | device = validate_device(0) 36 | self.clf = KNN(contamination=self.contamination, device=device) 37 | self.clf.fit(self.X_train) 38 | 39 | 40 | def test_parameters(self): 41 | assert (hasattr(self.clf, 'decision_scores_') and 42 | self.clf.decision_scores_ is not None) 43 | assert (hasattr(self.clf, 'labels_') and 44 | self.clf.labels_ is not None) 45 | assert (hasattr(self.clf, 'threshold_') and 46 | self.clf.threshold_ is not None) 47 | assert (hasattr(self.clf, '_mu') and 48 | self.clf._mu is not None) 49 | assert (hasattr(self.clf, '_sigma') and 50 | self.clf._sigma is not None) 51 | 52 | def test_train_scores(self): 53 | assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) 54 | 55 | # def test_prediction_scores(self): 56 | # pred_scores = self.clf.decision_function(self.X_test) 57 | # 58 | # # check score shapes 59 | # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) 60 | # 61 | # # check performance 62 | # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) 63 | # 64 | # def test_prediction_labels(self): 65 | # pred_labels = self.clf.predict(self.X_test) 66 | # assert_equal(pred_labels.shape, self.y_test.shape) 67 | # 68 | # def test_prediction_proba(self): 69 | # pred_proba = self.clf.predict_proba(self.X_test) 70 | # assert (pred_proba.min() >= 0) 71 | # assert (pred_proba.max() <= 1) 72 | # 73 | # def test_prediction_proba_linear(self): 74 | # pred_proba = self.clf.predict_proba(self.X_test, method='linear') 75 | # assert (pred_proba.min() >= 0) 76 | # assert (pred_proba.max() <= 1) 77 | # 78 | # def test_prediction_proba_unify(self): 79 | # pred_proba = self.clf.predict_proba(self.X_test, method='unify') 80 | # assert (pred_proba.min() >= 0) 81 | # assert (pred_proba.max() <= 1) 82 | # 83 | # def test_prediction_proba_parameter(self): 84 | # with assert_raises(ValueError): 85 | # self.clf.predict_proba(self.X_test, method='something') 86 | # 87 | # def test_prediction_labels_confidence(self): 88 | # pred_labels, confidence = self.clf.predict(self.X_test, 89 | # return_confidence=True) 90 | # assert_equal(pred_labels.shape, self.y_test.shape) 91 | # assert_equal(confidence.shape, self.y_test.shape) 92 | # assert (confidence.min() >= 0) 93 | # assert (confidence.max() <= 1) 94 | # 95 | # def test_prediction_proba_linear_confidence(self): 96 | # pred_proba, confidence = self.clf.predict_proba(self.X_test, 97 | # method='linear', 98 | # return_confidence=True) 99 | # assert (pred_proba.min() >= 0) 100 | # assert (pred_proba.max() <= 1) 101 | # 102 | # assert_equal(confidence.shape, self.y_test.shape) 103 | # assert (confidence.min() >= 0) 104 | # assert (confidence.max() <= 1) 105 | # 106 | # def test_fit_predict(self): 107 | # pred_labels = self.clf.fit_predict(self.X_train) 108 | # assert_equal(pred_labels.shape, self.y_train.shape) 109 | # 110 | # def test_fit_predict_score(self): 111 | # self.clf.fit_predict_score(self.X_test, self.y_test) 112 | # self.clf.fit_predict_score(self.X_test, self.y_test, 113 | # scoring='roc_auc_score') 114 | # self.clf.fit_predict_score(self.X_test, self.y_test, 115 | # scoring='prc_n_score') 116 | # with assert_raises(NotImplementedError): 117 | # self.clf.fit_predict_score(self.X_test, self.y_test, 118 | # scoring='something') 119 | # 120 | # def test_predict_rank(self): 121 | # pred_socres = self.clf.decision_function(self.X_test) 122 | # pred_ranks = self.clf._predict_rank(self.X_test) 123 | # 124 | # # assert the order is reserved 125 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 126 | # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) 127 | # assert_array_less(-0.1, pred_ranks) 128 | # 129 | # def test_predict_rank_normalized(self): 130 | # pred_socres = self.clf.decision_function(self.X_test) 131 | # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) 132 | # 133 | # # assert the order is reserved 134 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 135 | # assert_array_less(pred_ranks, 1.01) 136 | # assert_array_less(-0.1, pred_ranks) 137 | # 138 | # def test_model_clone(self): 139 | # clone_clf = clone(self.clf) 140 | 141 | def tearDown(self): 142 | pass 143 | 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /pytod/test/test_lof.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import unittest 9 | 10 | # noinspection PyProtectedMember 11 | from numpy.testing import assert_equal 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.lof import LOF 18 | from pytod.utils.data import generate_data 19 | from pytod.utils.utility import validate_device 20 | 21 | class TestLOF(unittest.TestCase): 22 | def setUp(self): 23 | self.n_train = 200 24 | self.n_test = 100 25 | self.contamination = 0.1 26 | self.roc_floor = 0.8 27 | self.X_train, self.y_train, self.X_test, self.y_test = generate_data( 28 | n_train=self.n_train, n_test=self.n_test, 29 | contamination=self.contamination, random_state=42) 30 | 31 | device = 'cpu' 32 | self.clf = LOF(contamination=self.contamination, device=device) 33 | self.clf.fit(self.X_train) 34 | 35 | device = validate_device(0) 36 | self.clf = LOF(contamination=self.contamination, device=device) 37 | self.clf.fit(self.X_train) 38 | 39 | 40 | def test_parameters(self): 41 | assert (hasattr(self.clf, 'decision_scores_') and 42 | self.clf.decision_scores_ is not None) 43 | assert (hasattr(self.clf, 'labels_') and 44 | self.clf.labels_ is not None) 45 | assert (hasattr(self.clf, 'threshold_') and 46 | self.clf.threshold_ is not None) 47 | assert (hasattr(self.clf, '_mu') and 48 | self.clf._mu is not None) 49 | assert (hasattr(self.clf, '_sigma') and 50 | self.clf._sigma is not None) 51 | 52 | def test_train_scores(self): 53 | assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) 54 | 55 | # def test_prediction_scores(self): 56 | # pred_scores = self.clf.decision_function(self.X_test) 57 | # 58 | # # check score shapes 59 | # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) 60 | # 61 | # # check performance 62 | # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) 63 | # 64 | # def test_prediction_labels(self): 65 | # pred_labels = self.clf.predict(self.X_test) 66 | # assert_equal(pred_labels.shape, self.y_test.shape) 67 | # 68 | # def test_prediction_proba(self): 69 | # pred_proba = self.clf.predict_proba(self.X_test) 70 | # assert (pred_proba.min() >= 0) 71 | # assert (pred_proba.max() <= 1) 72 | # 73 | # def test_prediction_proba_linear(self): 74 | # pred_proba = self.clf.predict_proba(self.X_test, method='linear') 75 | # assert (pred_proba.min() >= 0) 76 | # assert (pred_proba.max() <= 1) 77 | # 78 | # def test_prediction_proba_unify(self): 79 | # pred_proba = self.clf.predict_proba(self.X_test, method='unify') 80 | # assert (pred_proba.min() >= 0) 81 | # assert (pred_proba.max() <= 1) 82 | # 83 | # def test_prediction_proba_parameter(self): 84 | # with assert_raises(ValueError): 85 | # self.clf.predict_proba(self.X_test, method='something') 86 | # 87 | # def test_prediction_labels_confidence(self): 88 | # pred_labels, confidence = self.clf.predict(self.X_test, 89 | # return_confidence=True) 90 | # assert_equal(pred_labels.shape, self.y_test.shape) 91 | # assert_equal(confidence.shape, self.y_test.shape) 92 | # assert (confidence.min() >= 0) 93 | # assert (confidence.max() <= 1) 94 | # 95 | # def test_prediction_proba_linear_confidence(self): 96 | # pred_proba, confidence = self.clf.predict_proba(self.X_test, 97 | # method='linear', 98 | # return_confidence=True) 99 | # assert (pred_proba.min() >= 0) 100 | # assert (pred_proba.max() <= 1) 101 | # 102 | # assert_equal(confidence.shape, self.y_test.shape) 103 | # assert (confidence.min() >= 0) 104 | # assert (confidence.max() <= 1) 105 | # 106 | # def test_fit_predict(self): 107 | # pred_labels = self.clf.fit_predict(self.X_train) 108 | # assert_equal(pred_labels.shape, self.y_train.shape) 109 | # 110 | # def test_fit_predict_score(self): 111 | # self.clf.fit_predict_score(self.X_test, self.y_test) 112 | # self.clf.fit_predict_score(self.X_test, self.y_test, 113 | # scoring='roc_auc_score') 114 | # self.clf.fit_predict_score(self.X_test, self.y_test, 115 | # scoring='prc_n_score') 116 | # with assert_raises(NotImplementedError): 117 | # self.clf.fit_predict_score(self.X_test, self.y_test, 118 | # scoring='something') 119 | # 120 | # def test_predict_rank(self): 121 | # pred_socres = self.clf.decision_function(self.X_test) 122 | # pred_ranks = self.clf._predict_rank(self.X_test) 123 | # 124 | # # assert the order is reserved 125 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 126 | # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) 127 | # assert_array_less(-0.1, pred_ranks) 128 | # 129 | # def test_predict_rank_normalized(self): 130 | # pred_socres = self.clf.decision_function(self.X_test) 131 | # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) 132 | # 133 | # # assert the order is reserved 134 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 135 | # assert_array_less(pred_ranks, 1.01) 136 | # assert_array_less(-0.1, pred_ranks) 137 | # 138 | # def test_model_clone(self): 139 | # clone_clf = clone(self.clf) 140 | 141 | def tearDown(self): 142 | pass 143 | 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /pytod/test/test_pca.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os 7 | import sys 8 | import unittest 9 | 10 | # noinspection PyProtectedMember 11 | from numpy.testing import assert_equal 12 | 13 | # temporary solution for relative imports in case pyod is not installed 14 | # if pyod is installed, no need to use the following line 15 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 16 | 17 | from pytod.models.pca import PCA 18 | from pytod.utils.data import generate_data 19 | from pytod.utils.utility import validate_device 20 | 21 | 22 | class TestPCA(unittest.TestCase): 23 | def setUp(self): 24 | self.n_train = 2000 25 | self.n_test = 100 26 | self.contamination = 0.1 27 | self.roc_floor = 0.8 28 | self.X_train, self.y_train, self.X_test, self.y_test = generate_data( 29 | n_train=self.n_train, n_test=self.n_test, n_features=10, 30 | contamination=self.contamination, random_state=42) 31 | 32 | device = 'cpu' 33 | self.clf = PCA(contamination=self.contamination, device=device) 34 | self.clf.fit(self.X_train) 35 | 36 | device = validate_device(0) 37 | self.clf = PCA(contamination=self.contamination, device=device) 38 | self.clf.fit(self.X_train) 39 | 40 | def test_parameters(self): 41 | assert (hasattr(self.clf, 'decision_scores_') and 42 | self.clf.decision_scores_ is not None) 43 | assert (hasattr(self.clf, 'labels_') and 44 | self.clf.labels_ is not None) 45 | assert (hasattr(self.clf, 'threshold_') and 46 | self.clf.threshold_ is not None) 47 | assert (hasattr(self.clf, '_mu') and 48 | self.clf._mu is not None) 49 | assert (hasattr(self.clf, '_sigma') and 50 | self.clf._sigma is not None) 51 | 52 | def test_train_scores(self): 53 | assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0]) 54 | 55 | # def test_prediction_scores(self): 56 | # pred_scores = self.clf.decision_function(self.X_test) 57 | # 58 | # # check score shapes 59 | # assert_equal(pred_scores.shape[0], self.X_test.shape[0]) 60 | # 61 | # # check performance 62 | # assert (roc_auc_score(self.y_test, pred_scores) >= self.roc_floor) 63 | # 64 | # def test_prediction_labels(self): 65 | # pred_labels = self.clf.predict(self.X_test) 66 | # assert_equal(pred_labels.shape, self.y_test.shape) 67 | # 68 | # def test_prediction_proba(self): 69 | # pred_proba = self.clf.predict_proba(self.X_test) 70 | # assert (pred_proba.min() >= 0) 71 | # assert (pred_proba.max() <= 1) 72 | # 73 | # def test_prediction_proba_linear(self): 74 | # pred_proba = self.clf.predict_proba(self.X_test, method='linear') 75 | # assert (pred_proba.min() >= 0) 76 | # assert (pred_proba.max() <= 1) 77 | # 78 | # def test_prediction_proba_unify(self): 79 | # pred_proba = self.clf.predict_proba(self.X_test, method='unify') 80 | # assert (pred_proba.min() >= 0) 81 | # assert (pred_proba.max() <= 1) 82 | # 83 | # def test_prediction_proba_parameter(self): 84 | # with assert_raises(ValueError): 85 | # self.clf.predict_proba(self.X_test, method='something') 86 | # 87 | # def test_prediction_labels_confidence(self): 88 | # pred_labels, confidence = self.clf.predict(self.X_test, 89 | # return_confidence=True) 90 | # assert_equal(pred_labels.shape, self.y_test.shape) 91 | # assert_equal(confidence.shape, self.y_test.shape) 92 | # assert (confidence.min() >= 0) 93 | # assert (confidence.max() <= 1) 94 | # 95 | # def test_prediction_proba_linear_confidence(self): 96 | # pred_proba, confidence = self.clf.predict_proba(self.X_test, 97 | # method='linear', 98 | # return_confidence=True) 99 | # assert (pred_proba.min() >= 0) 100 | # assert (pred_proba.max() <= 1) 101 | # 102 | # assert_equal(confidence.shape, self.y_test.shape) 103 | # assert (confidence.min() >= 0) 104 | # assert (confidence.max() <= 1) 105 | # 106 | # def test_fit_predict(self): 107 | # pred_labels = self.clf.fit_predict(self.X_train) 108 | # assert_equal(pred_labels.shape, self.y_train.shape) 109 | # 110 | # def test_fit_predict_score(self): 111 | # self.clf.fit_predict_score(self.X_test, self.y_test) 112 | # self.clf.fit_predict_score(self.X_test, self.y_test, 113 | # scoring='roc_auc_score') 114 | # self.clf.fit_predict_score(self.X_test, self.y_test, 115 | # scoring='prc_n_score') 116 | # with assert_raises(NotImplementedError): 117 | # self.clf.fit_predict_score(self.X_test, self.y_test, 118 | # scoring='something') 119 | # 120 | # def test_predict_rank(self): 121 | # pred_socres = self.clf.decision_function(self.X_test) 122 | # pred_ranks = self.clf._predict_rank(self.X_test) 123 | # 124 | # # assert the order is reserved 125 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 126 | # assert_array_less(pred_ranks, self.X_train.shape[0] + 1) 127 | # assert_array_less(-0.1, pred_ranks) 128 | # 129 | # def test_predict_rank_normalized(self): 130 | # pred_socres = self.clf.decision_function(self.X_test) 131 | # pred_ranks = self.clf._predict_rank(self.X_test, normalized=True) 132 | # 133 | # # assert the order is reserved 134 | # assert_allclose(rankdata(pred_ranks), rankdata(pred_socres), atol=2) 135 | # assert_array_less(pred_ranks, 1.01) 136 | # assert_array_less(-0.1, pred_ranks) 137 | # 138 | # def test_model_clone(self): 139 | # clone_clf = clone(self.clf) 140 | 141 | def tearDown(self): 142 | pass 143 | 144 | 145 | if __name__ == '__main__': 146 | unittest.main() 147 | -------------------------------------------------------------------------------- /pytod/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/pytod/utils/__init__.py -------------------------------------------------------------------------------- /pytod/utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pyod.utils.data import evaluate_print as evaluate_print_np 4 | from pyod.utils.data import generate_data as generate_data_pyod 5 | from pyod.utils.utility import precision_n_scores 6 | from sklearn.metrics import roc_auc_score 7 | from sklearn.utils import check_consistent_length 8 | from sklearn.utils import column_or_1d 9 | 10 | 11 | def generate_data(n_train=1000, n_test=500, n_features=2, contamination=0.1, 12 | train_only=False, offset=10, 13 | random_state=None, n_nan=0, n_inf=0): 14 | """Utility function to generate synthesized data. 15 | Normal data is generated by a multivariate Gaussian distribution and 16 | outliers are generated by a uniform distribution. 17 | "X_train, X_test, y_train, y_test" are returned. 18 | 19 | Parameters 20 | ---------- 21 | n_train : int, (default=1000) 22 | The number of training points to generate. 23 | 24 | n_test : int, (default=500) 25 | The number of test points to generate. 26 | 27 | n_features : int, optional (default=2) 28 | The number of features (dimensions). 29 | 30 | contamination : float in (0., 0.5), optional (default=0.1) 31 | The amount of contamination of the data set, i.e. 32 | the proportion of outliers in the data set. Used when fitting to 33 | define the threshold on the decision function. 34 | 35 | train_only : bool, optional (default=False) 36 | If true, generate train data only. 37 | 38 | offset : int, optional (default=10) 39 | Adjust the value range of Gaussian and Uniform. 40 | 41 | random_state : int, RandomState instance or None, optional (default=None) 42 | If int, random_state is the seed used by the random number generator; 43 | If RandomState instance, random_state is the random number generator; 44 | If None, the random number generator is the RandomState instance used 45 | by `np.random`. 46 | 47 | n_nan : int 48 | The number of values that are missing (np.NaN). Defaults to zero. 49 | 50 | n_inf : int 51 | The number of values that are infinite. (np.infty). Defaults to zero. 52 | 53 | Returns 54 | ------- 55 | X_train : numpy array of shape (n_train, n_features) 56 | Training data. 57 | 58 | X_test : numpy array of shape (n_test, n_features) 59 | Test data. 60 | 61 | y_train : numpy array of shape (n_train,) 62 | Training ground truth. 63 | 64 | y_test : numpy array of shape (n_test,) 65 | Test ground truth. 66 | 67 | """ 68 | if train_only: 69 | X_train, y_train = generate_data_pyod(n_train, n_test, n_features, 70 | contamination, 71 | train_only, offset, 'new', 72 | random_state, n_nan, n_inf) 73 | return torch.from_numpy(X_train), torch.from_numpy(y_train) 74 | else: 75 | X_train, X_test, y_train, y_test = generate_data_pyod(n_train, n_test, 76 | n_features, 77 | contamination, 78 | train_only, 79 | offset, 'new', 80 | random_state, 81 | n_nan, n_inf) 82 | 83 | return torch.from_numpy(X_train), torch.from_numpy( 84 | y_train), torch.from_numpy(X_test), torch.from_numpy(y_test) 85 | 86 | 87 | def evaluate_print(clf_name, y, y_pred): 88 | """Utility function for evaluating and printing the results for examples. 89 | Default metrics include ROC and Precision @ n 90 | 91 | Parameters 92 | ---------- 93 | clf_name : str 94 | The name of the detector. 95 | 96 | y : list or numpy array of shape (n_samples,) 97 | The ground truth. Binary (0: inliers, 1: outliers). 98 | 99 | y_pred : list or numpy array of shape (n_samples,) 100 | The raw outlier scores as returned by a fitted model. 101 | 102 | """ 103 | 104 | if torch.is_tensor(y): 105 | y = y.cpu().numpy() 106 | y_pred = y_pred.cpu().numpy() 107 | return evaluate_print_np(clf_name, y, y_pred) 108 | 109 | 110 | def get_roc(y, y_pred): 111 | """Utility function for evaluating the results for examples. 112 | Default metrics include ROC 113 | 114 | Parameters 115 | ---------- 116 | y : list or numpy array of shape (n_samples,) 117 | The ground truth. Binary (0: inliers, 1: outliers). 118 | 119 | y_pred : list or numpy array of shape (n_samples,) 120 | The raw outlier scores as returned by a fitted model. 121 | 122 | """ 123 | y = column_or_1d(y) 124 | y_pred = column_or_1d(y_pred) 125 | check_consistent_length(y, y_pred) 126 | 127 | return np.round(roc_auc_score(y, y_pred), decimals=4) 128 | 129 | 130 | def get_prn(y, y_pred): 131 | """Utility function for evaluating the results for examples. 132 | Default metrics include P@N 133 | 134 | Parameters 135 | ---------- 136 | y : list or numpy array of shape (n_samples,) 137 | The ground truth. Binary (0: inliers, 1: outliers). 138 | 139 | y_pred : list or numpy array of shape (n_samples,) 140 | The raw outlier scores as returned by a fitted model. 141 | 142 | """ 143 | y = column_or_1d(y) 144 | y_pred = column_or_1d(y_pred) 145 | check_consistent_length(y, y_pred) 146 | 147 | return np.round(precision_n_scores(y, y_pred), decimals=4) 148 | 149 | 150 | def Standardizer(X_train, mean=None, std=None, return_mean_std=False): 151 | """standardize the data with zero mean and unit variance. 152 | 153 | Parameters 154 | ---------- 155 | X_train 156 | mean 157 | std 158 | return_mean_std 159 | 160 | Returns 161 | ------- 162 | 163 | """ 164 | if mean is None: 165 | mean = torch.mean(X_train, axis=0) 166 | std = torch.std(X_train, axis=0) 167 | 168 | assert (mean.shape[0] == X_train.shape[1]) 169 | assert (std.shape[0] == X_train.shape[1]) 170 | 171 | X_train_norm = (X_train - mean) / std 172 | assert (X_train_norm.shape == X_train.shape) 173 | 174 | if return_mean_std: 175 | return X_train_norm, mean, std 176 | else: 177 | return X_train_norm 178 | -------------------------------------------------------------------------------- /pytod/utils/utility.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """A set of utility functions to support outlier detection. 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import warnings 8 | 9 | import numpy as np 10 | import torch 11 | from numpy import percentile 12 | from pyod.utils.utility import check_parameter 13 | from sklearn.metrics import precision_score 14 | from sklearn.utils import check_consistent_length 15 | from sklearn.utils import column_or_1d 16 | 17 | 18 | def validate_device(gpu_id): 19 | """Validate the input device id (GPU id) is valid on the given 20 | machine. If no GPU is presented, return 'cpu'. 21 | Parameters 22 | ---------- 23 | gpu_id : int 24 | GPU id to be used. The function will validate the usability 25 | of the GPU. If failed, return device as 'cpu'. 26 | Returns 27 | ------- 28 | device_id : str 29 | Valid device id, e.g., 'cuda:0' or 'cpu' 30 | """ 31 | # if it is cpu 32 | if gpu_id == -1: 33 | return 'cpu' 34 | 35 | # cast to int for checking 36 | gpu_id = int(gpu_id) 37 | 38 | # if gpu is available 39 | if torch.cuda.is_available(): 40 | # check if gpu id is between 0 and the total number of GPUs 41 | check_parameter(gpu_id, 0, torch.cuda.device_count(), 42 | param_name='gpu id', include_left=True, 43 | include_right=False) 44 | device_id = 'cuda:{}'.format(gpu_id) 45 | else: 46 | if gpu_id != 'cpu': 47 | warnings.warn('The cuda is not available. Set to cpu.') 48 | device_id = 'cpu' 49 | 50 | return device_id 51 | 52 | 53 | def Standardizer(X_train, mean=None, std=None, return_mean_std=False): 54 | if mean is None: 55 | mean = torch.mean(X_train, axis=0) 56 | std = torch.std(X_train, axis=0) 57 | # print(mean.shape, std.shape) 58 | assert (mean.shape[0] == X_train.shape[1]) 59 | assert (std.shape[0] == X_train.shape[1]) 60 | 61 | X_train_norm = (X_train - mean) / std 62 | assert (X_train_norm.shape == X_train.shape) 63 | 64 | if return_mean_std: 65 | return X_train_norm, mean, std 66 | else: 67 | return X_train_norm 68 | 69 | 70 | def get_batch_index(n_samples, batch_size): 71 | """Turning 1-dimensional space into equal chunk and return the index pairs. 72 | 73 | Parameters 74 | ---------- 75 | n_samples 76 | batch_size 77 | 78 | Returns 79 | ------- 80 | 81 | """ 82 | 83 | if n_samples <= batch_size: 84 | return [(0, n_samples)] 85 | 86 | index_tracker = [] 87 | n_batches = int(np.ceil(n_samples // batch_size)) 88 | # print('n_batches', n_batches) 89 | tracker = 0 90 | left_index, right_index = 0, 0 91 | for i in range(n_batches): 92 | left_index = tracker * batch_size 93 | right_index = left_index + batch_size 94 | tracker += 1 95 | # print(left_index, right_index) 96 | index_tracker.append((left_index, right_index)) 97 | 98 | if n_samples % batch_size != 0: 99 | left_index = right_index 100 | right_index = n_samples 101 | # print(left_index, right_index) 102 | index_tracker.append((left_index, right_index)) 103 | return index_tracker 104 | 105 | 106 | def get_label_n(y, y_pred, n=None): 107 | """Function to turn raw outlier scores into binary labels by assign 1 108 | to top n outlier scores. 109 | 110 | Parameters 111 | ---------- 112 | y : list or numpy array of shape (n_samples,) 113 | The ground truth. Binary (0: inliers, 1: outliers). 114 | 115 | y_pred : list or numpy array of shape (n_samples,) 116 | The raw outlier scores as returned by a fitted model. 117 | 118 | n : int, optional (default=None) 119 | The number of outliers. if not defined, infer using ground truth. 120 | 121 | Returns 122 | ------- 123 | labels : numpy array of shape (n_samples,) 124 | binary labels 0: normal points and 1: outliers 125 | 126 | Examples 127 | -------- 128 | >>> from pytod.utils.utility import get_label_n 129 | >>> y = [0, 1, 1, 0, 0] 130 | >>> y_pred = [0.1, 0.5, 0.3, 0.2, 0.7] 131 | >>> get_label_n(y, y_pred) 132 | array([0, 1, 0, 0, 1]) 133 | 134 | """ 135 | 136 | # enforce formats of inputs 137 | y = column_or_1d(y) 138 | y_pred = column_or_1d(y_pred) 139 | 140 | check_consistent_length(y, y_pred) 141 | y_len = len(y) # the length of targets 142 | 143 | # calculate the percentage of outliers 144 | if n is not None: 145 | outliers_fraction = n / y_len 146 | else: 147 | outliers_fraction = np.count_nonzero(y) / y_len 148 | 149 | threshold = percentile(y_pred, 100 * (1 - outliers_fraction)) 150 | y_pred = (y_pred > threshold).astype('int') 151 | 152 | return y_pred 153 | 154 | 155 | def precision_n_scores(y, y_pred, n=None): 156 | """Utility function to calculate precision @ rank n. 157 | 158 | Parameters 159 | ---------- 160 | y : list or numpy array of shape (n_samples,) 161 | The ground truth. Binary (0: inliers, 1: outliers). 162 | 163 | y_pred : list or numpy array of shape (n_samples,) 164 | The raw outlier scores as returned by a fitted model. 165 | 166 | n : int, optional (default=None) 167 | The number of outliers. if not defined, infer using ground truth. 168 | 169 | Returns 170 | ------- 171 | precision_at_rank_n : float 172 | Precision at rank n score. 173 | 174 | """ 175 | 176 | # turn raw prediction decision scores into binary labels 177 | y_pred = get_label_n(y, y_pred, n) 178 | 179 | # enforce formats of y and labels_ 180 | y = column_or_1d(y) 181 | y_pred = column_or_1d(y_pred) 182 | 183 | return precision_score(y, y_pred) 184 | -------------------------------------------------------------------------------- /pytod/version.py: -------------------------------------------------------------------------------- 1 | """ 2 | ``pytod`` is a GPU-based system for fast outlier detection 3 | """ 4 | # Based on NiLearn package 5 | # License: simplified BSD 6 | 7 | # PEP0440 compatible formatted version, see: 8 | # https://www.python.org/dev/peps/pep-0440/ 9 | # 10 | # Generic release markers: 11 | # X.Y 12 | # X.Y.Z # For bug fix releases 13 | # 14 | # Admissible pre-release markers: 15 | # X.YaN # Alpha release 16 | # X.YbN # Beta release 17 | # X.YrcN # Release Candidate 18 | # X.Y # Final release 19 | # 20 | # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. 21 | # 'X.Y.dev0' is the canonical version of 'X.Y.dev' 22 | # 23 | __version__ = '0.0.3' # pragma: no cover 24 | -------------------------------------------------------------------------------- /reproducibility/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/__init__.py -------------------------------------------------------------------------------- /reproducibility/additional_scripts/multi-knn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Apr 27 14:24:54 2021 4 | 5 | @author: yuezh 6 | """ 7 | import time 8 | import torch 9 | import torch.multiprocessing as mp 10 | import numpy as np 11 | 12 | from basic_operators import bottomk 13 | # from basic_operators_batch import cdist_batch 14 | from basic_operators_batch import bottomk_batch 15 | from utility import get_batch_index 16 | 17 | def bottomk(A, k, dim=1): 18 | if len(A.shape) == 1: 19 | dim = 0 20 | # tk = torch.topk(A * -1, k, dim=dim) 21 | # see parameter https://pytorch.org/docs/stable/generated/torch.topk.html 22 | tk = torch.topk(A, k, dim=dim, largest=False) 23 | return tk[0].cpu(), tk[1].cpu() 24 | 25 | # https://discuss.pytorch.org/t/using-torch-tensor-over-multiprocessing-queue-process-fails/2847/12 26 | 27 | 28 | 29 | 30 | def cdist_batch_k(A, B, p=2.0, gpu_id=0, batch_size=None): 31 | # def cdist_batch(A, B=None, batch_size=None): 32 | # TODO: whether to half can be a parameter 33 | # TODO: should pass other possible hyperparameters to torch.cdist 34 | 35 | if B is None: 36 | B = A 37 | 38 | n_samples, n_features = A.shape[0], A.shape[1] 39 | n_distance = B.shape[0] 40 | 41 | # batch is not needed 42 | if batch_size is None or batch_size >= n_samples: 43 | print('direct cal') 44 | kd, ki = bottomk(torch.cdist(A.to(gpu_id), B.to(gpu_id), p=p), k=10) 45 | return kd, ki 46 | 47 | batch_index_A = get_batch_index(n_samples, batch_size) 48 | batch_index_B = get_batch_index(n_distance, batch_size) 49 | print(batch_index_A) 50 | print(batch_index_B) 51 | 52 | # this is a cpu tensor to save space 53 | # cdist_mat = torch.zeros([n_samples, n_distance]) 54 | 55 | # print(gpu_id, gpu_id, gpu_id, gpu_id) 56 | kd_list = [] 57 | for i, index_A in enumerate(batch_index_A): 58 | for j, index_B in enumerate(batch_index_B): 59 | for t in range(3): 60 | kd, ki = bottomk(torch.cdist(A[index_A[0]:index_A[1], :].to(gpu_id), 61 | B[index_B[0]:index_B[1], :].to(gpu_id), p=p), 10) 62 | 63 | # kd, ki = bottomk(torch.cdist(A[index_A[0]:index_A[1], :].cuda(gpu_id), 64 | # B[index_B[0]:index_B[1], :].cuda(gpu_id), p=p), 65 | # 10) 66 | # kd, ki = bottomk(torch.cdist(A[index_A[0]:index_A[1], :].cuda(gpu_id), 67 | # B[index_B[0]:index_B[1], :].cuda(gpu_id), p=p), 68 | # 10) 69 | 70 | # kd, ki = bottomk(torch.cdist(A[index_A[0]:index_A[1], :].cuda(gpu_id), 71 | # B[index_B[0]:index_B[1], :].cuda(gpu_id), p=p), 72 | # 10) 73 | 74 | # kd, ki = bottomk(torch.cdist(A[index_A[0]:index_A[1], :].cuda(gpu_id), 75 | # B[index_B[0]:index_B[1], :].cuda(gpu_id), p=p), 76 | # 10) 77 | kd_list.append((kd, ki)) 78 | 79 | # cdist_mat[index_A[0]:index_A[1], index_B[0]:index_B[1]] = \ 80 | # torch.cdist(A[index_A[0]:index_A[1], :].cuda(gpu_id), 81 | # B[index_B[0]:index_B[1], :].cuda(gpu_id), 82 | # p=p).cpu() 83 | # print(gpu_id, kd_list) 84 | # return cdist_mat 85 | return kd_list 86 | 87 | 88 | # return bottomk_batch(cdist_mat, k=10, batch_size=batch_size) 89 | 90 | 91 | 92 | def cdist_per_GPU(x, pval, list_indexs, k, gpu_id, GPU_batch): 93 | 94 | print('something') 95 | print(list_indexs) 96 | for i in list_indexs: 97 | # print(gpu_id, i) 98 | print('On GPU', gpu_id, i[0][0], i[0][1], i[1][0], i[1][1]) 99 | 100 | # # a = x[i[0][0]:i[0][1], :].to(gpu_id) 101 | # # batch_size = i[1][1] - i[1][0] 102 | # # batch_inds = torch.arange(i[1][0], i[1][1]).repeat(batch_size, 1) 103 | kd_list = cdist_batch_k(x[i[0][0]:i[0][1], :], 104 | x[i[1][0]:i[1][1], :], 105 | gpu_id=gpu_id, 106 | batch_size=GPU_batch) 107 | 108 | # kd_list = knn_batch_gpu(x[i[0][0]:i[0][1], :], 109 | # x[i[1][0]:i[1][1], :], 110 | # gpu_id=gpu_id, 111 | # batch_size=GPU_batch) 112 | 113 | # print('cdist_mat_batch shape', cdist_mat_batch.shape) 114 | 115 | # bk = bottomk_batch(cdist_mat_batch, k, batch_size=GPU_batch) 116 | 117 | # bk = batch_cdist(x[i[0][0]:i[0][1], :], 118 | # x[i[1][0]:i[1][1], :], 119 | # batch_size=GPU_batch) 120 | # bk = bottomk(cdist_mat_batch, k) 121 | 122 | # print('bk!', bk[0].shape, bk[1].shape) 123 | # print('b ind', batch_inds.shape) 124 | # print('aa', batch_inds.gather(1, bk[1].long()).shape) 125 | # pval.append((i, bk[0], batch_inds.gather(1, bk[1].long()))) 126 | 127 | # pval.append((i, batch_cdist(x[i[0][0]:i[0][1], :], 128 | # x[i[1][0]:i[1][1], :], 129 | # batch_size=GPU_batch).cpu())) 130 | 131 | pval.append((gpu_id, kd_list)) 132 | 133 | if __name__ == '__main__': 134 | mp.set_start_method('spawn') 135 | 136 | # data generation 137 | n_processes = 8 138 | n_samples =4000000 139 | n_dimensions = 100 140 | k =10 141 | 142 | # decide global and local size 143 | global_batch_size= 250000 144 | GPU_batch = 50000 145 | 146 | X = torch.randn(n_samples,n_dimensions) 147 | batch_index_A = get_batch_index(n_samples, global_batch_size) 148 | 149 | # retrieve all index 150 | all_index = [] 151 | for i in batch_index_A: 152 | for j in batch_index_A: 153 | all_index.append((i, j)) 154 | n_tasks_per_gpu = int(len(all_index)/n_processes) 155 | 156 | ab = [] 157 | 158 | start = time.time() 159 | 160 | with mp.Manager() as mgr: 161 | processes = [] 162 | pval = mgr.list() 163 | 164 | for i in range(n_processes): 165 | processes.append(mp.Process(target=cdist_per_GPU, 166 | args=(X, 167 | pval, 168 | all_index[i*n_tasks_per_gpu:(i+1)*n_tasks_per_gpu], 169 | k, 170 | i, 171 | GPU_batch))) 172 | for p in processes: 173 | p.start() 174 | 175 | for p in processes: 176 | p.join() 177 | 178 | print() 179 | print() 180 | print('cdist from {a} gpus'.format(a=n_processes)) 181 | for k, pv in enumerate(pval): 182 | print(k, len(pv)) 183 | ab.append(pv) 184 | 185 | 186 | kdist_mat = torch.zeros([n_samples, int(n_samples/GPU_batch)*10]) 187 | kind_mat = torch.zeros([n_samples, int(n_samples/GPU_batch)*10]) 188 | for k, pv in enumerate(ab): 189 | index_left = all_index[pv[0]][0] 190 | index_right = all_index[pv[0]][1] 191 | 192 | start_index = int(index_right[0]/GPU_batch) 193 | sub_index_len = len(pv[1]) 194 | 195 | # sub_mat = torch.zeros(global_batch_size*global_batch_size*GPU_batch, 10) 196 | # print('submat', sub_mat.shape) 197 | single_len = int(np.sqrt(sub_index_len)) 198 | 199 | for i in range(single_len): 200 | 201 | if index_right[0] !=0: 202 | for j in range(single_len, 2*single_len): 203 | print(index_left[0], index_left[0]+(i+1)*GPU_batch, 10*j, 10*(j+1)) 204 | kdist_mat[index_left[0]+i*GPU_batch:index_left[0]+(i+1)*GPU_batch, 10*j:10*(j+1)] = pv[1][i*single_len+j-single_len][0] 205 | kind_mat[index_left[0]+i*GPU_batch:index_left[0]+(i+1)*GPU_batch, 10*j:10*(j+1)] = pv[1][i*single_len+j-single_len][1] 206 | else: 207 | for j in range(single_len): 208 | print(index_left[0], index_left[0]+(i+1)*GPU_batch, 10*j, 10*(j+1)) 209 | kdist_mat[index_left[0]+i*GPU_batch:index_left[0]+(i+1)*GPU_batch, 10*j:10*(j+1)] = pv[1][i*single_len+j][0] 210 | kind_mat[index_left[0]+i*GPU_batch:index_left[0]+(i+1)*GPU_batch, 10*j:10*(j+1)] = pv[1][i*single_len+j][1] 211 | # for l in range(sub_index_len): 212 | # print(l*GPU_batch, (l+1)*GPU_batch) 213 | # print('s', pv[1][l][0].shape) 214 | # # sub_mat[l*GPU_batch:(l+1)*GPU_batch, :] = pv[1][l][0] 215 | # print(sub_mat.shape) 216 | print(kdist_mat, kdist_mat.shape) 217 | 218 | knn_dist, knn_inds = bottomk(kdist_mat.cuda(), k=10) 219 | # need a final gather since kind_mat is for specific mat. 220 | 221 | end = time.time() 222 | print(end - start) 223 | 224 | -------------------------------------------------------------------------------- /reproducibility/additional_scripts/numpy_vs_torch_batch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 30 18:11:07 2020 4 | 5 | @author: yuezh 6 | """ 7 | 8 | import time 9 | 10 | import torch 11 | from torch import cdist 12 | 13 | # check torch version 14 | 15 | print(torch.__version__) 16 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 17 | print(device, torch.cuda.get_device_name(torch.cuda.current_device())) 18 | 19 | # disable autograd 20 | torch.set_grad_enabled(False) 21 | 22 | 23 | def get_batch_index(n_samples, batch_size): 24 | index_tracker = [] 25 | 26 | n_batches = int(np.ceil(n_samples // batch_size)) 27 | print('n_batches', n_batches) 28 | tracker = 0 29 | for i in range(n_batches): 30 | left_index = tracker * batch_size 31 | right_index = left_index + batch_size 32 | tracker += 1 33 | # print(left_index, right_index) 34 | index_tracker.append((left_index, right_index)) 35 | 36 | if n_samples % batch_size != 0: 37 | left_index = right_index 38 | right_index = n_samples 39 | # print(left_index, right_index) 40 | index_tracker.append((left_index, right_index)) 41 | return index_tracker 42 | 43 | 44 | # %% 45 | torch.cuda.empty_cache() 46 | 47 | A = torch.randn(50000, 1000).cuda().half() 48 | B = torch.randn(50000, 1000).cuda().half() 49 | 50 | 51 | def batch_cdist(A, B, batch_size=25000): 52 | # def cdist_batch(A, B=None, batch_size=None): 53 | 54 | # A should be able to be batchfied 55 | # B should be fixed 56 | 57 | if B is None: 58 | B = A 59 | 60 | n_samples, n_features = A.shape[0], A.shape[1] 61 | n_distance = B.shape[0] 62 | 63 | batch_index_A = get_batch_index(n_samples, batch_size) 64 | batch_index_B = get_batch_index(n_distance, batch_size) 65 | print(batch_index_A) 66 | print(batch_index_B) 67 | 68 | cdist_mat = torch.zeros([n_samples, n_distance]).half() 69 | # cdist_mat = np.zeros([n_samples, n_distance]) 70 | 71 | for i, index_A in enumerate(batch_index_A): 72 | for j, index_B in enumerate(batch_index_B): 73 | cdist_mat[index_A[0]:index_A[1], index_B[0]:index_B[1]] = \ 74 | cdist(A[index_A[0]:index_A[1], :].cuda().half(), 75 | B[index_B[0]:index_B[1], :].cuda().half()) 76 | 77 | print(cdist_mat) 78 | print() 79 | return cdist_mat 80 | 81 | 82 | start = time.time() 83 | w = batch_cdist(A, B) 84 | end = time.time() 85 | print(end - start) 86 | 87 | torch.cuda.empty_cache() 88 | 89 | start = time.time() 90 | cdist_mat_raw = cdist(A, B) 91 | print(cdist_mat_raw) 92 | end = time.time() 93 | print(end - start) 94 | # %% numpy time 95 | from scipy.spatial.distance import cdist 96 | 97 | C = torch.randn(10000, 1000).half().cpu().numpy() 98 | D = torch.randn(10000, 1000).half().cpu().numpy() 99 | start = time.time() 100 | cdist_mat_raw = cdist(C, D) 101 | end = time.time() 102 | print(end - start) 103 | 104 | C = torch.randn(10, 2).cuda().half() 105 | print(C) 106 | print(torch.norm(C, dim=1)) 107 | 108 | # %% 109 | import numpy as np 110 | import torch 111 | from torch import cdist 112 | 113 | torch.cuda.empty_cache() 114 | 115 | a = torch.randn(50000, 200).cuda().half() 116 | b = torch.randn(20000, 200).cuda().half() 117 | 118 | 119 | def cdist_s(a, b): 120 | norm_a = torch.norm(a, dim=1).reshape(a.shape[0], 1) 121 | norm_b = torch.norm(b, dim=1).reshape(1, b.shape[0]) 122 | 123 | w = norm_a ** 2 + norm_b ** 2 - 2 * torch.matmul(a, b.T) 124 | return torch.sqrt(w) 125 | 126 | 127 | print(a) 128 | print(b) 129 | 130 | w = cdist_s(a, b) 131 | print(w) 132 | # p = cdist(a,b) 133 | # print(p) 134 | -------------------------------------------------------------------------------- /reproducibility/additional_scripts/quantization.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Benchmark of all implemented algorithms 3 | """ 4 | 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import os 9 | import sys 10 | from time import time 11 | import itertools 12 | import torch 13 | 14 | # temporary solution for relative imports in case pyod is not installed 15 | # if pyod is installed, no need to use the following line 16 | sys.path.append( 17 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 18 | # supress warnings for clean output 19 | import warnings 20 | 21 | warnings.filterwarnings("ignore") 22 | 23 | import numpy as np 24 | import pandas as pd 25 | from sklearn.model_selection import train_test_split 26 | from scipy.io import loadmat 27 | 28 | from pyod.models.abod import ABOD 29 | from pyod.models.cblof import CBLOF 30 | # from pyod.models.feature_bagging import FeatureBagging 31 | from pyod.models.hbos import HBOS 32 | from pyod.models.iforest import IForest 33 | from pyod.models.knn import KNN 34 | from pyod.models.lmdd import LMDD 35 | from pyod.models.loci import LOCI 36 | from pyod.models.loda import LODA 37 | from pyod.models.lof import LOF 38 | from pyod.models.mcd import MCD 39 | from pyod.models.ocsvm import OCSVM 40 | from pyod.models.pca import PCA 41 | from pyod.models.cof import COF 42 | from pyod.models.sod import SOD 43 | 44 | from pyod.utils.data import generate_data 45 | from pyod.utils.data import evaluate_print 46 | 47 | from pyod.utils.utility import standardizer 48 | from pyod.utils.utility import precision_n_scores 49 | from sklearn.metrics import roc_auc_score 50 | from sklearn.metrics import average_precision_score 51 | from sklearn.preprocessing import MinMaxScaler 52 | import arff 53 | 54 | from scipy.stats import rankdata 55 | from basic_operators import topk, bottomk, bottomk_low_prec, topk_low_prec 56 | 57 | from mpmath import mp, mpf 58 | 59 | machine_eps = mpf(2**-53) 60 | 61 | def get_bounded_error(max_value, dimension, machine_eps=np.finfo(float).eps, two_sided=True): 62 | mp.dps = 100 63 | factor = (1+machine_eps)**(mp.log(dimension)+2)-1 64 | if two_sided: 65 | return float(2*(4*dimension*(max_value**2)*factor)) 66 | else: 67 | return float(4*dimension*(max_value**2)*factor) 68 | 69 | # print(get_bounded_error(1, 1000000)) 70 | # error_bound = float(get_bounded_error(1, 1000000)) 71 | 72 | # TODO: add neural networks, LOCI, SOS, COF, SOD 73 | 74 | # Define data file and read X and y 75 | mat_file_list = [ 76 | # 'annthyroid.mat', 77 | 'arrhythmia.mat', 78 | # 'breastw.mat', 79 | # 'glass.mat', 80 | # 'ionosphere.mat', 81 | # 'letter.mat', 82 | # 'lympho.mat', 83 | # 'mammography.mat', 84 | # 'mnist.mat', 85 | # 'musk.mat', 86 | 87 | # 'optdigits.mat', 88 | # 'pendigits.mat', 89 | # 'pima.mat', 90 | # 'satellite.mat', 91 | # 'satimage-2.mat', 92 | # # 'shuttle.mat', 93 | # # 'smtp_n.mat', 94 | # 'speech.mat', 95 | # 'thyroid.mat', 96 | # 'vertebral.mat', 97 | # 'vowels.mat', 98 | # 'wbc.mat', 99 | # 'wine.mat', 100 | ] 101 | 102 | mat_file = 'speech.mat' 103 | mat = loadmat(os.path.join("datasets", "ODDS", mat_file)) 104 | 105 | X = mat['X'] 106 | y = mat['y'].ravel() 107 | 108 | n_samples, n_features = X.shape[0], X.shape[1] 109 | 110 | outliers_fraction = np.count_nonzero(y) / len(y) 111 | outliers_percentage = round(outliers_fraction * 100, ndigits=4) 112 | 113 | # scaler = MinMaxScaler(feature_range=((1,2))) 114 | 115 | # X_transform = scaler.fit_transform(X) 116 | # a = rankdata(X, axis=0) 117 | # b = rankdata(X_transform, axis=0) 118 | 119 | X = standardizer(X) 120 | error_bound = get_bounded_error(np.max(X), n_features) 121 | print(error_bound) 122 | 123 | k = 10 124 | # X_train = torch.tensor(X).half().cuda() 125 | X_train = torch.tensor(X).float() 126 | # X_train = torch.tensor(X).double().cuda() 127 | 128 | 129 | cdist_dist = torch.cdist(X_train, X_train, p=2) 130 | 131 | bottomk_dist, bottomk_indices = bottomk(cdist_dist, k) 132 | bottomk_dist1, bottomk_indices1 = bottomk_low_prec(cdist_dist, k) 133 | 134 | # bottomk_dist_sorted, bottomk_indices_argsort = torch.sort(bottomk_dist1, dim=1) 135 | # bottomk_indices_sorted = bottomk_indices1.gather(1, bottomk_indices_argsort) 136 | print() 137 | print('bottomk is not sorted...') 138 | # we can only ensure the top k 139 | print(torch.sum((bottomk_dist[:, k-1] !=bottomk_dist1[:, k-1]).int())) 140 | print(torch.sum((bottomk_indices[:, k-1]!=bottomk_indices1[:, k-1]).int())) 141 | 142 | # we can only ensure the top k 143 | print(torch.sum((bottomk_dist !=bottomk_dist1).int())) 144 | print(torch.sum((bottomk_indices!=bottomk_indices1).int())) 145 | 146 | 147 | bottomk_dist2, bottomk_indices2 = bottomk_low_prec(cdist_dist, k, sort_value=True) 148 | print() 149 | print('bottomk is sorted...') 150 | # we ensure topk 151 | print(torch.sum((bottomk_dist[:, k-1] !=bottomk_dist2[:, k-1]).int())) 152 | print(torch.sum((bottomk_indices[:, k-1]!=bottomk_indices2[:, k-1]).int())) 153 | 154 | # we can ensure all 155 | print(torch.sum((bottomk_dist !=bottomk_dist2).int())) 156 | print(torch.sum((bottomk_indices!=bottomk_indices2).int())) 157 | 158 | #%% 159 | 160 | print() 161 | print('topk is not sorted...') 162 | 163 | topk_dist, topk_indices = topk(cdist_dist, k) 164 | topk_dist1, topk_indices1 = topk_low_prec(cdist_dist, k) 165 | 166 | 167 | # we can only ensure the top k 168 | print(torch.sum((topk_dist[:, k-1] !=topk_dist1[:, k-1]).int())) 169 | print(torch.sum((topk_indices[:, k-1]!=topk_indices1[:, k-1]).int())) 170 | 171 | print(torch.sum((topk_dist !=topk_dist1).int())) 172 | print(torch.sum((topk_indices!=topk_indices1).int())) 173 | 174 | topk_dist2, topk_indices2 = topk_low_prec(cdist_dist, k, sort_value=True) 175 | print() 176 | print('topk is sorted...') 177 | print(torch.sum((topk_dist[:, k-1] !=topk_dist2[:, k-1]).int())) 178 | print(torch.sum((topk_indices[:, k-1]!=topk_indices2[:, k-1]).int())) 179 | 180 | print(torch.sum((topk_dist !=topk_dist2).int())) 181 | print(torch.sum((topk_indices!=topk_indices2).int())) 182 | 183 | 184 | 185 | # here we flip the order 186 | decision_scores = bottomk_dist[:, -1] 187 | 188 | 189 | evaluate_print('knn', y, decision_scores.cpu()) 190 | 191 | # #%% 192 | # from basic_operators import topk, intersec1d 193 | # from pytorch_memlab import LineProfiler 194 | # from pytorch_memlab import MemReporter 195 | # import time 196 | 197 | 198 | # # t1 = torch.randint(low=0, high=20000000, size=[20000000]) 199 | # # t2 = torch.randint(low=5000000, high=25000000, size=[20000000]) 200 | 201 | # t1 = torch.rand(size=[50000000]) 202 | # t2 = torch.rand(size=[50000000]) 203 | 204 | 205 | # t1, t2 = t1.half().cuda(), t2.half().cuda() 206 | # # t1, t2 = t1.float().cuda(), t2.float().cuda() 207 | # # t1, t2 = t1.double().cuda(), t2.double().cuda() 208 | 209 | # def w(A, B): 210 | # return intersec1d(A, B) 211 | 212 | # with LineProfiler(w) as prof: 213 | # # distance_mat = batch_cdist(X_train_norm, X_train_norm, batch_size=5000) 214 | # start = time.time() 215 | # a = w(t1, t2) 216 | # end = time.time() 217 | # print(end - start) 218 | 219 | # print(prof.display()) 220 | 221 | 222 | # #%% 223 | # from basic_operators import topk 224 | # from pytorch_memlab import LineProfiler 225 | # from pytorch_memlab import MemReporter 226 | # import time 227 | 228 | # def Standardizer(X_train, mean=None, std=None, return_mean_std=False): 229 | 230 | # if mean is None: 231 | # mean = torch.mean(X_train, axis=0) 232 | # std = torch.std(X_train, axis=0) 233 | # # print(mean.shape, std.shape) 234 | # assert (mean.shape[0] == X_train.shape[1]) 235 | # assert (std.shape[0] == X_train.shape[1]) 236 | 237 | 238 | # X_train_norm = (X_train-mean)/std 239 | # assert(X_train_norm.shape == X_train.shape) 240 | 241 | # if return_mean_std: 242 | # return X_train_norm, mean, std 243 | # else: 244 | # return X_train_norm 245 | 246 | # contamination = 0.1 # percentage of outliers 247 | # n_train = 200000 # number of training points 248 | # n_test = 1000 # number of testing points 249 | # n_features = 2000 250 | 251 | # # Generate sample data 252 | # X_train, y_train, X_test, y_test = \ 253 | # generate_data(n_train=n_train, 254 | # n_test=n_test, 255 | # n_features=n_features, 256 | # contamination=contamination, 257 | # random_state=42) 258 | 259 | # k = 5 260 | 261 | 262 | # X_train = torch.tensor(X_train) 263 | # X_test = torch.tensor(X_test) 264 | 265 | # # X_train_norm, X_train_mean, X_train_std = Standardizer(X_train, return_mean_std=True) 266 | # # X_test_norm = Standardizer(X_test, mean=X_train_mean, std=X_train_std) 267 | 268 | 269 | 270 | # # X_train_norm = X_train.half().cuda() 271 | # # X_train_norm = X_train.float().cuda() 272 | # X_train_norm = X_train.double().cuda() 273 | # print(X_train_norm.type()) 274 | 275 | 276 | # def w(A, k): 277 | # return torch.topk(A, k) 278 | 279 | # with LineProfiler(w) as prof: 280 | # # distance_mat = batch_cdist(X_train_norm, X_train_norm, batch_size=5000) 281 | # start = time.time() 282 | # a,b = w(X_train_norm, k) 283 | # end = time.time() 284 | # print(end - start) 285 | 286 | # print(prof.display()) 287 | 288 | #%% 289 | 290 | -------------------------------------------------------------------------------- /reproducibility/additional_scripts/readme.MD: -------------------------------------------------------------------------------- 1 | We are cleaning up these scripts for easy run, while the primary results are reproducible with 2 | "compare_real_data.py" and "compare_synthetic_data.py". -------------------------------------------------------------------------------- /reproducibility/additional_scripts/single-knn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Apr 3 14:37:27 2021 4 | 5 | @author: yuezh 6 | """ 7 | import time 8 | import numpy as np 9 | import torch 10 | 11 | from basic_operators import bottomk, cdist 12 | from utility import get_batch_index, Standardizer 13 | 14 | 15 | def batch_cdist(A, B, p=2.0, batch_size=None): 16 | # def cdist_batch(A, B=None, batch_size=None): 17 | # TODO: whether to half can be a parameter 18 | # TODO: should pass other possible hyperparameters to torch.cdist 19 | 20 | # batch is not needed 21 | if batch_size is None: 22 | return torch.cdist(A.cuda(), B.cuda(), p=p) 23 | #todo: what if n_samples is smaller than batch size. need an if/else check 24 | 25 | if B is None: 26 | B = A 27 | 28 | n_samples, n_features = A.shape[0], A.shape[1] 29 | n_distance = B.shape[0] 30 | 31 | batch_index_A = get_batch_index(n_samples, batch_size) 32 | batch_index_B = get_batch_index(n_distance, batch_size) 33 | print(batch_index_A) 34 | print(batch_index_B) 35 | 36 | # this is a cpu tensor to save space 37 | cdist_mat = torch.zeros([n_samples, n_distance]) 38 | 39 | for i, index_A in enumerate(batch_index_A): 40 | for j, index_B in enumerate(batch_index_B): 41 | cdist_mat[index_A[0]:index_A[1], index_B[0]:index_B[1]] = \ 42 | torch.cdist(A[index_A[0]:index_A[1], :].cuda(), 43 | B[index_B[0]:index_B[1], :].cuda(), 44 | p=p).cpu() 45 | return cdist_mat 46 | 47 | def knn_batch_intermediate(A, B, k=5, p=2.0, batch_size=None): 48 | # this is the map step 49 | n_samples, n_features = A.shape[0], A.shape[1] 50 | n_distance = B.shape[0] 51 | 52 | batch_index_A = get_batch_index(n_samples, batch_size) 53 | batch_index_B = get_batch_index(n_distance, batch_size) 54 | print(batch_index_A) 55 | print(batch_index_B) 56 | 57 | n_batch_A = len(batch_index_A) 58 | n_batch_B = len(batch_index_B) 59 | 60 | # this is a cpu tensor to save space 61 | # cdist_mat = torch.zeros([n_samples, n_distance]) 62 | k_dist_mat = torch.zeros([n_samples, n_batch_B*k]) 63 | k_inds_mat = torch.zeros([n_samples, n_batch_B*k]).int() 64 | 65 | for i, index_A in enumerate(batch_index_A): 66 | for j, index_B in enumerate(batch_index_B): 67 | print(i, j, n_batch_A, n_batch_B) 68 | 69 | # get the dist 70 | cdist_mat_batch = torch.cdist(A[index_A[0]:index_A[1], :].cuda(), 71 | B[index_B[0]:index_B[1], :].cuda(), p=p) 72 | 73 | # important, need to select from the batch index 74 | # otherwise the ind starts from 0 again 75 | batch_inds = torch.arange(index_B[0], index_B[1]).repeat(batch_size, 1) 76 | # print(batch_inds.shape) 77 | 78 | bk = bottomk(cdist_mat_batch, k) 79 | # we need a global indices here 80 | k_dist_mat[i*batch_size:(i+1)*batch_size, j*k:(j+1)*k] = bk[0] 81 | k_inds_mat[i*batch_size:(i+1)*batch_size, j*k:(j+1)*k] = batch_inds.gather(1, bk[1].long()) 82 | 83 | return k_dist_mat, k_inds_mat 84 | 85 | 86 | def get_knn_from_intermediate(intermediate_knn, k): 87 | # this is the reduce step 88 | 89 | # sort distance for index, real knn happens here 90 | sorted_ind = torch.argsort(intermediate_knn[0], dim=1) 91 | 92 | # bottomk_indices.gather(1, bottomk_indices_argsort) 93 | 94 | # selected the first k for each sample 95 | knn_dist = intermediate_knn[0].gather(1, sorted_ind[:, :k]) 96 | knn_inds = intermediate_knn[1].gather(1, sorted_ind[:, :k]) 97 | 98 | return knn_dist, knn_inds 99 | 100 | def knn_batch(A, B, k=5, p=2.0, batch_size=None): 101 | intermediate_knn = knn_batch_intermediate(A, B, k, p, batch_size) 102 | return get_knn_from_intermediate(intermediate_knn, k) 103 | 104 | 105 | if __name__ == '__main__': 106 | 107 | n_train = 1000000 # number of training points 108 | n_features = 100 109 | batch_size = 40000 110 | p = 2 111 | k = 10 112 | 113 | # # Generate sample data 114 | # # X_train = torch.randn([n_train, n_features]).half() 115 | # # X_train = torch.randn([n_train, n_features]) 116 | A = torch.randn([n_train, n_features]) 117 | # # X_train_norm = Standardizer(X_train, return_mean_std=False) 118 | B = A 119 | 120 | 121 | 122 | start = time.time() 123 | # intermediate_knn = knn_batch_intermediate(A, B, k, batch_size=batch_size) 124 | # knn_dist, knn_inds = get_knn_from_intermediate(intermediate_knn, k) 125 | 126 | knn_dist, knn_inds = knn_batch(A, B, k, batch_size=batch_size) 127 | 128 | end = time.time() 129 | print(n_train, n_features, end - start) 130 | 131 | -------------------------------------------------------------------------------- /reproducibility/compare_real_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using PyTOD on real-world datasets 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | import os 12 | import sys 13 | import time 14 | 15 | import numpy as np 16 | import torch 17 | from pyod.models.abod import ABOD as PyOD_ABOD 18 | from pyod.models.hbos import HBOS as PyOD_HBOS 19 | from pyod.models.knn import KNN as PyOD_KNN 20 | from pyod.models.lof import LOF as PyOD_LOF 21 | from pyod.models.pca import PCA as PyOD_PCA 22 | from scipy.io import loadmat 23 | 24 | # temporary solution for relative imports in case pyod is not installed 25 | # if pyod is installed, no need to use the following line 26 | sys.path.append( 27 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 28 | 29 | from pytod.models.abod import ABOD 30 | from pytod.models.lof import LOF 31 | from pytod.models.knn import KNN 32 | from pytod.models.pca import PCA 33 | from pytod.models.hbos import HBOS 34 | from pytod.utils.utility import validate_device 35 | from pytod.utils.data import get_roc, get_prn 36 | 37 | # please select multiple data 38 | mat_file_list = [ 39 | # 'annthyroid.mat', 40 | # 'arrhythmia.mat', 41 | # 'breastw.mat', 42 | 'glass.mat', 43 | # 'ionosphere.mat', 44 | # 'letter.mat', 45 | # 'lympho.mat', 46 | # 'mammography.mat', 47 | # 'mnist.mat', 48 | # 'musk.mat', 49 | # 'optdigits.mat', 50 | # 'pendigits.mat', 51 | # 'pima.mat', 52 | # 'satellite.mat', 53 | # 'satimage-2.mat', 54 | # 'shuttle.mat', 55 | # 'smtp_n.mat', 56 | # 'speech.mat', 57 | # 'thyroid.mat', 58 | # 'vertebral.mat', 59 | # 'vowels.mat', 60 | # 'wbc.mat', 61 | # 'wine.mat', 62 | ] 63 | 64 | # load PyOD models 65 | models = { 66 | 'LOF': PyOD_LOF(n_neighbors=20), 67 | 'ABOD': PyOD_ABOD(n_neighbors=20), 68 | 'HBOS': PyOD_HBOS(n_bins=50), 69 | 'KNN': PyOD_KNN(n_neighbors=5), 70 | 'PCA': PyOD_PCA(n_components=5) 71 | } 72 | 73 | for j in range(len(mat_file_list)): 74 | mat_file = mat_file_list[j] 75 | # loading and vectorization 76 | mat = loadmat(os.path.join("datasets", "ODDS", mat_file)) 77 | 78 | X = mat['X'].astype('float') 79 | y = mat['y'].ravel() 80 | X_torch = torch.from_numpy(X).float() 81 | 82 | # initialize the output file 83 | text_file = open("results.txt", "a") 84 | text_file.write( 85 | 'file' + '|' + 'algorithm' + '|' + 'system' + '|' + 'ROC' + '|' + 'PRN' + '|' + 'Runtime' + '\n') 86 | 87 | for key in models.keys(): 88 | clf = models[key] 89 | 90 | start = time.time() 91 | clf.fit(X) 92 | decision_scores = clf.decision_scores_ 93 | decision_scores = np.nan_to_num(decision_scores) 94 | end = time.time() 95 | 96 | dur = np.round(end - start, decimals=4) 97 | roc = get_roc(y, decision_scores) 98 | prn = get_prn(y, decision_scores) 99 | 100 | print(mat_file, key, roc, prn, dur) 101 | text_file.write( 102 | mat_file + '|' + key + '|' + 'PyOD' + '|' + str(roc) + '|' + str( 103 | prn) + '|' + str(dur) + '\n') 104 | text_file.close() 105 | 106 | # get results from PyTOD 107 | # try to access the GPU, fall back to cpu if no gpu is available 108 | device = validate_device(0) 109 | batch_size = 30000 110 | 111 | text_file = open("results.txt", "a") 112 | key = 'LOF' 113 | start = time.time() 114 | clf = LOF(n_neighbors=20, batch_size=batch_size, device=device) 115 | clf.fit(X_torch) 116 | decision_scores = clf.decision_scores_ 117 | decision_scores = np.nan_to_num(decision_scores) 118 | end = time.time() 119 | 120 | dur = np.round(end - start, decimals=4) 121 | roc = get_roc(y, decision_scores) 122 | prn = get_prn(y, decision_scores) 123 | 124 | print(mat_file, key, roc, prn, dur) 125 | text_file.write( 126 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 127 | prn) + '|' + str(dur) + '\n') 128 | text_file.close() 129 | ########################################################################### 130 | text_file = open("results.txt", "a") 131 | key = 'ABOD' 132 | start = time.time() 133 | clf = ABOD(n_neighbors=20, batch_size=batch_size, device=device) 134 | clf.fit(X_torch) 135 | decision_scores = clf.decision_scores_ 136 | decision_scores = np.nan_to_num(decision_scores) 137 | end = time.time() 138 | 139 | dur = np.round(end - start, decimals=4) 140 | roc = get_roc(y, decision_scores) 141 | prn = get_prn(y, decision_scores) 142 | 143 | print(mat_file, key, roc, prn, dur) 144 | text_file.write( 145 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 146 | prn) + '|' + str(dur) + '\n') 147 | text_file.close() 148 | ########################################################################### 149 | text_file = open("results.txt", "a") 150 | key = 'HBOS' 151 | start = time.time() 152 | clf = HBOS(n_bins=50, alpha=0.1, device=device) 153 | clf.fit(X_torch) 154 | decision_scores = clf.decision_scores_ 155 | decision_scores = np.nan_to_num(decision_scores) 156 | end = time.time() 157 | 158 | dur = np.round(end - start, decimals=4) 159 | roc = get_roc(y, decision_scores) 160 | prn = get_prn(y, decision_scores) 161 | 162 | print(mat_file, key, roc, prn, dur) 163 | text_file.write( 164 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 165 | prn) + '|' + str(dur) + '\n') 166 | text_file.close() 167 | # ############################################################################################# 168 | text_file = open("results.txt", "a") 169 | key = 'KNN' 170 | start = time.time() 171 | clf = KNN(n_neighbors=5, batch_size=batch_size, device=device) 172 | clf.fit(X_torch) 173 | decision_scores = clf.decision_scores_ 174 | decision_scores = np.nan_to_num(decision_scores) 175 | end = time.time() 176 | 177 | dur = np.round(end - start, decimals=4) 178 | roc = get_roc(y, decision_scores) 179 | prn = get_prn(y, decision_scores) 180 | 181 | print(mat_file, key, roc, prn, dur) 182 | text_file.write( 183 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 184 | prn) + '|' + str(dur) + '\n') 185 | text_file.close() 186 | # ############################################################################################# 187 | text_file = open("results.txt", "a") 188 | key = 'PCA' 189 | start = time.time() 190 | clf = PCA(n_components=5) 191 | clf.fit(X_torch) 192 | decision_scores = clf.decision_scores_ 193 | decision_scores = np.nan_to_num(decision_scores) 194 | end = time.time() 195 | 196 | dur = np.round(end - start, decimals=4) 197 | roc = get_roc(y, decision_scores) 198 | prn = get_prn(y, decision_scores) 199 | 200 | print(mat_file, key, roc, prn, dur) 201 | text_file.write( 202 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 203 | prn) + '|' + str(dur) + '\n') 204 | text_file.close() 205 | 206 | print("The results are stored in results.txt.") 207 | -------------------------------------------------------------------------------- /reproducibility/compare_real_data_adbench.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using PyTOD on real-world datasets 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | import os 12 | import sys 13 | import time 14 | 15 | import numpy as np 16 | import torch 17 | from pyod.models.abod import ABOD as PyOD_ABOD 18 | from pyod.models.hbos import HBOS as PyOD_HBOS 19 | from pyod.models.knn import KNN as PyOD_KNN 20 | from pyod.models.lof import LOF as PyOD_LOF 21 | from pyod.models.pca import PCA as PyOD_PCA 22 | from pyod.utils.utility import precision_n_scores 23 | from scipy.io import loadmat 24 | from sklearn.metrics import roc_auc_score 25 | 26 | # temporary solution for relative imports in case pyod is not installed 27 | # if pyod is installed, no need to use the following line 28 | sys.path.append( 29 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 30 | 31 | from pytod.models.abod import ABOD 32 | from pytod.models.lof import LOF 33 | from pytod.models.knn import KNN 34 | from pytod.models.pca import PCA 35 | from pytod.models.hbos import HBOS 36 | from pytod.utils.utility import validate_device 37 | from pytod.utils.data import get_roc, get_prn 38 | 39 | # please select multiple data 40 | file_name_list = [ 41 | # '8_celeba' 42 | # '9_census', 43 | # '11_donors', 44 | # '13_fraud', 45 | '16_http', 46 | '33_skin' 47 | ] 48 | 49 | # load PyOD models 50 | models = { 51 | 'LOF': PyOD_LOF(n_neighbors=20), 52 | 'ABOD': PyOD_ABOD(n_neighbors=20), 53 | 'HBOS': PyOD_HBOS(n_bins=50), 54 | 'KNN': PyOD_KNN(n_neighbors=5), 55 | 'PCA': PyOD_PCA(n_components=2) 56 | } 57 | 58 | for j in range(len(file_name_list)): 59 | file_name = file_name_list[j] 60 | # loading and vectorization 61 | file_path = os.path.join("datasets", "adbench", file_name + '.npz') 62 | 63 | data = np.load(file_path, allow_pickle=True) 64 | X, y = data['X'].astype(float), data['y'].astype(int).ravel() 65 | print(X.shape) 66 | 67 | X_torch = torch.from_numpy(X).float() 68 | 69 | # initialize the output file 70 | text_file = open("results.txt", "a") 71 | text_file.write( 72 | 'file' + '|' + 'algorithm' + '|' + 'system' + '|' + 'ROC' + '|' + 'PRN' + '|' + 'Runtime' + '\n') 73 | 74 | for key in models.keys(): 75 | clf = models[key] 76 | 77 | start = time.time() 78 | clf.fit(X) 79 | decision_scores = clf.decision_scores_ 80 | decision_scores = np.nan_to_num(decision_scores) 81 | end = time.time() 82 | 83 | dur = np.round(end - start, decimals=4) 84 | roc = get_roc(y, decision_scores) 85 | prn = get_prn(y, decision_scores) 86 | 87 | print(file_name, key, roc, prn, dur) 88 | text_file.write( 89 | file_name + '|' + key + '|' + 'PyOD' + '|' + str(roc) + '|' + str( 90 | prn) + '|' + str(dur) + '\n') 91 | text_file.close() 92 | 93 | # get results from PyTOD 94 | # try to access the GPU, fall back to cpu if no gpu is available 95 | device = validate_device(0) 96 | batch_size = 30000 97 | 98 | text_file = open("results.txt", "a") 99 | key = 'LOF' 100 | start = time.time() 101 | clf = LOF(n_neighbors=20, batch_size=batch_size, device=device) 102 | clf.fit(X_torch) 103 | decision_scores = clf.decision_scores_ 104 | decision_scores = np.nan_to_num(decision_scores) 105 | end = time.time() 106 | 107 | dur = np.round(end - start, decimals=4) 108 | roc = get_roc(y, decision_scores) 109 | prn = get_prn(y, decision_scores) 110 | 111 | print(file_name, key, roc, prn, dur) 112 | text_file.write( 113 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 114 | prn) + '|' + str(dur) + '\n') 115 | text_file.close() 116 | ########################################################################### 117 | text_file = open("results.txt", "a") 118 | key = 'ABOD' 119 | start = time.time() 120 | clf = ABOD(n_neighbors=20, batch_size=batch_size, device=device) 121 | clf.fit(X_torch) 122 | decision_scores = clf.decision_scores_ 123 | decision_scores = np.nan_to_num(decision_scores) 124 | end = time.time() 125 | 126 | dur = np.round(end - start, decimals=4) 127 | roc = get_roc(y, decision_scores) 128 | prn = get_prn(y, decision_scores) 129 | 130 | print(file_name, key, roc, prn, dur) 131 | text_file.write( 132 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 133 | prn) + '|' + str(dur) + '\n') 134 | text_file.close() 135 | ########################################################################### 136 | text_file = open("results.txt", "a") 137 | key = 'HBOS' 138 | start = time.time() 139 | clf = HBOS(n_bins=50, alpha=0.1, device=device) 140 | clf.fit(X_torch) 141 | decision_scores = clf.decision_scores_ 142 | decision_scores = np.nan_to_num(decision_scores) 143 | end = time.time() 144 | 145 | dur = np.round(end - start, decimals=4) 146 | roc = get_roc(y, decision_scores) 147 | prn = get_prn(y, decision_scores) 148 | 149 | print(file_name, key, roc, prn, dur) 150 | text_file.write( 151 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 152 | prn) + '|' + str(dur) + '\n') 153 | text_file.close() 154 | # ############################################################################################# 155 | text_file = open("results.txt", "a") 156 | key = 'KNN' 157 | start = time.time() 158 | clf = KNN(n_neighbors=5, batch_size=batch_size, device=device) 159 | clf.fit(X_torch) 160 | decision_scores = clf.decision_scores_ 161 | decision_scores = np.nan_to_num(decision_scores) 162 | end = time.time() 163 | 164 | dur = np.round(end - start, decimals=4) 165 | roc = get_roc(y, decision_scores) 166 | prn = get_prn(y, decision_scores) 167 | 168 | print(file_name, key, roc, prn, dur) 169 | text_file.write( 170 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 171 | prn) + '|' + str(dur) + '\n') 172 | text_file.close() 173 | # ############################################################################################# 174 | text_file = open("results.txt", "a") 175 | key = 'PCA' 176 | start = time.time() 177 | clf = PCA(n_components=2) 178 | clf.fit(X_torch) 179 | decision_scores = clf.decision_scores_ 180 | decision_scores = np.nan_to_num(decision_scores) 181 | end = time.time() 182 | 183 | dur = np.round(end - start, decimals=4) 184 | roc = get_roc(y, decision_scores) 185 | prn = get_prn(y, decision_scores) 186 | 187 | print(file_name, key, roc, prn, dur) 188 | text_file.write( 189 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 190 | prn) + '|' + str(dur) + '\n') 191 | text_file.close() 192 | 193 | print("The results are stored in results.txt.") 194 | -------------------------------------------------------------------------------- /reproducibility/compare_real_data_quant.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using PyTOD on real-world datasets 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore") 10 | 11 | import os 12 | import sys 13 | import time 14 | 15 | import numpy as np 16 | import torch 17 | from pyod.models.abod import ABOD as PyOD_ABOD 18 | from pyod.models.hbos import HBOS as PyOD_HBOS 19 | from pyod.models.knn import KNN as PyOD_KNN 20 | from pyod.models.lof import LOF as PyOD_LOF 21 | from pyod.models.pca import PCA as PyOD_PCA 22 | from scipy.io import loadmat 23 | 24 | # temporary solution for relative imports in case pyod is not installed 25 | # if pyod is installed, no need to use the following line 26 | sys.path.append( 27 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 28 | 29 | from pytod.models.abod import ABOD 30 | from pytod.models.lof import LOF 31 | from pytod.models.knn import KNN 32 | from pytod.models.pca import PCA 33 | from pytod.models.hbos import HBOS 34 | from pytod.utils.utility import validate_device 35 | from pytod.utils.data import get_roc, get_prn 36 | 37 | # please select multiple data 38 | # please select multiple data 39 | file_name_list = [ 40 | '8_celeba' 41 | # '9_census', 42 | # '11_donors', 43 | # '13_fraud', 44 | ] 45 | 46 | # load PyOD models 47 | models = { 48 | # 'LOF': PyOD_LOF(n_neighbors=20), 49 | # 'ABOD': PyOD_ABOD(n_neighbors=20), 50 | 'HBOS': PyOD_HBOS(n_bins=50), 51 | # 'KNN': PyOD_KNN(n_neighbors=5), 52 | 'PCA': PyOD_PCA(n_components=2) 53 | } 54 | 55 | for j in range(len(file_name_list)): 56 | file_name = file_name_list[j] 57 | # loading and vectorization 58 | file_path = os.path.join("datasets", "adbench", file_name + '.npz') 59 | 60 | data = np.load(file_path, allow_pickle=True) 61 | X, y = data['X'].astype('float64'), data['y'].astype(int).ravel() 62 | print(X.shape) 63 | 64 | X_torch = torch.from_numpy(X).double() 65 | 66 | # initialize the output file 67 | text_file = open("results.txt", "a") 68 | text_file.write( 69 | 'file' + '|' + 'algorithm' + '|' + 'system' + '|' + 'ROC' + '|' + 'PRN' + '|' + 'Runtime' + '\n') 70 | for key in models.keys(): 71 | clf = models[key] 72 | 73 | start = time.time() 74 | clf.fit(X) 75 | decision_scores = clf.decision_scores_ 76 | decision_scores = np.nan_to_num(decision_scores) 77 | end = time.time() 78 | 79 | dur = np.round(end - start, decimals=4) 80 | roc = get_roc(y, decision_scores) 81 | prn = get_prn(y, decision_scores) 82 | 83 | print(file_name, key, roc, prn, dur) 84 | text_file.write( 85 | file_name + '|' + key + '|' + 'PyOD' + '|' + str(roc) + '|' + str( 86 | prn) + '|' + str(dur) + '\n') 87 | text_file.close() 88 | 89 | # get results from PyTOD 90 | # try to access the GPU, fall back to cpu if no gpu is available 91 | device = validate_device(0) 92 | batch_size = 30000 93 | 94 | # text_file = open("results.txt", "a") 95 | # key = 'LOF' 96 | # start = time.time() 97 | # clf = LOF(n_neighbors=20, batch_size=batch_size, device=device) 98 | # clf.fit(X_torch) 99 | # decision_scores = clf.decision_scores_ 100 | # decision_scores = np.nan_to_num(decision_scores) 101 | # end = time.time() 102 | # 103 | # dur = np.round(end - start, decimals=4) 104 | # roc = get_roc(y, decision_scores) 105 | # prn = get_prn(y, decision_scores) 106 | # 107 | # print(mat_file, key, roc, prn, dur) 108 | # text_file.write( 109 | # mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 110 | # prn) + '|' + str(dur) + '\n') 111 | # text_file.close() 112 | # ########################################################################### 113 | # text_file = open("results.txt", "a") 114 | # key = 'ABOD' 115 | # start = time.time() 116 | # clf = ABOD(n_neighbors=20, batch_size=batch_size, device=device) 117 | # clf.fit(X_torch) 118 | # decision_scores = clf.decision_scores_ 119 | # decision_scores = np.nan_to_num(decision_scores) 120 | # end = time.time() 121 | # 122 | # dur = np.round(end - start, decimals=4) 123 | # roc = get_roc(y, decision_scores) 124 | # prn = get_prn(y, decision_scores) 125 | # 126 | # print(mat_file, key, roc, prn, dur) 127 | # text_file.write( 128 | # mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 129 | # prn) + '|' + str(dur) + '\n') 130 | # text_file.close() 131 | ########################################################################### 132 | text_file = open("results.txt", "a") 133 | key = 'HBOS' 134 | start = time.time() 135 | clf = HBOS(n_bins=50, alpha=0.1, device=device) 136 | clf.fit(X_torch) 137 | decision_scores = clf.decision_scores_ 138 | decision_scores = np.nan_to_num(decision_scores) 139 | end = time.time() 140 | 141 | dur = np.round(end - start, decimals=4) 142 | roc = get_roc(y, decision_scores) 143 | prn = get_prn(y, decision_scores) 144 | 145 | print(file_name, key, roc, prn, dur) 146 | text_file.write( 147 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 148 | prn) + '|' + str(dur) + '\n') 149 | text_file.close() 150 | # ############################################################################################# 151 | text_file = open("results.txt", "a") 152 | key = 'KNN' 153 | start = time.time() 154 | clf = KNN(n_neighbors=5, batch_size=batch_size, device=device) 155 | clf.fit(X_torch) 156 | decision_scores = clf.decision_scores_ 157 | decision_scores = np.nan_to_num(decision_scores) 158 | end = time.time() 159 | 160 | dur = np.round(end - start, decimals=4) 161 | roc = get_roc(y, decision_scores) 162 | prn = get_prn(y, decision_scores) 163 | 164 | print(file_name, key, roc, prn, dur) 165 | text_file.write( 166 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 167 | prn) + '|' + str(dur) + '\n') 168 | text_file.close() 169 | # ############################################################################################# 170 | text_file = open("results.txt", "a") 171 | key = 'PCA' 172 | start = time.time() 173 | clf = PCA(n_components=5) 174 | clf.fit(X_torch) 175 | decision_scores = clf.decision_scores_ 176 | decision_scores = np.nan_to_num(decision_scores) 177 | end = time.time() 178 | 179 | dur = np.round(end - start, decimals=4) 180 | roc = get_roc(y, decision_scores) 181 | prn = get_prn(y, decision_scores) 182 | 183 | print(file_name, key, roc, prn, dur) 184 | text_file.write( 185 | file_name + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 186 | prn) + '|' + str(dur) + '\n') 187 | text_file.close() 188 | 189 | print("The results are stored in results.txt.") 190 | -------------------------------------------------------------------------------- /reproducibility/compare_synthetic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Example of using PyTOD on synthetic datasets 3 | """ 4 | # Author: Yue Zhao 5 | # License: BSD 2 clause 6 | 7 | import os 8 | import sys 9 | import time 10 | 11 | import numpy as np 12 | import torch 13 | from pyod.models.abod import ABOD as PyOD_ABOD 14 | from pyod.models.hbos import HBOS as PyOD_HBOS 15 | from pyod.models.knn import KNN as PyOD_KNN 16 | from pyod.models.lof import LOF as PyOD_LOF 17 | from pyod.models.pca import PCA as PyOD_PCA 18 | from pyod.utils.data import generate_data 19 | from pyod.utils.utility import precision_n_scores 20 | from sklearn.metrics import roc_auc_score 21 | 22 | # temporary solution for relative imports in case pyod is not installed 23 | # if pyod is installed, no need to use the following line 24 | sys.path.append( 25 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 26 | 27 | from pytod.models.abod import ABOD 28 | from pytod.models.lof import LOF 29 | from pytod.models.knn import KNN 30 | from pytod.models.pca import PCA 31 | from pytod.models.hbos import HBOS 32 | from pytod.utils.utility import validate_device 33 | 34 | 35 | def get_roc(y, y_pred): 36 | from sklearn.utils import column_or_1d 37 | from sklearn.utils import check_consistent_length 38 | y = column_or_1d(y) 39 | y_pred = column_or_1d(y_pred) 40 | check_consistent_length(y, y_pred) 41 | 42 | return np.round(roc_auc_score(y, y_pred), decimals=4) 43 | 44 | 45 | def get_prn(y, y_pred): 46 | from sklearn.utils import column_or_1d 47 | from sklearn.utils import check_consistent_length 48 | y = column_or_1d(y) 49 | y_pred = column_or_1d(y_pred) 50 | check_consistent_length(y, y_pred) 51 | 52 | return np.round(precision_n_scores(y, y_pred), decimals=4) 53 | 54 | 55 | # define the synthetic data here 56 | contamination = 0.1 # percentage of outliers 57 | n_train = 10000 # number of training points 58 | n_features = 200 59 | k = 20 60 | 61 | # Generate sample data 62 | X, y = generate_data(n_train=n_train, 63 | n_features=n_features, 64 | contamination=contamination, 65 | train_only=True, 66 | random_state=42) 67 | 68 | mat_file = str(n_train) + '_' + str(n_features) 69 | X_torch = torch.from_numpy(X).float() 70 | 71 | # load PyOD models 72 | models = { 73 | 'LOF': PyOD_LOF(n_neighbors=20), 74 | 'ABOD': PyOD_ABOD(n_neighbors=20), 75 | 'HBOS': PyOD_HBOS(n_bins=50), 76 | 'KNN': PyOD_KNN(n_neighbors=5), 77 | 'PCA': PyOD_PCA(n_components=5) 78 | } 79 | 80 | text_file = open("results_synthetic.txt", "a") 81 | text_file.write( 82 | 'file' + '|' + 'algorithm' + '|' + 'system' + '|' + 'ROC' + '|' + 'PRN' + '|' + 'Runtime' + '\n') 83 | 84 | for key in models.keys(): 85 | clf = models[key] 86 | 87 | start = time.time() 88 | clf.fit(X) 89 | decision_scores = clf.decision_scores_ 90 | decision_scores = np.nan_to_num(decision_scores) 91 | end = time.time() 92 | 93 | dur = np.round(end - start, decimals=4) 94 | roc = get_roc(y, decision_scores) 95 | prn = get_prn(y, decision_scores) 96 | 97 | print(mat_file, key, roc, prn, dur) 98 | text_file.write( 99 | mat_file + '|' + key + '|' + 'PyOD' + '|' + str(roc) + '|' + str( 100 | prn) + '|' + str(dur) + '\n') 101 | 102 | text_file.close() 103 | 104 | # get results from PyTOD 105 | # try to access the GPU, fall back to cpu if no gpu is available 106 | device = validate_device(0) 107 | batch_size = 30000 108 | 109 | text_file = open("results_synthetic.txt", "a") 110 | key = 'LOF' 111 | start = time.time() 112 | clf = LOF(n_neighbors=20, batch_size=batch_size, device=device) 113 | clf.fit(X_torch) 114 | decision_scores = clf.decision_scores_ 115 | decision_scores = np.nan_to_num(decision_scores) 116 | end = time.time() 117 | 118 | dur = np.round(end - start, decimals=4) 119 | roc = get_roc(y, decision_scores) 120 | prn = get_prn(y, decision_scores) 121 | 122 | print(mat_file, key, roc, prn, dur) 123 | text_file.write( 124 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 125 | prn) + '|' + str(dur) + '\n') 126 | text_file.close() 127 | ############################################################################################# 128 | text_file = open("results_synthetic.txt", "a") 129 | key = 'ABOD' 130 | start = time.time() 131 | clf = ABOD(n_neighbors=20, batch_size=batch_size, device=device) 132 | clf.fit(X_torch) 133 | decision_scores = clf.decision_scores_ 134 | decision_scores = np.nan_to_num(decision_scores) 135 | end = time.time() 136 | 137 | dur = np.round(end - start, decimals=4) 138 | roc = get_roc(y, decision_scores) 139 | prn = get_prn(y, decision_scores) 140 | 141 | print(mat_file, key, roc, prn, dur) 142 | text_file.write( 143 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 144 | prn) + '|' + str(dur) + '\n') 145 | text_file.close() 146 | ############################################################################################# 147 | text_file = open("results_synthetic.txt", "a") 148 | key = 'HBOS' 149 | start = time.time() 150 | clf = HBOS(n_bins=50, alpha=0.1, device=device) 151 | clf.fit(X_torch) 152 | decision_scores = clf.decision_scores_ 153 | decision_scores = np.nan_to_num(decision_scores) 154 | end = time.time() 155 | 156 | dur = np.round(end - start, decimals=4) 157 | roc = get_roc(y, decision_scores) 158 | prn = get_prn(y, decision_scores) 159 | 160 | print(mat_file, key, roc, prn, dur) 161 | text_file.write( 162 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 163 | prn) + '|' + str(dur) + '\n') 164 | text_file.close() 165 | # ############################################################################################# 166 | text_file = open("results_synthetic.txt", "a") 167 | key = 'KNN' 168 | start = time.time() 169 | clf = KNN(n_neighbors=5, batch_size=batch_size, device=device) 170 | clf.fit(X_torch) 171 | decision_scores = clf.decision_scores_ 172 | decision_scores = np.nan_to_num(decision_scores) 173 | end = time.time() 174 | 175 | dur = np.round(end - start, decimals=4) 176 | roc = get_roc(y, decision_scores) 177 | prn = get_prn(y, decision_scores) 178 | 179 | print(mat_file, key, roc, prn, dur) 180 | text_file.write( 181 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 182 | prn) + '|' + str(dur) + '\n') 183 | text_file.close() 184 | # ############################################################################################# 185 | text_file = open("results_synthetic.txt", "a") 186 | key = 'PCA' 187 | start = time.time() 188 | clf = PCA(n_components=5, device=device) 189 | clf.fit(X_torch) 190 | decision_scores = clf.decision_scores_ 191 | decision_scores = np.nan_to_num(decision_scores) 192 | end = time.time() 193 | 194 | dur = np.round(end - start, decimals=4) 195 | roc = get_roc(y, decision_scores) 196 | prn = get_prn(y, decision_scores) 197 | 198 | print(mat_file, key, roc, prn, dur) 199 | text_file.write( 200 | mat_file + '|' + key + '|' + 'PyTOD' + '|' + str(roc) + '|' + str( 201 | prn) + '|' + str(dur) + '\n') 202 | text_file.close() 203 | 204 | print("The results are stored in results_synthetic.txt.") 205 | -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/annthyroid.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/annthyroid.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/arrhythmia.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/arrhythmia.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/breastw.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/breastw.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/glass.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/glass.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/ionosphere.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/ionosphere.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/letter.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/letter.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/lympho.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/lympho.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/mammography.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/mammography.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/mnist.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/mnist.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/musk.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/musk.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/optdigits.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/optdigits.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/pendigits.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/pendigits.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/pima.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/pima.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/satellite.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/satellite.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/satimage-2.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/satimage-2.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/shuttle.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/shuttle.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/smtp_n.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/smtp_n.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/speech.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/speech.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/thyroid.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/thyroid.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/vertebral.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/vertebral.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/vowels.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/vowels.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/wbc.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/wbc.mat -------------------------------------------------------------------------------- /reproducibility/datasets/ODDS/wine.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/datasets/ODDS/wine.mat -------------------------------------------------------------------------------- /reproducibility/fusion_experiment.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import os 6 | import sys 7 | import time 8 | 9 | import torch 10 | from pytorch_memlab import LineProfiler 11 | 12 | # temporary solution for relative imports in case pyod is not installed 13 | # if pyod is installed, no need to use the following line 14 | sys.path.append( 15 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 16 | 17 | from pytod.utils.data import Standardizer 18 | from pytod.utils.utility import validate_device 19 | from pytod.models.basic_operators_batch import cdist_batch 20 | from pytod.models.basic_operators_batch import bottomk_batch 21 | 22 | from pytod.models.intermediate_layers import knn_batch 23 | 24 | n_train = 150000 # number of training points 25 | n_features = 200 26 | batch_size = 20000 27 | k = 10 28 | # Generate sample data 29 | X_train = torch.randn([n_train, n_features]).float() 30 | 31 | device = validate_device(0) 32 | 33 | def simple_conct(X_train, batch_size, k, device): 34 | cdist_dist = cdist_batch(X_train, X_train, batch_size=batch_size, 35 | device=device) 36 | bottomk_batch(cdist_dist, k=k, batch_size=batch_size, device=device) 37 | 38 | with LineProfiler(simple_conct) as prof: 39 | start = time.time() 40 | simple_conct(X_train, batch_size, k, device) 41 | print(time.time() - start) 42 | print(prof.display()) 43 | 44 | with LineProfiler(knn_batch) as prof: 45 | start = time.time() 46 | knn_batch(X_train, X_train, batch_size=batch_size, k=k, device=device) 47 | print(time.time() - start) 48 | print(prof.display()) -------------------------------------------------------------------------------- /reproducibility/implement_new.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import warnings 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | import os 8 | import sys 9 | import time 10 | 11 | import torch 12 | from pytorch_memlab import LineProfiler 13 | 14 | # temporary solution for relative imports in case pyod is not installed 15 | # if pyod is installed, no need to use the following line 16 | sys.path.append( 17 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 18 | 19 | from pytod.utils.data import Standardizer 20 | from pytod.utils.utility import validate_device 21 | from pytod.models.intermediate_layers import neighbor_within_range_low_prec, neighbor_within_range,neighbor_within_range_low_prec_float 22 | from pytod.models.basic_operators_batch import cdist_batch 23 | 24 | n_train = 50000 # number of training points 25 | n_features = 200 26 | batch_size = 40000 27 | 28 | # Generate sample data 29 | X_train = torch.randn([n_train, n_features]).double() 30 | X_train_norm = Standardizer(X_train, return_mean_std=False) 31 | device = validate_device(0) 32 | 33 | # start = time.time() 34 | # with LineProfiler(neighbor_within_range) as prof: 35 | # 36 | # clear_pairs = neighbor_within_range(X_train_norm, range_threshold=12, 37 | # device=device) 38 | # # clear_pairs = neighbor_within_range(X_train_norm, range_threshold=12, 39 | # # batch_size=batch_size, device=device) 40 | # print(prof.display()) 41 | # end = time.time() 42 | # print('64-bit time', end - start) 43 | 44 | 45 | # with LineProfiler(cdist_batch) as prof: 46 | # start = time.time() 47 | # distance_mat = cdist_batch(X_train_norm, X_train_norm, 48 | # batch_size=batch_size, device=device) 49 | # 50 | # # identify the indice pairs 51 | # clear_indices = torch.nonzero((distance_mat <= 12), 52 | # as_tuple=False) 53 | # end = time.time() 54 | # print('64-bit time', end - start) 55 | # print(prof.display()) 56 | 57 | # start = time.time() 58 | # with LineProfiler(neighbor_within_range_low_prec_float) as prof: 59 | # 60 | # clear_pairs = neighbor_within_range_low_prec_float(X_train_norm, 61 | # range_threshold=12, 62 | # device=device) 63 | # 64 | # # clear_pairs = neighbor_within_range_low_prec_float(X_train_norm, 65 | # # range_threshold=12, 66 | # # batch_size=batch_size, 67 | # # device=device) 68 | # 69 | # print(prof.display()) 70 | # end = time.time() 71 | # print('32-bit time', end - start) 72 | 73 | 74 | start = time.time() 75 | with LineProfiler(neighbor_within_range_low_prec) as prof: 76 | 77 | clear_pairs = neighbor_within_range_low_prec(X_train_norm, 78 | range_threshold=12, 79 | device=device) 80 | # clear_pairs = neighbor_within_range_low_prec(X_train_norm, 81 | # range_threshold=12, 82 | # batch_size=batch_size, 83 | # device=device) 84 | 85 | print(prof.display()) 86 | end = time.time() 87 | print('16-bit time', end - start) -------------------------------------------------------------------------------- /reproducibility/knn_classification.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import os 6 | import sys 7 | import time 8 | 9 | import torch 10 | from pytorch_memlab import LineProfiler 11 | 12 | # temporary solution for relative imports in case pyod is not installed 13 | # if pyod is installed, no need to use the following line 14 | sys.path.append( 15 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 16 | 17 | from pytod.utils.data import Standardizer 18 | from pytod.utils.utility import validate_device 19 | from pytod.models.basic_operators_batch import cdist_batch 20 | from pytod.models.basic_operators_batch import bottomk_batch 21 | 22 | n_train = 50000 # number of training points 23 | n_test = 50000 # number of training points 24 | n_features = 200 25 | batch_size = 50000 26 | k = 10 27 | # Generate sample data 28 | X_train = torch.randn([n_train, n_features]).float() 29 | y_train = torch.randint(0, 2, (n_train,)).int() 30 | X_test = torch.randn([n_test, n_features]).float() 31 | device = validate_device(0) 32 | 33 | 34 | def knnclf_tod(X_train, y_train, X_test, device): 35 | n_train, n_test = X_train.shape[0], X_test.shape[0] 36 | cdist_result = cdist_batch(X_test, X_train, batch_size=batch_size, 37 | device=device) 38 | # cdist_result = cdist_batch(X_train, X_test, batch_size=batch_size, device='cpu') 39 | 40 | bottomk_dist, bottomk_ind = bottomk_batch(cdist_result, k, 41 | batch_size=batch_size, 42 | device=device) 43 | # print(cdist_result, cdist_result.shape) 44 | # print(bottomk_ind.shape, bottomk_dist, time.time() - start) 45 | 46 | y_train_repeat = y_train.repeat(1, n_test).reshape(n_test, n_train) 47 | # print(y_train_repeat) 48 | # print(y_train_repeat.shape) 49 | 50 | knn_results = y_train_repeat.gather(1, bottomk_ind.long()) 51 | knn_vote = torch.sum(knn_results, dim=1) / k 52 | 53 | # get the pred results of kNN by TOD 54 | pred = (knn_vote >= 0.5).int() 55 | return pred 56 | 57 | 58 | start = time.time() 59 | pred = knnclf_tod(X_train, y_train, X_test, device) 60 | print(time.time() - start) 61 | 62 | 63 | from sklearn.neighbors import KNeighborsClassifier 64 | neigh = KNeighborsClassifier(n_neighbors=k) 65 | start = time.time() 66 | neigh.fit(X_train.numpy(), y_train.numpy()) 67 | neigh.predict(X_test.numpy()) 68 | print(time.time() - start) -------------------------------------------------------------------------------- /reproducibility/results.txt: -------------------------------------------------------------------------------- 1 | file|algorithm|system|ROC|PRN|Runtime 2 | mnist.mat|LOF|PyOD|0.6482|0.2429|0.9775 3 | mnist.mat|ABOD|PyOD|0.78|0.3857|12.3105 4 | mnist.mat|HBOS|PyOD|0.624|0.15|1.0356 5 | mnist.mat|KNN|PyOD|0.8032|0.3957|0.9466 6 | mnist.mat|PCA|PyOD|0.8521|0.39|0.0345 7 | mnist.mat|LOF|PyTOD|0.6482|0.2429|1.4475 8 | mnist.mat|ABOD|PyTOD|0.78|0.3829|1.2015 9 | mnist.mat|HBOS|PyTOD|0.5|0.0|0.067 10 | mnist.mat|KNN|PyTOD|0.8032|0.3957|0.008 11 | mnist.mat|PCA|PyTOD|0.2329|0.0171|0.1567 12 | file|algorithm|system|ROC|PRN|Runtime 13 | glass.mat|LOF|PyOD|0.8352|0.2222|0.003 14 | glass.mat|ABOD|PyOD|0.8602|0.1111|1.0512 15 | glass.mat|HBOS|PyOD|0.6228|0.0|1.2453 16 | glass.mat|KNN|PyOD|0.8656|0.1111|0.002 17 | glass.mat|PCA|PyOD|0.5783|0.1111|0.03 18 | glass.mat|LOF|PyTOD|0.8352|0.2222|0.023 19 | glass.mat|ABOD|PyTOD|0.8504|0.1111|0.116 20 | glass.mat|HBOS|PyTOD|0.6233|0.0|0.8412 21 | glass.mat|KNN|PyTOD|0.8656|0.1111|0.003 22 | glass.mat|PCA|PyTOD|0.7144|0.3333|1.1993 23 | file|algorithm|system|ROC|PRN|Runtime 24 | glass.mat|LOF|PyOD|0.8352|0.2222|0.003 25 | glass.mat|ABOD|PyOD|0.8602|0.1111|1.0482 26 | glass.mat|HBOS|PyOD|0.6228|0.0|1.1653 27 | glass.mat|KNN|PyOD|0.8656|0.1111|0.001 28 | glass.mat|PCA|PyOD|0.5783|0.1111|0.004 29 | glass.mat|LOF|PyTOD|0.8352|0.2222|0.004 30 | glass.mat|ABOD|PyTOD|0.8504|0.1111|0.088 31 | glass.mat|HBOS|PyTOD|0.6233|0.0|0.7782 32 | glass.mat|KNN|PyTOD|0.8656|0.1111|0.002 33 | glass.mat|PCA|PyTOD|0.7149|0.3333|0.8012 34 | file|algorithm|system|ROC|PRN|Runtime 35 | glass.mat|LOF|PyOD|0.8352|0.2222|0.002 36 | glass.mat|ABOD|PyOD|0.8602|0.1111|0.9732 37 | glass.mat|HBOS|PyOD|0.6228|0.0|1.1083 38 | glass.mat|KNN|PyOD|0.8656|0.1111|0.001 39 | glass.mat|PCA|PyOD|0.5783|0.1111|0.007 40 | glass.mat|LOF|PyTOD|0.8352|0.2222|0.006 41 | glass.mat|ABOD|PyTOD|0.8504|0.1111|0.194 42 | glass.mat|HBOS|PyTOD|0.6233|0.0|0.8062 43 | glass.mat|KNN|PyTOD|0.8656|0.1111|0.001 44 | glass.mat|PCA|PyTOD|0.7176|0.3333|0.8412 45 | file|algorithm|system|ROC|PRN|Runtime 46 | file|algorithm|system|ROC|PRN|Runtime 47 | 4_breastw|LOF|PyOD|0.3901|0.1381|0.007 48 | file|algorithm|system|ROC|PRN|Runtime 49 | 4_breastw|LOF|PyOD|0.3901|0.1381|0.006 50 | 4_breastw|ABOD|PyOD|0.7587|0.5523|1.2493 51 | 4_breastw|HBOS|PyOD|0.985|0.9372|0.8522 52 | 4_breastw|KNN|PyOD|0.9765|0.9145|0.004 53 | 4_breastw|PCA|PyOD|0.9564|0.9247|0.017 54 | file|algorithm|system|ROC|PRN|Runtime 55 | 4_breastw|LOF|PyOD|0.3901|0.1381|0.006 56 | 4_breastw|ABOD|PyOD|0.7587|0.5523|1.2583 57 | 4_breastw|HBOS|PyOD|0.985|0.9372|0.8632 58 | 4_breastw|KNN|PyOD|0.9765|0.9145|0.003 59 | 4_breastw|PCA|PyOD|0.9564|0.9247|0.005 60 | 4_breastw|LOF|PyTOD|0.4177|0.1339|0.006 61 | 4_breastw|ABOD|PyTOD|0.3591|0.0|0.2631 62 | 4_breastw|HBOS|PyTOD|0.985|0.9372|0.7032 63 | 4_breastw|KNN|PyTOD|0.9765|0.9145|0.003 64 | 4_breastw|PCA|PyTOD|0.9959|0.954|0.6762 65 | file|algorithm|system|ROC|PRN|Runtime 66 | 13_fraud|LOF|PyOD|0.475|0.0|920.8856 67 | 13_fraud|ABOD|PyOD|0.906|0.0|1237.4646 68 | 13_fraud|HBOS|PyOD|0.9492|0.2805|1.0692 69 | 13_fraud|KNN|PyOD|0.9342|0.1809|912.9075 70 | 13_fraud|PCA|PyOD|0.9502|0.2317|0.6291 71 | 13_fraud|LOF|PyTOD|0.4754|0.0|79.005 72 | 13_fraud|ABOD|PyTOD|0.8206|0.0|98.533 73 | 13_fraud|HBOS|PyTOD|0.9492|0.2805|0.2841 74 | 13_fraud|KNN|PyTOD|0.9343|0.1809|71.3112 75 | 13_fraud|PCA|PyTOD|0.1319|0.002|0.167 76 | file|algorithm|system|ROC|PRN|Runtime 77 | 11_donors|LOF|PyOD|0.5977|0.2044|967.9399 78 | 11_donors|ABOD|PyOD|0.3376|0.0|1287.4338 79 | 11_donors|HBOS|PyOD|0.8049|0.2834|1.0342 80 | 11_donors|KNN|PyOD|0.6117|0.2188|927.2259 81 | 11_donors|PCA|PyOD|0.8138|0.2037|0.6962 82 | 11_donors|LOF|PyTOD|0.6623|0.2016|331.9047 83 | 11_donors|ABOD|PyTOD|0.4516|0.0|375.2702 84 | 11_donors|HBOS|PyTOD|0.7667|0.2723|0.3001 85 | 11_donors|KNN|PyTOD|0.6338|0.2449|331.4133 86 | 11_donors|PCA|PyTOD|0.8302|0.1429|0.163 87 | file|algorithm|system|ROC|PRN|Runtime 88 | 8_celeba|LOF|PyOD|0.4367|0.011|501.1362 89 | 8_celeba|ABOD|PyOD|0.5134|0.0|695.4697 90 | 8_celeba|HBOS|PyOD|0.7568|0.1524|1.0832 91 | 8_celeba|KNN|PyOD|0.5666|0.0237|503.1269 92 | 8_celeba|PCA|PyOD|0.7831|0.1696|0.4501 93 | 8_celeba|LOF|PyTOD|0.4509|0.0|38.3437 94 | 8_celeba|ABOD|PyTOD|0.4794|0.0|56.786 95 | 8_celeba|HBOS|PyTOD|0.7568|0.1524|0.3651 96 | 8_celeba|KNN|PyTOD|0.5666|0.0237|35.2601 97 | 8_celeba|PCA|PyTOD|0.6169|0.0445|0.136 98 | file|algorithm|system|ROC|PRN|Runtime 99 | 9_census|LOF|PyOD|0.5503|0.0183|1655.7565 100 | 9_census|ABOD|PyOD|0.4579|0.0|2305.704 101 | 9_census|HBOS|PyOD|0.6426|0.0634|3.8896 102 | 9_census|KNN|PyOD|0.6465|0.0659|1674.5057 103 | 9_census|PCA|PyOD|0.6624|0.0665|6.8806 104 | 9_census|LOF|PyTOD|0.5775|0.0181|88.1579 105 | 9_census|ABOD|PyTOD|0.3492|0.0|226.0986 106 | 9_census|HBOS|PyTOD|0.6426|0.0636|0.673 107 | 9_census|KNN|PyTOD|0.6463|0.0659|80.426 108 | 9_census|PCA|PyTOD|0.8517|0.3168|0.2101 109 | file|algorithm|system|ROC|PRN|Runtime 110 | 111 | file|algorithm|system|ROC|PRN|Runtime 112 | 16_http|LOF|PyOD|0.3527|0.0032|9.1701 113 | 16_http|ABOD|PyOD|0.8941|0.0|480.4136 114 | 16_http|HBOS|PyOD|0.9858|0.0199|0.9522 115 | 16_http|KNN|PyOD|0.2309|0.0294|6.4835 116 | 16_http|PCA|PyOD|0.9962|0.022|0.206 117 | 16_http|LOF|PyTOD|0.1178|0.0|280.0403 118 | 16_http|ABOD|PyTOD|0.5521|0.0|313.6084 119 | 16_http|HBOS|PyTOD|0.9858|0.0199|0.076 120 | 16_http|KNN|PyTOD|0.2853|0.0294|266.8727 121 | 16_http|PCA|PyTOD|0.9975|0.9792|0.1221 122 | file|algorithm|system|ROC|PRN|Runtime 123 | 33_skin|LOF|PyOD|0.572|0.2972|1.5724 124 | 33_skin|ABOD|PyOD|0.2631|0.0|163.1831 125 | 33_skin|HBOS|PyOD|0.6022|0.1694|0.02 126 | 33_skin|KNN|PyOD|0.5879|0.3008|1.0892 127 | 33_skin|PCA|PyOD|0.1263|0.0032|0.089 128 | 33_skin|LOF|PyTOD|0.6797|0.2934|54.5795 129 | 33_skin|ABOD|PyTOD|0.4924|0.0|73.5376 130 | 33_skin|HBOS|PyTOD|0.6014|0.1936|0.056 131 | 33_skin|KNN|PyTOD|0.5789|0.2851|54.8439 132 | 33_skin|PCA|PyTOD|0.7537|0.2521|0.011 133 | glass.mat|PCA|PyTOD|0.7122|0.3333|1.9442 134 | 8_celeba|PCA|PyTOD|0.6129|0.0423|1.4413 135 | file|algorithm|system|ROC|PRN|Runtime 136 | 8_celeba|HBOS|PyOD|0.7568|0.1524|1.2213 137 | 8_celeba|HBOS|PyTOD|0.7568|0.1524|0.8852 138 | file|algorithm|system|ROC|PRN|Runtime 139 | 8_celeba|HBOS|PyOD|0.7568|0.1524|1.2233 140 | 8_celeba|PCA|PyOD|0.7594|0.1465|0.4301 141 | 8_celeba|HBOS|PyTOD|0.7568|0.1524|0.8042 142 | 8_celeba|PCA|PyTOD|0.5937|0.0394|0.7122 143 | file|algorithm|system|ROC|PRN|Runtime 144 | 8_celeba|HBOS|PyOD|0.7568|0.1524|1.2333 145 | 8_celeba|PCA|PyOD|0.7594|0.1465|0.4611 146 | 8_celeba|HBOS|PyTOD|0.7568|0.1524|0.9252 147 | 8_celeba|KNN|PyTOD|0.5666|0.0237|41.8073 148 | 8_celeba|PCA|PyTOD|0.5755|0.0376|0.189 149 | file|algorithm|system|ROC|PRN|Runtime 150 | 8_celeba|HBOS|PyOD|0.7568|0.1524|1.3983 151 | 8_celeba|PCA|PyOD|0.7594|0.1465|0.4611 152 | 8_celeba|HBOS|PyTOD|0.7568|0.1524|1.0202 153 | 8_celeba|KNN|PyTOD|0.5666|0.0237|44.1484 154 | 8_celeba|PCA|PyTOD|0.598|0.0336|0.181 155 | -------------------------------------------------------------------------------- /reproducibility/results_synthetic.txt: -------------------------------------------------------------------------------- 1 | file|algorithm|system|ROC|PRN|Runtime 2 | 10000_200|LOF|PyOD|0.1909|0.0|1.5573 3 | 10000_200|ABOD|PyOD|1.0|1.0|15.5342 4 | 10000_200|HBOS|PyOD|1.0|1.0|1.3401 5 | 10000_200|KNN|PyOD|1.0|1.0|1.6634 6 | 10000_200|PCA|PyOD|1.0|1.0|0.0814 7 | 10000_200|LOF|PyTOD|0.1864|0.0|1.4872 8 | 10000_200|ABOD|PyTOD|1.0|1.0|3.1522 9 | 10000_200|HBOS|PyTOD|1.0|1.0|0.1359 10 | 10000_200|KNN|PyTOD|1.0|1.0|0.014 11 | 10000_200|PCA|PyTOD|0.0|0.0|0.1599 12 | file|algorithm|system|ROC|PRN|Runtime 13 | 10000_200|LOF|PyOD|0.1909|0.0|1.5744 14 | 10000_200|ABOD|PyOD|1.0|1.0|14.7663 15 | 10000_200|HBOS|PyOD|1.0|1.0|1.0202 16 | 10000_200|KNN|PyOD|1.0|1.0|1.7304 17 | 10000_200|PCA|PyOD|1.0|1.0|0.073 18 | 10000_200|LOF|PyTOD|0.1909|0.0|1.4723 19 | 10000_200|ABOD|PyTOD|1.0|1.0|1.8164 20 | 10000_200|HBOS|PyTOD|1.0|1.0|0.197 21 | 10000_200|KNN|PyTOD|1.0|1.0|0.011 22 | 10000_200|PCA|PyTOD|0.0|0.0|0.096 23 | file|algorithm|system|ROC|PRN|Runtime 24 | 10000_200|LOF|PyOD|0.1909|0.0|1.5664 25 | 10000_200|ABOD|PyOD|1.0|1.0|14.7754 26 | 10000_200|HBOS|PyOD|1.0|1.0|0.9362 27 | 10000_200|KNN|PyOD|1.0|1.0|1.6594 28 | 10000_200|PCA|PyOD|1.0|1.0|0.074 29 | 10000_200|LOF|PyTOD|0.1909|0.0|1.5644 30 | 10000_200|ABOD|PyTOD|1.0|1.0|1.8634 31 | 10000_200|HBOS|PyTOD|1.0|1.0|0.133 32 | 10000_200|KNN|PyTOD|1.0|1.0|0.011 33 | 10000_200|PCA|PyTOD|0.0|0.0|0.09 34 | file|algorithm|system|ROC|PRN|Runtime 35 | 10000_200|LOF|PyOD|0.1909|0.0|1.6082 36 | 10000_200|ABOD|PyOD|1.0|1.0|14.5055 37 | 10000_200|HBOS|PyOD|1.0|1.0|0.959 38 | 10000_200|KNN|PyOD|1.0|1.0|1.6179 39 | 10000_200|PCA|PyOD|1.0|1.0|0.0789 40 | 10000_200|LOF|PyTOD|0.1909|0.0|1.6418 41 | 10000_200|ABOD|PyTOD|1.0|1.0|1.8843 42 | 10000_200|HBOS|PyTOD|1.0|1.0|0.1375 43 | 10000_200|KNN|PyTOD|1.0|1.0|0.011 44 | 10000_200|PCA|PyTOD|0.0|0.0|0.0934 45 | -------------------------------------------------------------------------------- /reproducibility/time_break.xlsx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzhao062/pytod/ec43433ad1a0ab939195a5eda0c1a6ab01b96ad2/reproducibility/time_break.xlsx -------------------------------------------------------------------------------- /reproducibility/time_breakdown.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import numpy as np 6 | import torch 7 | from pyod.utils.data import generate_data 8 | 9 | # temporary solution for relative imports in case pyod is not installed 10 | # if pyod is installed, no need to use the following line 11 | sys.path.append( 12 | os.path.abspath(os.path.join(os.path.dirname("__file__"), '..'))) 13 | 14 | from pytod.models.abod import ABOD 15 | from pytod.models.lof import LOF 16 | from pytod.models.knn import KNN 17 | from pytod.models.pca import PCA 18 | from pytod.models.hbos import HBOS 19 | from pytod.utils.utility import validate_device 20 | 21 | # define the synthetic data here 22 | contamination = 0.1 # percentage of outliers 23 | n_train = 100000 # number of training points 24 | n_features = 200 25 | k = 20 26 | batch_size = 30000 27 | 28 | # Generate sample data 29 | X, y = generate_data(n_train=n_train, 30 | n_features=n_features, 31 | contamination=contamination, 32 | train_only=True, 33 | random_state=42) 34 | 35 | mat_file = str(n_train) + '_' + str(n_features) 36 | X_torch = torch.from_numpy(X).float() 37 | 38 | device = validate_device(0) 39 | 40 | def knn_measure(): 41 | key = 'KNN' 42 | start = time.time() 43 | clf = KNN(n_neighbors=20, batch_size=batch_size, device=device) 44 | clf.fit(X_torch, return_time=True) 45 | decision_scores = clf.decision_scores_ 46 | decision_scores = np.nan_to_num(decision_scores) 47 | end = time.time() 48 | print('kNN total time', end - start) 49 | print('kNN GPU time', clf.gpu_time) 50 | 51 | def hbos_measure(): 52 | key = 'HBOS' 53 | start = time.time() 54 | clf = HBOS(n_bins=50, alpha=0.1, device=device) 55 | clf.fit(X_torch, return_time=True) 56 | decision_scores = clf.decision_scores_ 57 | decision_scores = np.nan_to_num(decision_scores) 58 | end = time.time() 59 | print('HBOS total time', end - start) 60 | print('HBOS GPU time', clf.gpu_time) 61 | 62 | def pca_measure(): 63 | key = 'PCA' 64 | start = time.time() 65 | clf = PCA(n_components=5, device=device) 66 | clf.fit(X_torch, return_time=True) 67 | decision_scores = clf.decision_scores_ 68 | decision_scores = np.nan_to_num(decision_scores) 69 | end = time.time() 70 | print('PCA total time', end - start) 71 | print('PCA GPU time', clf.gpu_time) 72 | 73 | def lof_measure(): 74 | key = 'LOF' 75 | start = time.time() 76 | clf = LOF(n_neighbors=20, batch_size=batch_size, device=device) 77 | clf.fit(X_torch, return_time=True) 78 | decision_scores = clf.decision_scores_ 79 | decision_scores = np.nan_to_num(decision_scores) 80 | end = time.time() 81 | print('LOF total time', end - start) 82 | print('LOF GPU time', clf.gpu_time) 83 | 84 | def abod_measure(): 85 | key = 'ABOD' 86 | start = time.time() 87 | clf = ABOD(n_neighbors=20, batch_size=batch_size, device=device) 88 | clf.fit(X_torch, return_time=True) 89 | decision_scores = clf.decision_scores_ 90 | decision_scores = np.nan_to_num(decision_scores) 91 | end = time.time() 92 | print('ABOD total time', end - start) 93 | print('ABOD GPU time', clf.gpu_time) 94 | 95 | # knn_measure() 96 | # hbos_measure() 97 | # pca_measure() 98 | lof_measure() 99 | # abod_measure() 100 | 101 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mpmath 2 | numpy>=1.13 3 | pyod>=1.0.4 4 | scikit_learn>=0.21 5 | scipy 6 | torch>=1.7 7 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | # read the contents of README file 4 | from os import path 5 | from io import open # for Python 2 and 3 compatibility 6 | 7 | # get __version__ from _version.py 8 | ver_file = path.join('pytod', 'version.py') 9 | with open(ver_file) as f: 10 | exec(f.read()) 11 | 12 | this_directory = path.abspath(path.dirname(__file__)) 13 | 14 | 15 | # read the contents of README.rst 16 | def readme(): 17 | with open(path.join(this_directory, 'README.rst'), encoding='utf-8') as f: 18 | return f.read() 19 | 20 | 21 | # read the contents of requirements.txt 22 | with open(path.join(this_directory, 'requirements.txt'), 23 | encoding='utf-8') as f: 24 | requirements = f.read().splitlines() 25 | 26 | setup( 27 | name='pytod', 28 | version=__version__, 29 | description='Tensor-based outlier detection. A general GPU-accelerated framework.', 30 | long_description=readme(), 31 | long_description_content_type='text/x-rst', 32 | author='Yue Zhao', 33 | author_email='zhaoy@cmu.edu', 34 | url='https://github.com/yzhao062/pytod', 35 | download_url='https://github.com/yzhao062/pytod/archive/master.zip', 36 | keywords=['pytorch', 'tensor operation', 'outlier detection', 'acceleration', 37 | 'data mining', 'machine learning', 'python'], 38 | packages=find_packages(exclude=['test']), 39 | include_package_data=True, 40 | install_requires=requirements, 41 | setup_requires=['setuptools>=38.6.0'], 42 | classifiers=[ 43 | 'Development Status :: 2 - Pre-Alpha', 44 | 'Intended Audience :: Education', 45 | 'Intended Audience :: Financial and Insurance Industry', 46 | 'Intended Audience :: Science/Research', 47 | 'Intended Audience :: Developers', 48 | 'Intended Audience :: Information Technology', 49 | 'License :: OSI Approved :: BSD License', 50 | 'Programming Language :: Python :: 3.5', 51 | 'Programming Language :: Python :: 3.6', 52 | 'Programming Language :: Python :: 3.7', 53 | ], 54 | ) 55 | --------------------------------------------------------------------------------