├── tests ├── __init__.py ├── data │ ├── classification.fits │ ├── classification_str.h5 │ ├── classification_pandas.h5 │ ├── classification_pandas_str.h5 │ └── classification_str.txt ├── test_kde_predictor.py ├── test_recommenders.py ├── test_regression.py ├── test_labellers.py ├── test_proto_io.py ├── test_database.py ├── test_proto_wrappers.py └── test_predictors.py ├── acton ├── proto │ ├── __init__.py │ ├── acton.proto │ ├── io.py │ ├── wrappers.py │ └── acton_pb2.py ├── __init__.py ├── acton.proto ├── plot.py ├── kde_predictor.py ├── labellers.py ├── cli.py ├── recommenders.py └── acton.py ├── docs ├── design │ ├── acton.pdf │ ├── mlai-synbio.pdf │ ├── mlai-synbio-notes.md │ ├── NewActon.tex │ ├── api_design.tex │ └── survey.rst ├── source │ ├── modules.rst │ ├── acton.proto.rst │ ├── acton.rst │ └── dev.rst ├── index.rst ├── protobuf_spec.ipynb ├── PRESCAL Updating Design.ipynb ├── make.bat ├── Makefile └── conf.py ├── test ├── compile_proto ├── requirements-docs.txt ├── .buildconfig ├── setup.cfg ├── install_protobuf ├── requirements.txt ├── examples ├── simulate_active_learning ├── simulate_active_learning.py ├── simulate_thompson_sampling.py ├── multiclass_classification.py └── classification.py ├── .travis.yml ├── LICENSE ├── .gitignore ├── Makefile ├── setup.py └── README.rst /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /acton/proto/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /acton/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.3.3' 2 | -------------------------------------------------------------------------------- /acton/acton.proto: -------------------------------------------------------------------------------- 1 | -ThompsonSamplingRecommender | TensorPredictor -------------------------------------------------------------------------------- /docs/design/acton.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengsoonong/acton/HEAD/docs/design/acton.pdf -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | acton 2 | ===== 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | acton 8 | -------------------------------------------------------------------------------- /test: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | nosetests 3 | flake8 acton --exclude *_pb2.py --max-line-length=80 4 | -------------------------------------------------------------------------------- /docs/design/mlai-synbio.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengsoonong/acton/HEAD/docs/design/mlai-synbio.pdf -------------------------------------------------------------------------------- /compile_proto: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | protoc -I=acton/proto --python_out=acton/proto acton/proto/acton.proto 3 | -------------------------------------------------------------------------------- /tests/data/classification.fits: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengsoonong/acton/HEAD/tests/data/classification.fits -------------------------------------------------------------------------------- /tests/data/classification_str.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengsoonong/acton/HEAD/tests/data/classification_str.h5 -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | Sphinx==1.4.8 2 | numpydoc==0.6.0 3 | typing>=3.5.2 4 | sphinx_rtd_theme==0.1.9 5 | mock==2.0.0 6 | -------------------------------------------------------------------------------- /tests/data/classification_pandas.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengsoonong/acton/HEAD/tests/data/classification_pandas.h5 -------------------------------------------------------------------------------- /tests/data/classification_pandas_str.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chengsoonong/acton/HEAD/tests/data/classification_pandas_str.h5 -------------------------------------------------------------------------------- /.buildconfig: -------------------------------------------------------------------------------- 1 | [default] 2 | name=Default 3 | runtime=host 4 | config-opts= 5 | run-opts= 6 | prefix=/home/admin-u6015325/.cache/gnome-builder/install/acton/host 7 | app-id= 8 | postbuild= 9 | prebuild= 10 | default=true 11 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.3.3 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:docs/conf.py] 7 | search = version = '{current_version}' 8 | replace = version = '{new_version}' 9 | 10 | [bumpversion:file:acton/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 0 16 | 17 | [flake8] 18 | exclude = docs 19 | 20 | -------------------------------------------------------------------------------- /install_protobuf: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | if [ ! -d "$HOME/protobuf/lib" ]; then 3 | wget https://github.com/google/protobuf/archive/v3.1.0.tar.gz 4 | tar -xzvf v3.1.0.tar.gz 5 | pushd protobuf-3.1.0/ && ./autogen.sh && ./configure --prefix=$HOME/protobuf && make && sudo make install && popd 6 | pushd protobuf-3.1.0/python && python setup.py build && sudo python setup.py install && popd 7 | else 8 | echo "Using cached protobuf directory." 9 | fi 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | h5py>=2.6.0 2 | protobuf>=3.1.0 3 | numpy>=1.11.0 4 | scipy>=0.17.0 5 | scikit-learn>=0.18.1 6 | typing>=3.5.2 7 | astropy>=1.1.2 8 | pip>=8.1.2 9 | bumpversion==0.5.3 10 | wheel==0.29.0 11 | watchdog==0.8.3 12 | flake8==3.2.0 13 | coverage==4.1 14 | Sphinx==1.4.8 15 | numpydoc==0.6.0 16 | pyflakes<1.4.0,>=1.3.0 17 | pandas>=0.15.2 18 | nose==1.3.7 19 | click==6.6 20 | tables>=3.3.0 21 | sphinx_rtd_theme==0.1.9 22 | mock==2.0.0 23 | GPy==1.9.2 24 | matplotlib==2.0.0 25 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Acton documentation master file, created by 2 | sphinx-quickstart on Sat Jan 21 12:22:55 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to Acton's documentation! 7 | ================================= 8 | 9 | Contents: 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | 14 | source/modules 15 | source/dev 16 | 17 | 18 | Indices and tables 19 | ================== 20 | 21 | * :ref:`genindex` 22 | * :ref:`modindex` 23 | * :ref:`search` 24 | 25 | -------------------------------------------------------------------------------- /examples/simulate_active_learning: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script simulates active learning using the component CLI. The same task 4 | # is accomplished in Python in simulate_active_learning.py. 5 | 6 | # Label some initial data points. 7 | echo "1 8 | 2 9 | 3 10 | 4 11 | 5 12 | 6 13 | 7 14 | 8 15 | 9 16 | 10 17 | " | acton-label \ 18 | --data tests/data/classification.txt \ 19 | --label col20 -v > labels.pb 20 | 21 | for (( epoch = 0; epoch < 10; epoch++ )); do 22 | echo Epoch $epoch 23 | acton-predict -v < labels.pb | acton-recommend -v | acton-label -v > labels_.pb 24 | mv labels_.pb labels.pb 25 | done 26 | 27 | acton-predict -v < labels.pb > predictions.pb 28 | -------------------------------------------------------------------------------- /docs/source/acton.proto.rst: -------------------------------------------------------------------------------- 1 | acton.proto package 2 | =================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | acton.proto.acton_pb2 module 8 | ---------------------------- 9 | 10 | .. automodule:: acton.proto.acton_pb2 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | acton.proto.io module 16 | --------------------- 17 | 18 | .. automodule:: acton.proto.io 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | acton.proto.wrappers module 24 | --------------------------- 25 | 26 | .. automodule:: acton.proto.wrappers 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | 32 | Module contents 33 | --------------- 34 | 35 | .. automodule:: acton.proto 36 | :members: 37 | :undoc-members: 38 | :show-inheritance: 39 | -------------------------------------------------------------------------------- /tests/test_kde_predictor.py: -------------------------------------------------------------------------------- 1 | """Tests for kde_predictor.""" 2 | 3 | import unittest 4 | 5 | import acton.kde_predictor 6 | import numpy 7 | import sklearn.utils.estimator_checks 8 | 9 | 10 | class TestKDEClassifier(unittest.TestCase): 11 | 12 | def test_sklearn_interface(self): 13 | """KDEClassifier implements the scikit-learn interface.""" 14 | sklearn.utils.estimator_checks.check_estimator( 15 | acton.kde_predictor.KDEClassifier) 16 | 17 | def test_softmax(self): 18 | """_softmax correctly evaluates a softmax on array input.""" 19 | for axis in range(2): 20 | for _ in range(100): 21 | array = numpy.random.random(size=(100, 100)) * 1000 - 500 22 | softmax = acton.kde_predictor.KDEClassifier._softmax( 23 | array, axis=axis) 24 | for i in softmax.sum(axis=axis): 25 | self.assertAlmostEqual(i, 1) 26 | -------------------------------------------------------------------------------- /examples/simulate_active_learning.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # This script simulates active learning using the component interface. The same 4 | # task is accomplished in bash in simulate_active_learning. 5 | 6 | from acton.acton import predict, recommend, label 7 | import acton.database 8 | from acton.proto.wrappers import Recommendations 9 | 10 | # Initial labels. 11 | recommendation_indices = list(range(10)) 12 | with acton.database.ASCIIReader( 13 | 'tests/data/classification.txt', 14 | feature_cols=[], label_col='col20') as db: 15 | recommendations = Recommendations.make( 16 | recommended_ids=recommendation_indices, 17 | labelled_ids=[], 18 | recommender='None', 19 | db=db) 20 | labels = label(recommendations) 21 | 22 | # Main loop. 23 | for epoch in range(10): 24 | print('Epoch', epoch) 25 | labels = label( 26 | recommend(predict(labels, 'LogisticRegression'), 'RandomRecommender')) 27 | 28 | print('Labelled instances:', labels.ids) 29 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | # Config file for automatic testing at travis-ci.org 2 | 3 | language: python 4 | 5 | cache: 6 | - pip 7 | - directories: 8 | - $HOME/protobuf 9 | 10 | python: 11 | - "3.4" 12 | - "3.5" 13 | - "3.6" 14 | 15 | notifications: 16 | email: false 17 | 18 | before_install: 19 | - sudo apt-get update 20 | - sudo apt-get install libhdf5-serial-dev python-tables 21 | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; 22 | - bash miniconda.sh -b -p $HOME/miniconda 23 | - export PATH="$HOME/miniconda/bin:$PATH" 24 | - hash -r 25 | - conda config --set always_yes yes --set changeps1 no 26 | - conda update -q conda 27 | - conda info -a # for debugging 28 | 29 | install: 30 | # Install scientific Python 31 | - conda create -q -n testenv python=$TRAVIS_PYTHON_VERSION numpy llvmlite scipy pytables 32 | - source activate testenv 33 | # Install protobuf 34 | - ./install_protobuf 35 | # Install other requirements 36 | - pip install -r requirements.txt 37 | - python setup.py install 38 | 39 | script: 40 | - nosetests 41 | - flake8 acton --exclude *_pb2.py --max-line-length=80 42 | -------------------------------------------------------------------------------- /tests/test_recommenders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | test_recommenders 5 | ---------------------------------- 6 | 7 | Tests for `recommenders` module. 8 | """ 9 | 10 | import numpy 11 | import unittest 12 | import unittest.mock 13 | 14 | import acton.recommenders 15 | 16 | 17 | class TestRandomRecommender(unittest.TestCase): 18 | 19 | def setUp(self): 20 | pass 21 | 22 | def tearDown(self): 23 | pass 24 | 25 | def test_recommend(self): 26 | """RandomRecommender recommends an instance.""" 27 | ids = set(range(1000)) 28 | rr = acton.recommenders.RandomRecommender(None) 29 | id_ = rr.recommend(ids, predictions=None) 30 | self.assertIn(id_[0], ids) 31 | 32 | 33 | class TestMarginRecommender(unittest.TestCase): 34 | 35 | def setUp(self): 36 | pass 37 | 38 | def tearDown(self): 39 | pass 40 | 41 | def test_recommend(self): 42 | """MarginRecommender recommends an instance.""" 43 | n = 10 44 | c = 3 45 | ids = list(range(n)) 46 | predictions = numpy.random.random(size=(n, 1, c)) 47 | db = unittest.mock.Mock() 48 | mr = acton.recommenders.MarginRecommender(db) 49 | id_ = mr.recommend(ids, predictions=predictions) 50 | self.assertIn(id_[0], ids) 51 | -------------------------------------------------------------------------------- /docs/design/mlai-synbio-notes.md: -------------------------------------------------------------------------------- 1 | # Alignment between MLAI and SynBio 2 | 3 | Refer to the cycles, summarised in mlai-synbio.pdf 4 | 5 | - DBTL cycle of SynBio FSP 6 | - Predict, Recommend, Label cycle in machine learning 7 | 8 | https://github.com/chengsoonong/acton/blob/master/docs/design/acton.pdf 9 | 10 | 11 | 12 | ## MLAI 13 | 14 | Two types of machine learning / artificial intelligence problems: 15 | 16 | 1. Prediction: given an example (DNA sequence) predict the label (gene expression) 17 | 2. Recommendation: Choose an example (DNA sequence) to label (Build and Test) 18 | 19 | 20 | ## Common vision 21 | 22 | | MLAI | SynBio | 23 | | ---- | ------ | 24 | | prediction | Learn | 25 | | recommend | Design | 26 | | - | Build | 27 | | label | Test | 28 | 29 | #### prediction = learn 30 | 31 | - standard MLAI methods (supervised learning) 32 | - research on data representation 33 | - DNA, RNA sequence 34 | - time series, spectrum 35 | - protein interaction graph 36 | 37 | #### recommender 38 | 39 | - choose where to measure (e.g. which RBS sequence to experiment) 40 | - need to define the goal of the experiment (e.g. maximise GFP gene expression) 41 | - practical constraints of Build stage can be taken into account 42 | - long term: understand causal effects 43 | 44 | #### labeller 45 | 46 | - Build + Test 47 | - how to combine different experimental results? 48 | - data and interface management 49 | - feed into learning 50 | -------------------------------------------------------------------------------- /tests/test_regression.py: -------------------------------------------------------------------------------- 1 | """Tests for regression functionality.""" 2 | 3 | import logging 4 | import os.path 5 | import tempfile 6 | import unittest 7 | 8 | import acton.database 9 | import acton.predictors 10 | import numpy 11 | 12 | 13 | class TestRegression(unittest.TestCase): 14 | """Acton supports regression.""" 15 | 16 | def setUp(self): 17 | self.tempdir = tempfile.TemporaryDirectory() 18 | self.db_path = os.path.join(self.tempdir.name, 'db.h5') 19 | 20 | def tearDown(self): 21 | self.tempdir.cleanup() 22 | 23 | def test_linear_regression(self): 24 | """LinearRegression predictor can find a linear fit.""" 25 | # Some sample data. 26 | numpy.random.seed(0) 27 | xs = numpy.linspace(0, 1, 100) 28 | ys = 2 * xs - 1 29 | noise = numpy.random.normal(size=xs.shape, scale=0.2) 30 | xs = xs.reshape((-1, 1)) 31 | ts = (ys + noise).reshape((1, -1, 1)) 32 | ids = list(range(100)) 33 | 34 | with acton.database.ManagedHDF5Database(self.db_path) as db: 35 | db.write_features(ids, xs) 36 | db.write_labels([0], ids, ts) 37 | lr = acton.predictors.PREDICTORS['LinearRegression'](db) 38 | lr.fit(ids) 39 | predictions, _variances = lr.predict(ids) 40 | logging.debug('Labels: {}'.format(ys)) 41 | logging.debug('Predictions: {}'.format(predictions)) 42 | self.assertTrue(numpy.allclose(ys, predictions.ravel(), atol=0.2)) 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2016, Cheng Soon Ong 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /tests/test_labellers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | test_labellers 5 | ---------------------------------- 6 | 7 | Tests for `labellers` module. 8 | """ 9 | 10 | import os.path 11 | import tempfile 12 | import unittest 13 | 14 | import acton.labellers 15 | import numpy 16 | 17 | 18 | class TestASCIITableLabeller(unittest.TestCase): 19 | 20 | def setUp(self): 21 | self.tempdir = tempfile.TemporaryDirectory() 22 | 23 | # Make an ASCII table. We'll just use part of the Norris 2006 catalogue 24 | # (Table 6, Norris et al. 2006). 25 | table = """\ 26 | name |ra |dec |is_agn 27 | ATCDFS J032637.29-285738.2|03 26 37.30|-28 57 38.3|1 28 | ATCDFS J032642.55-285715.6|03 26 42.56|-28 57 15.7|1 29 | ATCDFS J032629.13-285648.7|03 26 29.13|-28 56 48.7|0 30 | ATCDFS J033056.94-285637.2|03 30 56.95|-28 56 37.3|0 31 | ATCDFS J033019.98-285635.5|03 30 19.98|-28 56 35.5|0 32 | ATCDFS J033126.71-285630.3|03 31 26.72|-28 56 30.3|0 33 | """ 34 | self.path = os.path.join(self.tempdir.name, 'table.dat') 35 | with open(self.path, 'w') as f: 36 | f.write(table) 37 | 38 | def tearDown(self): 39 | self.tempdir.cleanup() 40 | 41 | def test_query(self): 42 | """ASCIITableLabeller can query from a table.""" 43 | labeller = acton.labellers.ASCIITableLabeller( 44 | self.path, 'name', 'is_agn') 45 | self.assertEqual( 46 | labeller.query(0), 47 | numpy.array([[1]])) 48 | self.assertEqual( 49 | labeller.query(4), 50 | numpy.array([[0]])) 51 | -------------------------------------------------------------------------------- /docs/source/acton.rst: -------------------------------------------------------------------------------- 1 | acton package 2 | ============= 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | 9 | acton.proto 10 | 11 | Submodules 12 | ---------- 13 | 14 | acton.acton module 15 | ------------------ 16 | 17 | .. automodule:: acton.acton 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | 22 | acton.cli module 23 | ---------------- 24 | 25 | .. automodule:: acton.cli 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | 30 | acton.database module 31 | --------------------- 32 | 33 | .. automodule:: acton.database 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | acton.kde_predictor module 39 | -------------------------- 40 | 41 | .. automodule:: acton.kde_predictor 42 | :members: 43 | :undoc-members: 44 | :show-inheritance: 45 | 46 | acton.labellers module 47 | ---------------------- 48 | 49 | .. automodule:: acton.labellers 50 | :members: 51 | :undoc-members: 52 | :show-inheritance: 53 | 54 | acton.plot module 55 | ----------------- 56 | 57 | .. automodule:: acton.plot 58 | :members: 59 | :undoc-members: 60 | :show-inheritance: 61 | 62 | acton.predictors module 63 | ----------------------- 64 | 65 | .. automodule:: acton.predictors 66 | :members: 67 | :undoc-members: 68 | :show-inheritance: 69 | 70 | acton.recommenders module 71 | ------------------------- 72 | 73 | .. automodule:: acton.recommenders 74 | :members: 75 | :undoc-members: 76 | :show-inheritance: 77 | 78 | 79 | Module contents 80 | --------------- 81 | 82 | .. automodule:: acton 83 | :members: 84 | :undoc-members: 85 | :show-inheritance: 86 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | 81 | # virtualenv 82 | venv/ 83 | ENV/ 84 | 85 | # Spyder project settings 86 | .spyderproject 87 | 88 | # Rope project settings 89 | .ropeproject 90 | 91 | # PyCharm 92 | .idea 93 | 94 | # Minutes, data 95 | minutes/ 96 | data/ 97 | !tests/data/ 98 | 99 | # Mac 100 | *.DS_Store 101 | 102 | # LaTeX 103 | *.aux 104 | *.fdb_latexmk 105 | *.fls 106 | -------------------------------------------------------------------------------- /examples/simulate_thompson_sampling.py: -------------------------------------------------------------------------------- 1 | # direct to parent folder 2 | # sys.path.append("..") 3 | 4 | from acton.database import LabelOnlyASCIIReader 5 | import acton.acton 6 | import numpy as np 7 | import logging 8 | 9 | logging.basicConfig(level=logging.DEBUG) 10 | 11 | _path = 'acton/tests/kg-data/nation/triples.txt' 12 | output_path = 'acton/acton/acton.proto' 13 | n_dim = 10 14 | 15 | #data = io_ascii.read(_path) 16 | 17 | #reader = LabelOnlyASCIIReader(_path, n_dim) 18 | # reader.__enter__() 19 | 20 | 21 | TS = 0.0 22 | RANDOM = 1.0 23 | N_EPOCHS = 100 24 | repeated_labelling = False 25 | 26 | with LabelOnlyASCIIReader(_path, n_dim) as reader: 27 | n_relations = reader.n_relations 28 | n_entities = reader.n_entities 29 | totoal_size = n_relations * n_entities * n_entities 30 | ids = np.arange(totoal_size) 31 | 32 | # TS 33 | TS_train_error_list, TS_test_error_list, TS_gain = \ 34 | acton.acton.simulate_active_learning(ids, reader, {}, output_path, 35 | n_epochs=N_EPOCHS, 36 | recommender='ThompsonSamplingRecommender', 37 | predictor='TensorPredictor', 38 | labeller='LabelOnlyDatabaseLabeller', 39 | diversity=TS, 40 | repeated_labelling=repeated_labelling) 41 | # Random 42 | RD_train_error_list, RD_test_error_list, RD_gain = \ 43 | acton.acton.simulate_active_learning(ids, reader, {}, output_path, 44 | n_epochs=N_EPOCHS, 45 | recommender='ThompsonSamplingRecommender', 46 | predictor='TensorPredictor', 47 | labeller='LabelOnlyDatabaseLabeller', 48 | diversity=RANDOM, 49 | repeated_labelling=repeated_labelling) 50 | 51 | acton.acton.plot(TS_train_error_list, TS_test_error_list, TS_gain, 52 | RD_train_error_list, RD_test_error_list, RD_gain) 53 | -------------------------------------------------------------------------------- /docs/design/NewActon.tex: -------------------------------------------------------------------------------- 1 | \documentclass[11pt,twoside]{article} 2 | 3 | \usepackage[margin=2cm]{geometry} 4 | \usepackage{tikz} 5 | \usetikzlibrary{positioning,fit,shapes} 6 | 7 | \newcommand{\name}[1]{ {\color{red}{\sffamily\bfseries{#1}}} } 8 | 9 | \title{NewActon design} 10 | \author{Mengyan Zhang} 11 | \date{04 October 2018} 12 | 13 | \begin{document} 14 | \maketitle 15 | 16 | \section*{NewActon} 17 | 18 | Recommenders, predictors, and labellers are stateful "agents": only they can access their own internal state; other agents should ask them directly. Each agent stores their state in its own file. In this case, the whole system can be rebooted from any point with low coupling. 19 | 20 | \begin{figure}[h] 21 | \centering 22 | \includegraphics[scale = 0.7]{docs/design/acton.pdf} 23 | \caption{NewActon Design Choice} 24 | \label{fig:my_label} 25 | \end{figure} 26 | 27 | \subsection*{What's in state file?} 28 | 29 | Labeller: observed labels \newline 30 | Predictor: latest prediction \newline 31 | Recommender: recommendations 32 | 33 | \subsection*{What should be passed?} 34 | 35 | Labeller $\rightarrow$ Predictor: observed labels \newline 36 | Predictor $\rightarrow$ Recommender: prediction array (matrix) \newline 37 | Recommender $\rightarrow$ Labeller: recommendations \newline 38 | What should be passed over is exactly the information should be recorded in the state file. 39 | 40 | \subsection*{Graph Extension} 41 | 42 | A graph model contains nodes and edges, where nodes represent entities (e.g. person) and edges represent relations (e.g. friendship). We use a three-way tensor $\mathcal{X} \in \{0,1\}^{K \times N \times N}$ to represent the graph model, where K is the number of relations and N is the number of entities, and $x_{ikj}$ indicates whether the triple is valid.\newline 43 | To recommend as many valid triples to get labels as possible, we use Thompson Sampling model to predict the probability of a triple being valid and recommend the triple with the highest probability. \newline 44 | \newline 45 | Labeller: the same as before. \newline 46 | Recommender: Thompson sampling, recommends triple (i,k,j) = {argmax}$_{i,k,j} P(x_{ikj})$ \newline 47 | Predictor: update posterior distribution based on Bayesian inference; sample latent variables from the posterior distribution. 48 | 49 | 50 | \end{document} 51 | -------------------------------------------------------------------------------- /examples/multiclass_classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Using Acton to test uncertainty sampling on the iris libsvm dataset.""" 3 | 4 | import logging 5 | import os.path 6 | import tempfile 7 | 8 | import acton.acton 9 | import acton.plot 10 | import h5py 11 | import requests 12 | import sklearn.datasets 13 | import sklearn.preprocessing 14 | 15 | with tempfile.TemporaryDirectory() as tempdir: 16 | # Download the dataset. 17 | # We'll store the dataset in this file: 18 | raw_filename = os.path.join(tempdir, 'iris.dat') 19 | dataset_response = requests.get( 20 | 'https://www.csie.ntu.edu.tw/' 21 | '~cjlin/libsvmtools/datasets/multiclass/iris.scale') 22 | with open(raw_filename, 'w') as raw_file: 23 | raw_file.write(dataset_response.text) 24 | # Convert the dataset into a format we can use. It's currently libsvm. 25 | X, y = sklearn.datasets.load_svmlight_file(raw_filename) 26 | # Encode labels. 27 | y = sklearn.preprocessing.LabelEncoder().fit_transform(y) 28 | # We'll just save it directly into an HDF5 file: 29 | input_filename = os.path.join(tempdir, 'iris.h5') 30 | with h5py.File(input_filename, 'w') as input_file: 31 | input_file.create_dataset('features', data=X.toarray()) 32 | input_file.create_dataset('labels', data=y) 33 | 34 | # We'll save output to this file: 35 | output_base_filename = os.path.join(tempdir, 'iris_base.out') 36 | output_unct_filename = os.path.join(tempdir, 'iris_unct.out') 37 | 38 | # Run Acton. 39 | logging.root.setLevel(logging.DEBUG) 40 | acton.acton.main( 41 | data_path=input_filename, 42 | feature_cols=['features'], 43 | label_col='labels', 44 | output_path=output_base_filename, 45 | n_epochs=200, 46 | initial_count=10, 47 | recommender='RandomRecommender', 48 | predictor='LogisticRegression') 49 | acton.acton.main( 50 | data_path=input_filename, 51 | feature_cols=['features'], 52 | label_col='labels', 53 | output_path=output_unct_filename, 54 | n_epochs=200, 55 | initial_count=10, 56 | recommender='UncertaintyRecommender', 57 | predictor='LogisticRegression') 58 | 59 | # Plot the results. 60 | with open(output_base_filename, 'rb') as predictions_base, \ 61 | open(output_unct_filename, 'rb') as predictions_unct: 62 | acton.plot.plot([predictions_base, predictions_unct]) 63 | -------------------------------------------------------------------------------- /tests/data/classification_str.txt: -------------------------------------------------------------------------------- 1 | col0 col1 col2 label 2 | 3 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 4 | 5 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 6 | 7 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 8 | 9 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 10 | 11 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 12 | 13 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 14 | 15 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 16 | 17 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 18 | 19 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 20 | 21 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 22 | 23 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 24 | 25 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 26 | 27 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 28 | 29 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 30 | 31 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 32 | 33 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 34 | 35 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 36 | 37 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 38 | 39 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 40 | 41 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 42 | 43 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 44 | 45 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 46 | 47 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 48 | 49 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 50 | 51 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 52 | 53 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 54 | 55 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 56 | 57 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 58 | 59 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 60 | 61 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 62 | 63 | -0.08169237160307448 0.11282227712014925 1.7861692242526752 abc 64 | 65 | -0.8442863879085201 1.3662242178205282 -0.8565624884616692 def 66 | 67 | -0.5021565359867708 1.3747365419311302 1.4364060824139213 abc 68 | -------------------------------------------------------------------------------- /docs/source/dev.rst: -------------------------------------------------------------------------------- 1 | Developer Documentation 2 | ======================= 3 | 4 | Contributing 5 | ------------ 6 | 7 | We accept pull requests on GitHub. Contributions must be PEP8 compliant and pass 8 | formatting and function tests in the test script ``/test``. 9 | 10 | Adding a New Predictor 11 | ---------------------- 12 | 13 | A predictor is a class that implements ``acton.predictors.Predictor``. Adding a 14 | new predictor amounts to implementing a subclass of ``Predictor`` and 15 | registering it in ``acton.predictors.PREDICTORS``. 16 | 17 | Predictors must implement: 18 | 19 | - ``__init__(db: acton.database.Database, *args, **kwargs)``, which stores a reference to the database (and does any other initialisation). 20 | - ``fit(ids: Iterable[int])``, which takes an iterable of IDs and fits a model 21 | to the associated features and labels, 22 | - ``predict(ids: Sequence[int]) -> numpy.ndarray``, which takes a sequence of 23 | IDs and predicts the associated labels. 24 | - ``reference_predict(ids: Sequence[int]) -> numpy.ndarray``, which behaves the same as ``predict`` but uses the best possible model. 25 | 26 | Predictors should store data-based values such as the model in attributes ending in an underscore, e.g. ``self.model_``. 27 | 28 | Why Does Acton Use Predictor? 29 | ############################# 30 | 31 | Acton makes use of ``Predictor`` classes, which are often just wrappers for 32 | scikit-learn classes. This raises the question: Why not just use scikit-learn 33 | classes? 34 | 35 | This design decision was made because Acton must support predictors that do not 36 | fit the scikit-learn API, and so using scikit-learn predictors directly would 37 | mean that there is no unified API for predictors. An example of where Acton 38 | diverges from scikit-learn is that scikit-learn does not support multiple 39 | labellers. 40 | 41 | Adding a New Recommender 42 | ------------------------ 43 | 44 | A recommender is a class that implements ``acton.recommenders.Recommender``. Adding a new recommender amounts to implementing a subclass of ``Recommender`` and registering it in ``acton.recommenders.RECOMMENDERS``. 45 | 46 | Recommenders must implement: 47 | 48 | - ``__init__(db: acton.database.Database, *args, **kwargs)``, which stores a reference to the database (and does any other initialisation). 49 | - ``recommend(ids: Iterable[int], predictions: numpy.ndarray, n: int=1, diversity: float=0.5)` -> Sequence[int]``, which recommends ``n`` IDs from the given IDs based on the associated predictions. 50 | -------------------------------------------------------------------------------- /examples/classification.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Using Acton to test uncertainty sampling on the Australian libsvm dataset.""" 3 | 4 | import logging 5 | import os.path 6 | import tempfile 7 | 8 | import acton.acton 9 | import acton.plot 10 | import h5py 11 | import requests 12 | import sklearn.datasets 13 | import sklearn.preprocessing 14 | 15 | with tempfile.TemporaryDirectory() as tempdir: 16 | # Download the dataset. 17 | # We'll store the dataset in this file: 18 | raw_filename = os.path.join(tempdir, 'australian.dat') 19 | dataset_response = requests.get( 20 | 'https://www.csie.ntu.edu.tw/' 21 | '%7Ecjlin/libsvmtools/datasets/binary/australian') 22 | with open(raw_filename, 'w') as raw_file: 23 | raw_file.write(dataset_response.text) 24 | # Convert the dataset into a format we can use. It's currently libsvm. 25 | X, y = sklearn.datasets.load_svmlight_file(raw_filename) 26 | # We also have -1/1 as labels, but we need 0/1, so we'll convert. 27 | y = sklearn.preprocessing.LabelEncoder().fit_transform(y) 28 | # We'll just save it directly into an HDF5 file: 29 | input_filename = os.path.join(tempdir, 'australian.h5') 30 | with h5py.File(input_filename, 'w') as input_file: 31 | input_file.create_dataset('features', data=X.toarray()) 32 | input_file.create_dataset('labels', data=y) 33 | 34 | # We'll save output to this file: 35 | output_base_filename = os.path.join(tempdir, 'australian_base.out') 36 | output_unct_filename = os.path.join(tempdir, 'australian_unct.out') 37 | 38 | # Run Acton. 39 | logging.root.setLevel(logging.DEBUG) 40 | acton.acton.main( 41 | data_path=input_filename, 42 | feature_cols=['features'], 43 | label_col='labels', 44 | output_path=output_base_filename, 45 | n_epochs=500, 46 | initial_count=100, 47 | recommender='RandomRecommender', 48 | predictor='LogisticRegression') 49 | acton.acton.main( 50 | data_path=input_filename, 51 | feature_cols=['features'], 52 | label_col='labels', 53 | output_path=output_unct_filename, 54 | n_epochs=500, 55 | initial_count=100, 56 | recommender='UncertaintyRecommender', 57 | predictor='LogisticRegression') 58 | 59 | # Plot the results. 60 | with open(output_base_filename, 'rb') as predictions_base, \ 61 | open(output_unct_filename, 'rb') as predictions_unct: 62 | acton.plot.plot([predictions_base, predictions_unct]) 63 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | define BROWSER_PYSCRIPT 4 | import os, webbrowser, sys 5 | try: 6 | from urllib import pathname2url 7 | except: 8 | from urllib.request import pathname2url 9 | 10 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 11 | endef 12 | export BROWSER_PYSCRIPT 13 | 14 | define PRINT_HELP_PYSCRIPT 15 | import re, sys 16 | 17 | for line in sys.stdin: 18 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 19 | if match: 20 | target, help = match.groups() 21 | print("%-20s %s" % (target, help)) 22 | endef 23 | export PRINT_HELP_PYSCRIPT 24 | BROWSER := python3 -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | 32 | clean-build: ## remove build artifacts 33 | rm -fr build/ 34 | rm -fr dist/ 35 | rm -fr .eggs/ 36 | find . -name '*.egg-info' -exec rm -fr {} + 37 | find . -name '*.egg' -exec rm -f {} + 38 | 39 | clean-pyc: ## remove Python file artifacts 40 | find . -name '*.pyc' -exec rm -f {} + 41 | find . -name '*.pyo' -exec rm -f {} + 42 | find . -name '*~' -exec rm -f {} + 43 | find . -name '__pycache__' -exec rm -fr {} + 44 | 45 | clean-test: ## remove test and coverage artifacts 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | 49 | lint: ## check style with flake8 50 | flake8 acton tests 51 | 52 | test: ## run tests quickly with the default Python 53 | 54 | python3 setup.py test 55 | 56 | coverage: ## check code coverage quickly with the default Python 57 | 58 | coverage run --source acton setup.py test 59 | 60 | coverage report -m 61 | coverage html 62 | $(BROWSER) htmlcov/index.html 63 | 64 | docs: ## generate Sphinx HTML documentation, including API docs 65 | rm -f docs/acton.rst 66 | rm -f docs/modules.rst 67 | sphinx-apidoc -o docs/ acton 68 | $(MAKE) -C docs clean 69 | $(MAKE) -C docs html 70 | $(BROWSER) docs/_build/html/index.html 71 | 72 | servedocs: docs ## compile the docs watching for changes 73 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 74 | 75 | release: clean ## package and upload a release 76 | python3 setup.py sdist upload 77 | python3 setup.py bdist_wheel upload 78 | 79 | dist: clean ## builds source and wheel package 80 | python3 setup.py sdist 81 | python3 setup.py bdist_wheel 82 | ls -l dist 83 | 84 | install: clean ## install the package to the active Python's site-packages 85 | python3 setup.py install 86 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import setup 4 | 5 | from acton import __version__ 6 | 7 | with open('README.rst') as readme_file: 8 | readme = readme_file.read() 9 | 10 | # with open('HISTORY.rst') as history_file: 11 | # history = history_file.read() 12 | 13 | requirements = [ 14 | 'h5py>=2.6.0', 15 | 'protobuf>=3.1.0', 16 | 'numpy>=1.11.0', 17 | 'scipy>=0.17.0', 18 | 'scikit-learn>=0.18.1', 19 | 'typing>=3.5.2', 20 | 'astropy>=1.1.2', 21 | 'pip>=8.1.2', 22 | 'bumpversion==0.5.3', 23 | 'wheel==0.29.0', 24 | 'watchdog==0.8.3', 25 | 'flake8==3.2.0', 26 | 'coverage==4.1', 27 | 'Sphinx==1.4.8', 28 | 'numpydoc==0.6.0', 29 | 'pyflakes<1.4.0,>=1.3.0', 30 | 'pandas>=0.15.2', 31 | 'nose==1.3.7', 32 | 'click==6.6', 33 | 'tables>=3.3.0', 34 | 'sphinx_rtd_theme==0.1.9', 35 | 'mock==2.0.0', 36 | 'GPy', 37 | 'matplotlib==2.0.0', 38 | # 'paramz==0.7.4', 39 | ] 40 | 41 | test_requirements = [ 42 | 'flake8==3.2.0', 43 | 'pyflakes>=1.3.0', 44 | 'nose==1.3.7', 45 | ] 46 | 47 | setup( 48 | name='acton', 49 | version=__version__, 50 | description="A scientific research assistant", 51 | long_description=readme, # + '\n\n' + history, 52 | url='https://github.com/chengsoonong/acton', 53 | # Setup scripts don't support multiple authors, so this should be the main 54 | # author or the author that should be contacted regarding the module. 55 | author='Cheng Soon Ong', 56 | author_email='chengsoon.ong@anu.edu.au', 57 | packages=[ 58 | 'acton', 59 | 'acton.proto', 60 | ], 61 | # package_dir={'acton': 62 | # 'acton'}, 63 | entry_points={ 64 | 'console_scripts': [ 65 | 'acton=acton.cli:main', 66 | 'acton-predict=acton.cli:predict', 67 | 'acton-recommend=acton.cli:recommend', 68 | 'acton-label=acton.cli:label', 69 | ] 70 | }, 71 | include_package_data=True, 72 | install_requires=requirements, 73 | license="BSD license", 74 | zip_safe=False, 75 | keywords='machine-learning active-learning classification regression', 76 | classifiers=[ 77 | 'Development Status :: 2 - Pre-Alpha', 78 | 'Intended Audience :: Science/Research', 79 | 'Topic :: Scientific/Engineering', 80 | 'License :: OSI Approved :: BSD License', 81 | 'Natural Language :: English', 82 | 'Programming Language :: Python :: 3', 83 | 'Programming Language :: Python :: 3.4', 84 | 'Programming Language :: Python :: 3.5', 85 | ], 86 | test_suite='tests', 87 | tests_require=test_requirements 88 | ) 89 | -------------------------------------------------------------------------------- /acton/plot.py: -------------------------------------------------------------------------------- 1 | """Script to plot a dump of predictions.""" 2 | 3 | import itertools 4 | import sys 5 | from typing import Iterable 6 | from typing.io import BinaryIO 7 | 8 | import acton.proto.io 9 | from acton.proto.acton_pb2 import Predictions 10 | import acton.proto.wrappers 11 | import click 12 | import matplotlib.pyplot as plt 13 | import sklearn.metrics 14 | 15 | 16 | def plot(predictions: Iterable[BinaryIO]): 17 | """Plots predictions from a file. 18 | 19 | Parameters 20 | ---------- 21 | predictions 22 | Files containing predictions. 23 | """ 24 | if len(predictions) < 1: 25 | raise ValueError('Must have at least 1 set of predictions.') 26 | 27 | metadata = [] 28 | predictions, predictions_ = itertools.tee(predictions) 29 | for proto_file in predictions_: 30 | metadata.append(acton.proto.io.read_metadata(proto_file)) 31 | proto_file.seek(0) 32 | 33 | for meta, proto_file in zip(metadata, predictions): 34 | # Read in the first protobuf to get the database file. 35 | protobuf = next(acton.proto.io.read_protos(proto_file, Predictions)) 36 | protobuf = acton.proto.wrappers.Predictions(protobuf) 37 | with protobuf.DB() as db: 38 | accuracies = [] 39 | for protobuf in acton.proto.io.read_protos( 40 | proto_file, Predictions): 41 | protobuf = acton.proto.wrappers.Predictions(protobuf) 42 | ids = protobuf.predicted_ids 43 | predictions_ = protobuf.predictions 44 | assert predictions_.shape[0] == 1 45 | predictions_ = predictions_[0] 46 | labels = db.read_labels([0], ids).ravel() 47 | predicted_labels = predictions_.argmax(axis=1).ravel() 48 | predicted_labels = [str(p).encode('ascii') # quick and 49 | for p in predicted_labels] # dirty hack 50 | print(labels, predicted_labels) 51 | accuracies.append(sklearn.metrics.accuracy_score( 52 | labels, predicted_labels)) 53 | 54 | plt.plot(accuracies, label=meta.decode('ascii', errors='replace')) 55 | 56 | plt.xlabel('Number of additional labels') 57 | plt.ylabel('Accuracy score') 58 | plt.legend() 59 | plt.show() 60 | 61 | 62 | @click.command() 63 | @click.argument('predictions', 64 | type=click.File('rb'), 65 | nargs=-1, 66 | required=True) 67 | def _plot(predictions: Iterable[BinaryIO]): 68 | """Plots predictions from a file. 69 | 70 | Parameters 71 | ---------- 72 | predictions 73 | Files containing predictions. 74 | """ 75 | return plot(predictions) 76 | 77 | 78 | if __name__ == '__main__': 79 | sys.exit(_plot()) 80 | -------------------------------------------------------------------------------- /acton/proto/acton.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package acton; 4 | 5 | /** 6 | * Key/value pair. 7 | */ 8 | message KeyVal { 9 | string key = 1; 10 | string value = 2; 11 | } 12 | 13 | /** 14 | * A database storing features and labels. 15 | * 16 | * This message should be enough to open a new connection to the database. 17 | */ 18 | message Database { 19 | /** 20 | * A scikit-learn LabelEncoder. 21 | */ 22 | message LabelEncoder { 23 | // Maps label -> integer. 24 | message Encoding { 25 | string class_label = 1; 26 | int32 class_int = 2; 27 | } 28 | repeated Encoding encoding = 1; 29 | } 30 | 31 | // Path to database (usually a file). 32 | string path = 1; 33 | 34 | // Class of Python database wrapper. 35 | string class_name = 2; 36 | 37 | // Keyword arguments to pass to the wrapper database constructor. 38 | repeated KeyVal kwarg = 3; 39 | 40 | // Encodes labels as integers. 41 | LabelEncoder label_encoder = 4; 42 | } 43 | 44 | /** 45 | * A collection of labelled data points. 46 | */ 47 | message LabelPool { 48 | 49 | // IDs of labelled data points. 50 | repeated int64 id = 1; 51 | 52 | // Database that labels are stored in. 53 | Database db = 2; 54 | } 55 | 56 | /** 57 | * Predicted labels of data points. 58 | */ 59 | message Predictions { 60 | /** 61 | * Predictions for a single instance. 62 | */ 63 | message Prediction { 64 | // The ID of the instance that we are predicting. 65 | int64 id = 1; 66 | 67 | // Predictions are a T x D array, where 68 | // - T is the number of predictors, and 69 | // - D is the dimensionality of the prediction. 70 | repeated double prediction = 2; 71 | } 72 | 73 | // Predictions for instances. 74 | repeated Prediction prediction = 1; 75 | 76 | // IDs of instances used to train the predictor. 77 | repeated int64 labelled_id = 2; 78 | 79 | // By having the data type and shape of the predictions outside the 80 | // Prediction itself, we force all predictions to be the same shape, but 81 | // also save space. 82 | 83 | // Number of predictors. 84 | int32 n_predictors = 3; 85 | 86 | // Dimensionality of the predictions. 87 | int32 n_prediction_dimensions = 4; 88 | 89 | // Predictor used to generate these predictions. 90 | string predictor = 5; 91 | 92 | // Database that instances are stored in. 93 | Database db = 6; 94 | } 95 | 96 | 97 | /** 98 | * Recommended instances to label. 99 | */ 100 | message Recommendations { 101 | // IDs of recommendations. 102 | repeated int64 recommended_id = 1; 103 | 104 | // IDs of instances used to train the predictor used to recommend. 105 | repeated int64 labelled_id = 2; 106 | 107 | // Recommender used to generate these recommendations. 108 | string recommender = 3; 109 | 110 | // Database that instances are stored in. 111 | Database db = 4; 112 | } 113 | -------------------------------------------------------------------------------- /tests/test_proto_io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | test_proto_io 5 | ---------------------------------- 6 | 7 | Tests for `proto.io` module. 8 | """ 9 | 10 | import os.path 11 | import tempfile 12 | from typing import List 13 | import unittest 14 | import unittest.mock 15 | 16 | import acton.proto.io 17 | import numpy 18 | 19 | 20 | class TestIOMany(unittest.TestCase): 21 | """Tests read_protos and write_protos.""" 22 | 23 | def setUp(self): 24 | # Make some (mock) protobufs. 25 | self.proto = unittest.mock.Mock(['SerializeToString']) 26 | self.tempdir = tempfile.TemporaryDirectory() 27 | self.path = os.path.join(self.tempdir.name, 'testiomany.proto') 28 | # And a mock class for deserialisation. 29 | 30 | class Proto: 31 | def ParseFromString(self, string): 32 | self.proto = string 33 | self.Proto = Proto 34 | # The idea here is to store the serialised protobuf, so we can check 35 | # that it has been read in correctly. 36 | 37 | def tearDown(self): 38 | self.tempdir.cleanup() 39 | 40 | @staticmethod 41 | def make_protobufs(n: int=10) -> List[bytes]: 42 | """Makes random-length "protobufs". 43 | 44 | Parameters 45 | ---------- 46 | n 47 | Number of protobufs. 48 | 49 | Yields 50 | ------ 51 | bytes 52 | A random protobuf. 53 | """ 54 | for _ in range(n): 55 | protobuf = bytes( 56 | numpy.random.randint(256) 57 | for _ in range(numpy.random.randint(100))) 58 | yield protobuf 59 | 60 | def test_write_read(self): 61 | """Protobufs written by write_protos can be read by read_protos.""" 62 | # We will assume that the write/read functions are using the 63 | # serialisation methods built-in to protobufs. We thus deal only with 64 | # serialised protobufs. 65 | serialised_protos = list(self.make_protobufs()) 66 | 67 | # Write the protobufs. 68 | writer = acton.proto.io.write_protos(self.path) 69 | next(writer) 70 | for protobuf in serialised_protos: 71 | self.proto.SerializeToString.return_value = protobuf 72 | writer.send(self.proto) 73 | writer.send(None) 74 | 75 | # Read the protobufs. 76 | read_protobufs = [i.proto 77 | for i in acton.proto.io.read_protos(self.path, self.Proto)] 78 | self.assertEqual(serialised_protos, read_protobufs) 79 | 80 | def test_write_read_file(self): 81 | """read_protos accepts opened binary files.""" 82 | # This function is identical to test_write_read but with files instead 83 | # of paths. 84 | 85 | serialised_protos = list(self.make_protobufs()) 86 | 87 | # Write the protobufs. 88 | writer = acton.proto.io.write_protos(self.path) 89 | next(writer) 90 | for protobuf in serialised_protos: 91 | self.proto.SerializeToString.return_value = protobuf 92 | writer.send(self.proto) 93 | writer.send(None) 94 | 95 | # Read the protobufs using a file object. 96 | with open(self.path, 'rb') as proto_file: 97 | read_protobufs = [i.proto for i in acton.proto.io.read_protos( 98 | proto_file, self.Proto)] 99 | self.assertEqual(serialised_protos, read_protobufs) 100 | -------------------------------------------------------------------------------- /docs/protobuf_spec.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Protobuf Serialisation\n", 8 | "\n", 9 | "This notebook documents how Acton serialises protobufs.\n", 10 | "\n", 11 | "Protobufs can be serialised and deserialised individually using the built-in methods `SerializeToString` and `ParseFromString`:" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "collapsed": true 19 | }, 20 | "outputs": [], 21 | "source": [ 22 | "# Serialising.\n", 23 | "with open(path, 'wb') as proto_file:\n", 24 | " proto_file.write(proto.SerializeToString())" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": { 31 | "collapsed": true 32 | }, 33 | "outputs": [], 34 | "source": [ 35 | "# Deserialising. (from acton.proto.io)\n", 36 | "proto = Proto()\n", 37 | "with open(path, 'rb') as proto_file:\n", 38 | " proto.ParseFromString(proto_file.read())" 39 | ] 40 | }, 41 | { 42 | "cell_type": "markdown", 43 | "metadata": {}, 44 | "source": [ 45 | "To serialise multiple protobufs into one file, we serialise each to a string, write the length of this string to a file, then write the string to the file. The length is needed because protobufs are not self-delimiting. We use an unsigned long long with the `struct` library to store the length." 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "collapsed": true 53 | }, 54 | "outputs": [], 55 | "source": [ 56 | "for proto in protos:\n", 57 | " proto = proto.SerializeToString()\n", 58 | " length = struct.pack('`_ 6 | is a suburb in Canberra, where Australian National University is 7 | located. 8 | 9 | |PyPI| |Build Status| |Documentation Status| 10 | 11 | .. |PyPI| image:: https://img.shields.io/pypi/v/acton.svg 12 | :target: https://pypi.python.org/pypi/acton 13 | .. |Build Status| image:: https://travis-ci.org/chengsoonong/acton.svg?branch=master 14 | :target: https://travis-ci.org/chengsoonong/acton 15 | .. |Documentation Status| image:: http://readthedocs.org/projects/acton/badge/?version=latest 16 | :target: http://acton.readthedocs.io/en/latest/?badge=latest 17 | 18 | Dependencies 19 | ------------ 20 | 21 | Most dependencies will be installed by pip. You will need to manually install: 22 | 23 | - Python 3.4+ 24 | 25 | Setup 26 | ----- 27 | 28 | Install Acton using ``pip3``: 29 | 30 | .. code:: bash 31 | 32 | pip install git+https://github.com/chengsoonong/acton.git 33 | 34 | This provides access to a command-line tool ``acton`` as well as the 35 | ``acton`` Python library. 36 | 37 | Acton CLI 38 | --------- 39 | 40 | The command-line interface to Acton is available through the ``acton`` 41 | command. This takes a dataset of features and labels and simulates an 42 | active learning experiment on that dataset. 43 | 44 | Input 45 | +++++ 46 | 47 | Acton supports three formats of dataset: ASCII, pandas, and HDF5. ASCII 48 | tables can be any file read by ``astropy.io.ascii.read``, including many common 49 | plain-text table formats like CSV. pandas tables are supported if dumped to a 50 | file from ``DataFrame.to_hdf``. HDF5 tables are either an HDF5 file with datasets 51 | for each feature and a dataset for labels, or an HDF5 file with one 52 | multidimensional dataset for features and one dataset for labels. 53 | 54 | Output 55 | ++++++ 56 | 57 | Acton outputs a file containing predictions for each epoch of the simulation. 58 | These are encoded as specified in `this notebook 59 | `_. 60 | 61 | Quickstart 62 | ---------- 63 | 64 | You will need a dataset. Acton currently supports ASCII tables (anything that can be read by :code:`astropy.io.ascii.read`), HDF5 tables, and Pandas tables saved as HDF5. `Here's a simple classification dataset `_ that you can use. 65 | 66 | To run Acton to generate a passive learning curve with logistic regression: 67 | 68 | .. code:: bash 69 | 70 | acton --data classification.txt --label col20 --feature col10 --feature col11 -o passive.pb --recommender RandomRecommender --predictor LogisticRegression 71 | 72 | This command uses columns ``col10`` and ``col11`` as features, and ``col20`` as labels, a logistic regression predictor, and random recommendations. It outputs all predictions for test data points selected randomly from the input data to :code:`passive.pb`, which can then be used to construct a plot. To output an active learning curve using uncertainty sampling, change :code:`RandomRecommender` to :code:`UncertaintyRecommender`. 73 | 74 | To show the learning curve, use `acton.plot`: 75 | 76 | .. code:: bash 77 | 78 | python3 -m acton.plot passive.pb 79 | 80 | Look at the directory ``examples`` for more examples. 81 | 82 | 83 | Acknowledgements 84 | ---------------- 85 | 86 | Matthew Alger was funded in late 2016 by `CAASTRO `_. 87 | -------------------------------------------------------------------------------- /acton/kde_predictor.py: -------------------------------------------------------------------------------- 1 | """A predictor that uses KDE to classify instances.""" 2 | 3 | import numpy 4 | import sklearn.base 5 | import sklearn.neighbors 6 | import sklearn.utils.multiclass 7 | import sklearn.utils.validation 8 | 9 | 10 | class KDEClassifier(sklearn.base.BaseEstimator, sklearn.base.ClassifierMixin): 11 | """A classifier using kernel density estimation to classify instances.""" 12 | 13 | def __init__(self, bandwidth=1.0): 14 | """A classifier using kernel density estimation to classify instances. 15 | 16 | A kernel density estimate is fit to each class. These estimates are used 17 | to score instances and the highest score class is used as the label for 18 | each instance. 19 | 20 | bandwidth : float 21 | Bandwidth for the kernel density estimate. 22 | """ 23 | self.bandwidth = bandwidth 24 | 25 | def fit(self, X, y): 26 | """Fits kernel density models to the data. 27 | 28 | Parameters 29 | ---------- 30 | X : array_like, shape (n_samples, n_features) 31 | List of n_features-dimensional data points. Each row 32 | corresponds to a single data point. 33 | y : array-like, shape (n_samples,) 34 | Target vector relative to X. 35 | """ 36 | X, y = sklearn.utils.validation.check_X_y(X, y) 37 | 38 | self.classes_ = sklearn.utils.multiclass.unique_labels(y) 39 | 40 | self.kdes_ = [ 41 | sklearn.neighbors.KernelDensity(self.bandwidth).fit(X[y == label]) 42 | for label in self.classes_] 43 | 44 | return self 45 | 46 | def predict(self, X): 47 | """Predicts class labels. 48 | 49 | Parameters 50 | ---------- 51 | X : array_like, shape (n_samples, n_features) 52 | List of n_features-dimensional data points. Each row 53 | corresponds to a single data point. 54 | """ 55 | sklearn.utils.validation.check_is_fitted(self, ['kdes_', 'classes_']) 56 | X = sklearn.utils.validation.check_array(X) 57 | 58 | scores = self.predict_proba(X) 59 | 60 | most_probable_indices = scores.argmax(axis=1) 61 | assert most_probable_indices.shape[0] == X.shape[0] 62 | 63 | return numpy.array([self.classes_[i] for i in most_probable_indices]) 64 | 65 | @staticmethod 66 | def _softmax(data, axis=0): 67 | """Computes the softmax of an array along an axis. 68 | 69 | Notes 70 | ----- 71 | Adapted from https://gist.github.com/stober/1946926. 72 | 73 | Parameters 74 | ---------- 75 | data : array_like 76 | Array of numbers. 77 | axis : int 78 | Axis to softmax along. 79 | """ 80 | e_x = numpy.exp( 81 | data - numpy.expand_dims(numpy.max(data, axis=axis), axis)) 82 | out = e_x / numpy.expand_dims(e_x.sum(axis=axis), axis) 83 | return out 84 | 85 | def predict_proba(self, X): 86 | """Predicts class probabilities. 87 | 88 | Class probabilities are normalised log densities of the kernel density 89 | estimates. 90 | 91 | Parameters 92 | ---------- 93 | X : array_like, shape (n_samples, n_features) 94 | List of n_features-dimensional data points. Each row 95 | corresponds to a single data point. 96 | """ 97 | sklearn.utils.validation.check_is_fitted(self, ['kdes_', 'classes_']) 98 | X = sklearn.utils.validation.check_array(X) 99 | 100 | scores = numpy.zeros((X.shape[0], len(self.classes_))) 101 | for label, kde in enumerate(self.kdes_): 102 | scores[:, label] = kde.score_samples(X) 103 | 104 | scores = self._softmax(scores, axis=1) 105 | 106 | assert scores.shape == (X.shape[0], len(self.classes_)) 107 | assert numpy.allclose(scores.sum(axis=1), numpy.ones((X.shape[0],))) 108 | 109 | return scores 110 | -------------------------------------------------------------------------------- /acton/labellers.py: -------------------------------------------------------------------------------- 1 | """Labeller classes.""" 2 | 3 | from abc import ABC, abstractmethod 4 | 5 | import acton.database 6 | import astropy.io.ascii 7 | import numpy 8 | 9 | 10 | class Labeller(ABC): 11 | """Base class for labellers. 12 | 13 | Attributes 14 | ---------- 15 | """ 16 | 17 | @abstractmethod 18 | def query(self, id_: int) -> numpy.ndarray: 19 | """Queries the labeller. 20 | 21 | Parameters 22 | ---------- 23 | id_ 24 | ID of instance to label. 25 | 26 | Returns 27 | ------- 28 | numpy.ndarray 29 | T x F label array. 30 | """ 31 | 32 | 33 | class ASCIITableLabeller(Labeller): 34 | """Labeller that obtains labels from an ASCII table. 35 | 36 | Attributes 37 | ---------- 38 | path : str 39 | Path to table. 40 | id_col : str 41 | Name of the column where IDs are stored. 42 | label_col : str 43 | Name of the column where binary labels are stored. 44 | _table : astropy.table.Table 45 | Table object. 46 | """ 47 | 48 | def __init__(self, path: str, id_col: str, label_col: str): 49 | """ 50 | path 51 | Path to table. 52 | id_col 53 | Name of the column where IDs are stored. 54 | label_col 55 | Name of the column where binary labels are stored. 56 | """ 57 | self.path = path 58 | self.id_col = id_col 59 | self.label_col = label_col 60 | self._table = astropy.io.ascii.read(self.path) 61 | self._id_to_name = {} 62 | for id_, row in enumerate(self._table): 63 | name = row[self.id_col] 64 | self._id_to_name[id_] = name 65 | 66 | def query(self, id_: int) -> numpy.ndarray: 67 | """Queries the labeller. 68 | 69 | Parameters 70 | ---------- 71 | id_ 72 | ID of instance to label. 73 | 74 | Returns 75 | ------- 76 | numpy.ndarray 77 | 1 x 1 label array. 78 | """ 79 | for row in self._table: 80 | if row[self.id_col] == self._id_to_name[id_]: 81 | return row[self.label_col].reshape((1, 1)) 82 | raise KeyError('Unknown id: {}'.format(id_)) 83 | 84 | 85 | class DatabaseLabeller(Labeller): 86 | """Labeller that obtains labels from a Database. 87 | 88 | Attributes 89 | ---------- 90 | _db : acton.database.Database 91 | Database with labels. 92 | """ 93 | 94 | def __init__(self, db: acton.database.Database): 95 | """ 96 | db 97 | Database with labels to read from. 98 | """ 99 | self._db = db 100 | 101 | def query(self, id_: int) -> numpy.ndarray: 102 | """Queries the labeller. 103 | 104 | Parameters 105 | ---------- 106 | id_ 107 | ID of instance to label. 108 | 109 | Returns 110 | ------- 111 | numpy.ndarray 112 | 1 x 1 label array. 113 | """ 114 | return self._db.read_labels([0], [id_]).reshape((1, 1)) 115 | 116 | 117 | class GraphDatabaseLabeller(Labeller): 118 | """Labeller that obtains labels from a Database. 119 | 120 | Attributes 121 | ---------- 122 | _db : acton.database.Database 123 | Database with labels. 124 | """ 125 | 126 | def __init__(self, db: acton.database.Database): 127 | """ 128 | db 129 | Database with labels to read from. 130 | """ 131 | self._db = db 132 | 133 | def query(self, id_: tuple) -> numpy.ndarray: 134 | """Queries the labeller. 135 | 136 | Parameters 137 | ---------- 138 | id_ 139 | ID of instance to label. 140 | 141 | Returns 142 | ------- 143 | numpy.ndarray 144 | 1 x 1 label array. 145 | """ 146 | return self._db.read_labels([id_]).reshape((1, 1)) 147 | 148 | 149 | # For safe string-based access to labeller classes. 150 | LABELLERS = { 151 | 'ASCIITableLabeller': ASCIITableLabeller, 152 | 'DatabaseLabeller': DatabaseLabeller, 153 | 'GraphDatabaseLabeller': GraphDatabaseLabeller 154 | } 155 | -------------------------------------------------------------------------------- /tests/test_database.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | test_database 5 | ---------------------------------- 6 | 7 | Tests for `database` module. 8 | """ 9 | 10 | import os.path 11 | import tempfile 12 | import unittest 13 | 14 | from acton import database 15 | import numpy 16 | 17 | 18 | class TestManagedHDF5Database(unittest.TestCase): 19 | """Tests the ManagedHDF5Database class.""" 20 | 21 | def setUp(self): 22 | self.tempdir = tempfile.TemporaryDirectory() 23 | 24 | def tearDown(self): 25 | self.tempdir.cleanup() 26 | 27 | def temp_path(self, filename: str) -> str: 28 | """Makes a temporary path for a file. 29 | 30 | Parameters 31 | ---------- 32 | filename 33 | Filename of the file. 34 | 35 | Returns 36 | ------- 37 | str 38 | Full path to the temporary filename. 39 | """ 40 | return os.path.join(self.tempdir.name, filename) 41 | 42 | def test_io_features(self): 43 | """Features can be written to and read out from a ManagedHDF5Database. 44 | """ 45 | path = self.temp_path('test_io_features.h5') 46 | # Make some testing data. 47 | n_instances = 20 48 | n_dimensions = 15 49 | ids = [i for i in range(n_instances)] 50 | numpy.random.shuffle(ids) 51 | features = numpy.random.random(size=(n_instances, n_dimensions)).astype( 52 | 'float32') 53 | 54 | # Store half the testing data in the database. 55 | with database.ManagedHDF5Database(path) as db: 56 | db.write_features(ids[:n_instances // 2], 57 | features[:n_instances // 2]) 58 | with database.ManagedHDF5Database(path) as db: 59 | self.assertTrue(numpy.allclose( 60 | features[:n_instances // 2], 61 | db.read_features(ids[:n_instances // 2]))) 62 | 63 | # Change some features. 64 | changed_ids_int = [i for i in range(0, n_instances, 2)] 65 | changed_ids = [ids[i] for i in changed_ids_int] 66 | new_features = features.copy() 67 | new_features[changed_ids_int] = numpy.random.random( 68 | size=(len(changed_ids), n_dimensions)) 69 | 70 | # Update the database and extend it by including all features. 71 | with database.ManagedHDF5Database(path) as db: 72 | db.write_features(ids, new_features) 73 | with database.ManagedHDF5Database(path) as db: 74 | self.assertTrue(numpy.allclose( 75 | new_features, 76 | db.read_features(ids))) 77 | 78 | def test_read_write_labels(self): 79 | """Labels can be written to and read from a ManagedHDF5Database.""" 80 | path = self.temp_path('test_read_write_labels.h5') 81 | # Make some testing data. 82 | n_instances = 5 83 | n_dimensions = 1 84 | n_labellers = 1 85 | ids = [i for i in range(n_instances)] 86 | labeller_ids = [i for i in range(n_labellers)] 87 | numpy.random.shuffle(ids) 88 | labels = numpy.random.random( 89 | size=(n_labellers, n_instances, n_dimensions)).astype('float32') 90 | # Store half the testing data in the database. 91 | with database.ManagedHDF5Database(path) as db: 92 | db.write_labels(labeller_ids, 93 | ids[:n_instances // 2], 94 | labels[:, :n_instances // 2]) 95 | with database.ManagedHDF5Database(path) as db: 96 | exp_labels = labels[:, :n_instances // 2] 97 | act_labels = db.read_labels(labeller_ids, 98 | ids[:n_instances // 2]) 99 | self.assertTrue(numpy.allclose(exp_labels, act_labels), 100 | msg='delta {}'.format(exp_labels - act_labels)) 101 | # Change some labels. 102 | changed_ids_int = [i for i in range(0, n_instances, 2)] 103 | changed_ids = [ids[i] for i in changed_ids_int] 104 | new_labels = labels.copy() 105 | new_labels[:, changed_ids_int] = numpy.random.random( 106 | size=(n_labellers, len(changed_ids), n_dimensions)) 107 | 108 | # Update the database and extend it by including all labels. 109 | with database.ManagedHDF5Database(path) as db: 110 | db.write_labels(labeller_ids, ids, new_labels) 111 | with database.ManagedHDF5Database(path) as db: 112 | self.assertTrue(numpy.allclose( 113 | new_labels, 114 | db.read_labels(labeller_ids, ids))) 115 | -------------------------------------------------------------------------------- /tests/test_proto_wrappers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | test_proto_wrappers 5 | ---------------------------------- 6 | 7 | Tests for `proto.wrappers` module. 8 | """ 9 | 10 | import os.path 11 | import tempfile 12 | import unittest 13 | import unittest.mock 14 | 15 | import acton.database 16 | import acton.proto.wrappers 17 | import numpy 18 | 19 | 20 | class TestLabelPool(unittest.TestCase): 21 | """Tests the LabelPool wrapper.""" 22 | 23 | def setUp(self): 24 | self.tempdir = tempfile.TemporaryDirectory() 25 | self.path = os.path.join(self.tempdir.name, 'labelpool.proto') 26 | self.db_path = os.path.join(self.tempdir.name, 'db.txt') 27 | self.db_kwargs = { 28 | 'feature_cols': ['feature'], 29 | 'label_col': 'label', 30 | } 31 | with open(self.db_path, 'w') as f: 32 | f.write('id\tfeature\tlabel\n') 33 | f.write('0\t0.2\t0\n') 34 | f.write('1\t0.1\t0\n') 35 | f.write('2\t0.6\t1\n') 36 | 37 | def tearDown(self): 38 | self.tempdir.cleanup() 39 | 40 | def test_integration(self): 41 | """LabelPool.make returns a LabelPool with correct values.""" 42 | ids = [0, 2] 43 | with acton.database.ASCIIReader(self.db_path, **self.db_kwargs) as db: 44 | lp = acton.proto.wrappers.LabelPool.make(ids=ids, db=db) 45 | self.assertTrue(([b'0', b'1'] == lp.labels.ravel()).all()) 46 | self.assertEqual([0, 2], lp.ids) 47 | with lp.DB() as db: 48 | self.assertEqual([0, 1, 2], db.get_known_instance_ids()) 49 | self.assertEqual('ASCIIReader', lp.proto.db.class_name) 50 | 51 | 52 | class TestPredictions(unittest.TestCase): 53 | """Tests the Predictions wrapper.""" 54 | 55 | def setUp(self): 56 | self.tempdir = tempfile.TemporaryDirectory() 57 | self.path = os.path.join(self.tempdir.name, 'predictions.proto') 58 | self.db_path = os.path.join(self.tempdir.name, 'db.txt') 59 | self.db_kwargs = { 60 | 'feature_cols': ['feature'], 61 | 'label_col': 'label', 62 | } 63 | with open(self.db_path, 'w') as f: 64 | f.write('id\tfeature\tlabel\n') 65 | f.write('0\t0.2\t0\n') 66 | f.write('1\t0.1\t0\n') 67 | f.write('2\t0.6\t1\n') 68 | 69 | def tearDown(self): 70 | self.tempdir.cleanup() 71 | 72 | def test_integration(self): 73 | """Predictions.make returns a Predictions with correct values.""" 74 | predicted_ids = [0, 2] 75 | labelled_ids = [1, 2] 76 | predictions = numpy.array([0.1, 0.5, 0.5, 0.9]).reshape((2, 2, 1)) 77 | with acton.database.ASCIIReader(self.db_path, **self.db_kwargs) as db: 78 | preds = acton.proto.wrappers.Predictions.make( 79 | predicted_ids=predicted_ids, labelled_ids=labelled_ids, 80 | predictions=predictions, db=db) 81 | self.assertEqual([0, 2], preds.predicted_ids) 82 | self.assertEqual([1, 2], preds.labelled_ids) 83 | with preds.DB() as db: 84 | self.assertEqual([0, 1, 2], db.get_known_instance_ids()) 85 | self.assertTrue(numpy.allclose(predictions, preds.predictions)) 86 | 87 | 88 | class TestRecommendations(unittest.TestCase): 89 | """Tests the Recommendations wrapper.""" 90 | 91 | def setUp(self): 92 | self.tempdir = tempfile.TemporaryDirectory() 93 | self.path = os.path.join(self.tempdir.name, 'recommendations.proto') 94 | self.recommender = 'UncertaintyRecommender' 95 | self.db_path = os.path.join(self.tempdir.name, 'db.txt') 96 | self.db_kwargs = { 97 | 'feature_cols': ['feature'], 98 | 'label_col': 'label', 99 | } 100 | with open(self.db_path, 'w') as f: 101 | f.write('id\tfeature\tlabel\n') 102 | f.write('0\t0.2\t0\n') 103 | f.write('1\t0.1\t0\n') 104 | f.write('2\t0.6\t1\n') 105 | 106 | def tearDown(self): 107 | self.tempdir.cleanup() 108 | 109 | def test_integration(self): 110 | """Recommendations.make returns Recommendations with correct values.""" 111 | recommended_ids = [0, 2] 112 | labelled_ids = [1, 2] 113 | with acton.database.ASCIIReader(self.db_path, **self.db_kwargs) as db: 114 | recs = acton.proto.wrappers.Recommendations.make( 115 | recommended_ids=recommended_ids, labelled_ids=labelled_ids, 116 | recommender=self.recommender, db=db) 117 | self.assertEqual([0, 2], recs.recommendations) 118 | self.assertEqual([1, 2], recs.labelled_ids) 119 | with recs.DB() as db: 120 | self.assertEqual([0, 1, 2], db.get_known_instance_ids()) 121 | -------------------------------------------------------------------------------- /tests/test_predictors.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | """ 4 | test_predictors 5 | ---------------------------------- 6 | 7 | Tests for `predictors` module. 8 | """ 9 | 10 | import os.path 11 | import tempfile 12 | import unittest 13 | import unittest.mock 14 | 15 | import acton.database 16 | import acton.predictors 17 | import acton.proto.wrappers 18 | from acton.proto.acton_pb2 import LabelPool 19 | import numpy 20 | import sklearn.linear_model 21 | 22 | 23 | class TestIntegrationCommittee(unittest.TestCase): 24 | """Integration test for Committee.""" 25 | 26 | def setUp(self): 27 | # Make a protobuf. 28 | self.ids = LabelPool() 29 | 30 | self.ids.id.append(1) 31 | self.ids.id.append(3) 32 | 33 | self.ids.db.class_name = 'ManagedHDF5Database' 34 | 35 | self.tempdir = tempfile.TemporaryDirectory() 36 | self.path = os.path.join(self.tempdir.name, 'predin.proto') 37 | self.ids.db.path = os.path.join(self.tempdir.name, 'test.h5') 38 | with open(self.path, 'wb') as f: 39 | f.write(self.ids.SerializeToString()) 40 | 41 | self.n_instances = 2 42 | self.features = numpy.array([2, 5, 3, 7]).reshape((self.n_instances, 2)) 43 | with acton.database.ManagedHDF5Database( 44 | self.ids.db.path, 45 | feature_dtype='int32') as db: 46 | db.write_features(self.ids.id, 47 | self.features) 48 | labels = numpy.array([0, 1]).reshape((1, -1, 1)) 49 | db.write_labels([0], self.ids.id, labels) 50 | 51 | def tearDown(self): 52 | self.tempdir.cleanup() 53 | 54 | def testAll(self): 55 | """Committee can be used with LabelPool.""" 56 | pred_input = acton.proto.wrappers.LabelPool(self.ids) 57 | with pred_input.DB() as db: 58 | lrc = acton.predictors.Committee( 59 | acton.predictors.from_class( 60 | sklearn.linear_model.LogisticRegression), db, 61 | n_classifiers=10) 62 | ids = pred_input.ids 63 | lrc.fit(ids) 64 | probs, variances = lrc.predict(ids) 65 | self.assertEqual((self.n_instances, 10, 2), probs.shape) 66 | 67 | 68 | class TestSklearnWrapper(unittest.TestCase): 69 | """Integration test for scikit-learn wrapper functions.""" 70 | 71 | def setUp(self): 72 | # Make a protobuf. 73 | self.ids = LabelPool() 74 | 75 | self.ids.id.append(1) 76 | self.ids.id.append(3) 77 | 78 | self.ids.db.class_name = 'ManagedHDF5Database' 79 | 80 | self.tempdir = tempfile.TemporaryDirectory() 81 | self.path = os.path.join(self.tempdir.name, 'predin.proto') 82 | self.ids.db.path = os.path.join(self.tempdir.name, 'test.h5') 83 | with open(self.path, 'wb') as f: 84 | f.write(self.ids.SerializeToString()) 85 | 86 | self.n_instances = 2 87 | self.features = numpy.array([2, 5, 3, 7]).reshape((self.n_instances, 2)) 88 | with acton.database.ManagedHDF5Database( 89 | self.ids.db.path, 90 | feature_dtype='int32') as db: 91 | db.write_features(self.ids.id, 92 | self.features) 93 | labels = numpy.array([0, 1]).reshape((1, -1, 1)) 94 | db.write_labels([0], self.ids.id, labels) 95 | 96 | def tearDown(self): 97 | self.tempdir.cleanup() 98 | 99 | def testFromInstance(self): 100 | """from_instance wraps a scikit-learn classifier.""" 101 | # The main point of this test is to check nothing crashes. 102 | classifier = sklearn.linear_model.LogisticRegression() 103 | pred_input = acton.proto.wrappers.LabelPool(self.ids) 104 | 105 | with pred_input.DB() as db: 106 | predictor = acton.predictors.from_instance(classifier, db) 107 | ids = pred_input.ids 108 | predictor.fit(ids) 109 | probs, variances = predictor.predict(ids) 110 | self.assertEqual((2, 1, 2), probs.shape) 111 | 112 | def testFromClass(self): 113 | """from_class wraps a scikit-learn classifier.""" 114 | # The main point of this test is to check nothing crashes. 115 | Classifier = sklearn.linear_model.LogisticRegression 116 | pred_input = acton.proto.wrappers.LabelPool(self.ids) 117 | 118 | with pred_input.DB() as db: 119 | Predictor = acton.predictors.from_class(Classifier) 120 | predictor = Predictor(db, C=50.0) 121 | ids = pred_input.ids 122 | predictor.fit(ids) 123 | probs, variances = predictor.predict(ids) 124 | self.assertEqual((2, 1, 2), probs.shape) 125 | 126 | 127 | class TestGPClassifier(unittest.TestCase): 128 | """Integration test for GPClassifier.""" 129 | 130 | def setUp(self): 131 | self.tempdir = tempfile.TemporaryDirectory() 132 | 133 | def tearDown(self): 134 | self.tempdir.cleanup() 135 | 136 | def test_str_labels(self): 137 | """GPClassifier handles string labels.""" 138 | self.db_path = os.path.join(self.tempdir.name, 'str.h5') 139 | self.n_instances = 2 140 | self.ids = [2, 4] 141 | self.features = numpy.array([2, 5, 3, 7]).reshape((self.n_instances, 2)) 142 | self.labels = numpy.array(['Class A', 'Class B']).reshape((1, -1, 1)) 143 | 144 | with acton.database.ManagedHDF5Database( 145 | self.db_path, label_dtype=' 'Protobuf': 15 | """Reads a protobuf from a .proto file. 16 | 17 | Parameters 18 | ---------- 19 | path 20 | Path to the .proto file. 21 | Proto: 22 | Protocol message class (from the generated protobuf module). 23 | 24 | Returns 25 | ------- 26 | GeneratedProtocolMessageType 27 | The parsed protobuf. 28 | """ 29 | proto = Proto() 30 | with open(path, 'rb') as proto_file: 31 | proto.ParseFromString(proto_file.read()) 32 | 33 | return proto 34 | 35 | 36 | def write_protos(path: str, metadata: bytes=b''): 37 | """Serialises many protobufs to a file. 38 | 39 | Parameters 40 | ---------- 41 | path 42 | Path to binary file. Will be overwritten. 43 | metadata 44 | Optional bytestring to prepend to the file. 45 | 46 | Notes 47 | ----- 48 | Coroutine. Accepts protobufs, or None to terminate and close file. 49 | """ 50 | with open(path, 'wb') as proto_file: 51 | # Write metadata. 52 | proto_file.write(struct.pack(' bytes: 88 | """Reads metadata from a protobufs file. 89 | 90 | Notes 91 | ----- 92 | Internal use. For external API, use read_metadata. 93 | 94 | Parameters 95 | ---------- 96 | proto_file 97 | Binary file. 98 | 99 | Returns 100 | ------- 101 | bytes 102 | Metadata. 103 | """ 104 | metadata_length = proto_file.read(8) # Long long 105 | metadata_length, = struct.unpack(' bytes: 110 | """Reads metadata from a protobufs file. 111 | 112 | Parameters 113 | ---------- 114 | file 115 | Path to binary file, or file itself. 116 | 117 | Returns 118 | ------- 119 | bytes 120 | Metadata. 121 | """ 122 | try: 123 | return _read_metadata(file) 124 | except AttributeError: 125 | # Not a file-like object, so open the file. 126 | with open(file, 'rb') as proto_file: 127 | return _read_metadata(proto_file) 128 | 129 | 130 | def _read_protos( 131 | proto_file: BinaryIO, 132 | Proto: GeneratedProtocolMessageType 133 | ) -> 'GeneratedProtocolMessageType()': 134 | """Reads many protobufs from a file. 135 | 136 | Notes 137 | ----- 138 | Internal use. For external API, use read_protos. 139 | 140 | Parameters 141 | ---------- 142 | proto_file 143 | Binary file. 144 | Proto: 145 | Protocol message class (from the generated protobuf module). 146 | 147 | Yields 148 | ------- 149 | GeneratedProtocolMessageType 150 | A parsed protobuf. 151 | """ 152 | # This is essentially the inverse of the write_protos function. 153 | 154 | # Skip the metadata. 155 | metadata_length = proto_file.read(8) # Long long 156 | metadata_length, = struct.unpack(' 'GeneratedProtocolMessageType()': 172 | """Reads many protobufs from a file. 173 | 174 | Parameters 175 | ---------- 176 | file 177 | Path to binary file, or file itself. 178 | Proto: 179 | Protocol message class (from the generated protobuf module). 180 | 181 | Yields 182 | ------- 183 | GeneratedProtocolMessageType 184 | A parsed protobuf. 185 | """ 186 | try: 187 | yield from _read_protos(file, Proto) 188 | except AttributeError: 189 | # Not a file-like object, so open the file. 190 | with open(file, 'rb') as proto_file: 191 | yield from _read_protos(proto_file, Proto) 192 | 193 | 194 | def get_ndarray(data: list, shape: tuple, dtype: str) -> numpy.ndarray: 195 | """Converts a list of values into an array. 196 | 197 | Parameters 198 | ---------- 199 | data 200 | Raw array data. 201 | shape: 202 | Shape of the resulting array. 203 | dtype: 204 | Data type of the resulting array. 205 | 206 | Returns 207 | ------- 208 | numpy.ndarray 209 | Array with the given data, shape, and dtype. 210 | """ 211 | return numpy.array(data, dtype=dtype).reshape(tuple(shape)) 212 | -------------------------------------------------------------------------------- /docs/design/api_design.tex: -------------------------------------------------------------------------------- 1 | \documentclass[11pt,twoside]{article} 2 | 3 | \usepackage[margin=2cm]{geometry} 4 | \usepackage{tikz} 5 | \usetikzlibrary{positioning,fit,shapes} 6 | 7 | \newcommand{\name}[1]{ {\color{red}{\sffamily\bfseries{#1}}} } 8 | 9 | \title{Active Learning API design} 10 | \author{Cheng Soon Ong} 11 | \date{17 October 2017} 12 | 13 | \begin{document} 14 | \maketitle 15 | 16 | \section{Design choices} 17 | 18 | Things to pass around 19 | \begin{itemize} 20 | \item \name{A}nnotations produced by the labeller, also known as labels 21 | \item interesting \name{E}xamples, presented to a human researcher. We assume that examples 22 | can be referred to by name, for example using a primary key or ra/dec coordinates. 23 | \item One or more \name{S}cores corresponding to an example. 24 | \end{itemize} 25 | 26 | \noindent How do we keep track of 27 | \begin{itemize} 28 | \item The actual \name{F}eature vectors corresponding to the inputs to the predictor. 29 | \item A measure of the prediction \name{P}erformance, for example accuracy or $R^2$. 30 | \end{itemize} 31 | 32 | 33 | \subsection{Omega design ($\Omega$)} 34 | 35 | The recommender encompasses the predictor (Figure~\ref{fig:omega-design}). 36 | 37 | \begin{figure}[ht] 38 | \centering 39 | \begin{tikzpicture}[>=latex] 40 | 41 | % 42 | % Styles for states, and state edges 43 | % 44 | \tikzstyle{objectName} = [align=center, font={\sffamily\bfseries}, text=black!60] 45 | \tikzstyle{bbox} = [draw, very thick, rectangle, rounded corners=2pt, thin, 46 | minimum height=3em, minimum width=7em, node distance=8em] 47 | \tikzstyle{object} = [bbox, objectName, fill=blue!20] 48 | \tikzstyle{edgePortion} = [black,thick,bend right=10]; 49 | \tikzstyle{dataFlow} = [edgePortion,->]; 50 | \tikzstyle{dataLabel} = [pos=0.5, text centered, text=red, font={\sffamily\bfseries\small}]; 51 | 52 | % 53 | % Position States 54 | % 55 | \node[object, name=labeller] {Labeller}; 56 | \node[object, name=predictor, right = of labeller ] {Predictor}; 57 | \node[objectName, name=recommender, above of=predictor, color=black] {Recommender}; 58 | \node[bbox, name=wrapper, ellipse, fill=blue!=20, fill opacity=0.2, fit={(recommender) (predictor)}] {}; 59 | 60 | % 61 | % Connect States via edges 62 | % 63 | \draw (labeller) 64 | edge[dataFlow] node[dataLabel, below]{A} 65 | (wrapper); 66 | \draw (wrapper) 67 | edge[dataFlow] node[dataLabel, above]{E} 68 | (labeller); 69 | \end{tikzpicture} 70 | \caption{$\Omega$ design} 71 | \label{fig:omega-design} 72 | \end{figure} 73 | 74 | \newpage 75 | \subsection{{\bf Y} design} 76 | 77 | All parties contribute to a central database (Figure~\ref{fig:y-design}). 78 | 79 | \begin{figure}[ht] 80 | \centering 81 | \begin{tikzpicture}[>=latex] 82 | 83 | % 84 | % Styles for states, and state edges 85 | % 86 | \tikzstyle{object} = [draw, very thick, rectangle, rounded corners=2pt, thin, align=center, 87 | fill=blue!20, minimum height=3em, minimum width=7em, node distance=8em, 88 | font={\sffamily\bfseries}, text=black!60] 89 | \tikzstyle{edgePortion} = [black,thick,bend right=20]; 90 | \tikzstyle{dataFlow} = [edgePortion,->]; 91 | \tikzstyle{dataLabel} = [pos=0.5, text centered, text=red, font={\sffamily\bfseries\small}]; 92 | 93 | % 94 | % Position States 95 | % 96 | \node[object, name=labeller] {Labeller}; 97 | \node[object, name=database, regular polygon, regular polygon sides=7, above = of labeller] {Database}; 98 | \node[object, name=recommender, ellipse, above left = of database, text=black] {Recommender}; 99 | \node[object, name=predictor, above right = of database] {Predictor}; 100 | 101 | % 102 | % Connect States via edges 103 | % 104 | \draw (labeller) 105 | edge[dataFlow] node[dataLabel, right]{A} 106 | (database); 107 | \draw (database) 108 | edge[dataFlow] node[dataLabel, left]{E} 109 | (labeller); 110 | 111 | \draw (database) 112 | edge[dataFlow] node[dataLabel, right]{???} 113 | (predictor); 114 | \draw (predictor) 115 | edge[dataFlow] node[dataLabel, left]{S} 116 | (database); 117 | 118 | \draw (database) 119 | edge[dataFlow] node[dataLabel, right]{S} 120 | (recommender); 121 | \draw (recommender) 122 | edge[dataFlow] node[dataLabel, left]{E} 123 | (database); 124 | \end{tikzpicture} 125 | \caption{{\bf Y} design} 126 | \label{fig:y-design} 127 | \end{figure} 128 | 129 | \subsection{Delta design ($\Delta$)} 130 | 131 | Each of Labeller, Recommender and Predictor are equal partners (Figure~\ref{fig:delta-design}). 132 | 133 | \begin{figure}[ht] 134 | \centering 135 | \begin{tikzpicture}[>=latex] 136 | 137 | % 138 | % Styles for states, and state edges 139 | % 140 | \tikzstyle{object} = [draw, very thick, rectangle, rounded corners=2pt, thin, align=center, 141 | fill=blue!20, minimum height=3em, minimum width=7em, node distance=8em, 142 | font={\sffamily\bfseries}, text=black!60] 143 | \tikzstyle{edgePortion} = [black,thick,bend right=10]; 144 | \tikzstyle{dataFlow} = [edgePortion,->]; 145 | \tikzstyle{dataLabel} = [pos=0.5, text centered, text=red, font={\sffamily\bfseries\small}]; 146 | 147 | % 148 | % Position States 149 | % 150 | \node[object, name=labeller] {Labeller}; 151 | \node[object, name=recommender, ellipse, below left = of labeller, text=black] {Recommender}; 152 | \node[object, name=predictor, below right = of labeller] {Predictor}; 153 | 154 | % 155 | % Connect States via edges 156 | % 157 | \draw (labeller) 158 | edge[dataFlow] node[dataLabel, left]{A} 159 | (recommender); 160 | \draw (recommender) 161 | edge[dataFlow] node[dataLabel, right]{E} 162 | (labeller); 163 | 164 | \draw (labeller) 165 | edge[dataFlow] node[dataLabel, left]{A} 166 | (predictor); 167 | \draw (predictor) 168 | edge[dataFlow] node[dataLabel, right]{S} 169 | (labeller); 170 | 171 | \draw (predictor) 172 | edge[dataFlow] node[dataLabel, above]{S} 173 | (recommender); 174 | \draw (recommender) 175 | edge[dataFlow] node[dataLabel, below]{???} 176 | (predictor); 177 | \end{tikzpicture} 178 | \caption{$\Delta$ design} 179 | \label{fig:delta-design} 180 | \end{figure} 181 | 182 | 183 | 184 | \end{document} 185 | -------------------------------------------------------------------------------- /docs/PRESCAL Updating Design.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## PRESCAL Updating Design\n", 8 | "\n", 9 | "Original work: https://github.com/dongwookim-ml/almc \n", 10 | "Paper: https://arxiv.org/pdf/1608.05921.pdf\n", 11 | "\n", 12 | "To speed up the algorithm, we re-design the updating steps.\n", 13 | "\n", 14 | "We consider:\n", 15 | "\n", 16 | "1. We might do not need to update posteriors for all entities and relations.\n", 17 | "\n", 18 | "2. For a sequential updating algorithm (Thompson sampling), it doesn't make sense to use all observed labels in each iteration. i.e.\n", 19 | "\n", 20 | "$$P_t = l({x}_t) P_{t-1}$$\n", 21 | "where $x_t$ is the label observed in $t^{th}$ iteration.\n", 22 | "\n", 23 | "Based on the above consideration, we come up with the following design ideas:\n", 24 | "\n", 25 | "### Design 1\n", 26 | "\n", 27 | "Assume we observe $x_{ijk}$ in $t^{th}$ iteration, we only update the posterior of $e_i, e_j, r_k$ using the new label $x_{ijk}$.\n", 28 | "\n", 29 | "$\\textbf{Prior}$:\n", 30 | "\n", 31 | "$$P(\\mathbf{e_i}|\\sigma_e) = \\mathcal{N}(\\mathbf{e_i}| \\mathbf{u_e}, {\\sigma_e}^2 I_D)$$\n", 32 | "$$P(\\mathbf{R_k}|\\sigma_r) = \\mathcal{MN}(\\mathbf{R_k}| \\mathbf{u_r}, {\\sigma_r} I_D, {\\sigma_r} I_D)$$\n", 33 | "or eqivalently,\n", 34 | "$$P(\\mathbf{r_k}|\\sigma_r) = \\mathcal{N}(\\mathbf{r_k}| \\mathbf{u_r}, {\\sigma_r}^2 I_{D^2})$$\n", 35 | "where $r_k = vec(R_k) \\in \\mathcal{R}^{D^2,1}$\n", 36 | "\n", 37 | "$\\textbf{Likelihood}$:\n", 38 | "\n", 39 | "$$p(x_{ikj}|\\mathbf{e_i, e_j}, R_k) = \\mathcal{N}(x_{ikj}| \\mathbf{e_i}^T R_k \\mathbf{e_j}, \\sigma_x^2)$$\n", 40 | "using the identity $\\mathbf{e_i}^T R_k \\mathbf{e_j} = r_k^T \\mathbf{e_i} \\otimes \\mathbf{e_j}$,\n", 41 | "$$p(x_{ikj}|\\mathbf{e_i, e_j, r_k}) = \\mathcal{N}(x_{ikj}| \\mathbf{r_k}^T \\mathbf{e_i} \\otimes \\mathbf{e_j}, \\sigma_x^2)$$\n", 42 | "\n", 43 | "$\\textbf{Entity Posterior}$:\n", 44 | "\n", 45 | "$$P(\\mathbf{e_i}|x_{ikj}, \\mathbf{e_j}, R_k, \\sigma_e) = \\mathcal{N}(\\mathbf{e_i}| m_{eN}, s_{eN}) \\propto P(\\mathbf{e_i}|\\sigma_e)P(x_{ikj}|\\mathbf{e_i, e_j}, R_k) = \\mathcal{N}(\\mathbf{e_i}| \\mathbf{u_e}, {\\sigma_e}^2 I_D) \\mathcal{N}(x_{ikj}| \\mathbf{e_i}^T R_k \\mathbf{e_j}, \\sigma_x^2)$$\n", 46 | "\n", 47 | "We know for $c \\mathcal{N}(\\mathbf{x|c, C}) = \\mathcal{N}(\\mathbf{x|a, A})\\mathcal{N}(\\mathbf{x|b, B})$, \n", 48 | "\n", 49 | "\\begin{equation}\n", 50 | "\\mathbf{C = {(A^{-1} + B ^{-1)})}^{-1}}\n", 51 | "\\end{equation}\n", 52 | "$$\\mathbf{c = C(A^{-1}a + B^{-1}b)}$$" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "So the goal is to transform $\\mathcal{N}(x_{ikj}| r_k^T \\mathbf{e_i} \\otimes \\mathbf{e_j}, \\sigma_x^2)$ into $\\mathcal{N}(\\mathbf{e_i}| M x_{ikj}, \\sigma_x^2 MM^T)$" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "Assume $R_k \\mathbf{e_j}$ is column full rank,\n", 67 | "$$x_{ikj} = \\mathbf{e_i^T}R_k\\mathbf{e_j} \\Leftrightarrow \\mathbf{e_i} = (R_k \\mathbf{e_j})^{-T}x_{ikj}$$" 68 | ] 69 | }, 70 | { 71 | "cell_type": "markdown", 72 | "metadata": {}, 73 | "source": [ 74 | "$$\\mathcal{N}(\\mathbf{e_i}| M x_{ikj}, \\sigma_x^2 MM^T) = \\mathcal{N}(\\mathbf{e_i}| (R_k \\mathbf{e_j})^{-T}x_{ikj}, \\sigma_x^2 ((R_k \\mathbf{e_j})(R_k \\mathbf{e_j})^T)^{-1})$$" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "$$s_{eN} = (\\sigma_e^{-2} I_D + \\sigma_x^{-2} (R_k \\mathbf{e_j})(R_k \\mathbf{e_j})^T)^{-1}$$" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "$$m_{eN} = s_{eN} (\\sigma_e^{-2} \\mathbf{u_e} + \\sigma_x^{-2} (R_k \\mathbf{e_j}) x_{ikj} )$$" 89 | ] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "metadata": {}, 94 | "source": [ 95 | "Similarly, assum $\\mathbf{e_i}^T R_k$ is column full rank, for $P(e_j|x_{ikj}, \\mathbf{e_i}, R_k, \\sigma_e)$ we have \n", 96 | "\n", 97 | "$$s_{eN} = (\\sigma_e^{-2} I_D + \\sigma_x^{-2} (\\mathbf{e_i}^T R_k)^T(\\mathbf{e_i}^T R_k))^{-1}$$\n", 98 | "\n", 99 | "$$m_{eN} = s_{eN} (\\sigma_e^{-2} \\mathbf{u_e} + \\sigma_x^{-2} (\\mathbf{e_i}^T R_k)^T x_{ikj} )$$" 100 | ] 101 | }, 102 | { 103 | "cell_type": "markdown", 104 | "metadata": {}, 105 | "source": [ 106 | "$\\textbf{Relation Posterior}:$\n", 107 | "\n", 108 | "$$P(\\mathbf{r_k}|x_{ikj}, \\mathbf{e_i, e_j}, \\sigma_r) = \\mathcal{N}(\\mathbf{r_k}|m_{rN}, s_{rN}) \\propto P(\\mathbf{r_k|\\sigma_r})P(x_{ikj}|\\mathbf{e_i, e_j, r_k}) = \\mathcal{N}(\\mathbf{r_k}| \\mathbf{u_r}, {\\sigma_r}^2 I_{D^2}) \\mathcal{N}(x_{ikj}| \\mathbf{r_k}^T \\mathbf{e_i} \\otimes \\mathbf{e_j}, \\sigma_x^2)$$" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "metadata": {}, 114 | "source": [ 115 | "Similarly, assume $\\mathbf{e_i} \\otimes \\mathbf{e_j}$ is column full rank,\n", 116 | "\n", 117 | "$$ x_{ikj} = \\mathbf{r_k}^T \\mathbf{e_i} \\otimes \\mathbf{e_j} \\Leftrightarrow \\mathbf{r_k} = (\\mathbf{e_i} \\otimes \\mathbf{e_j}) ^{-T} x_{ikj}$$ " 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "$$\\mathcal{N}(\\mathbf{r_k}|M x_{ikj}, \\sigma^2 MM^T) = \\mathcal{N}(\\mathbf{r_k}| (\\mathbf{e_i} \\otimes \\mathbf{e_j}) ^{-T} x_{ikj}, \\sigma_x^{2} ((\\mathbf{e_i} \\otimes \\mathbf{e_j}) (\\mathbf{e_i} \\otimes \\mathbf{e_j}) ^T)^{-1} )$$" 125 | ] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "metadata": {}, 130 | "source": [ 131 | "$$s_{rN} = (\\sigma_r^{-2}I_D + \\sigma_x^{-2} (\\mathbf{e_i} \\otimes \\mathbf{e_j}) (\\mathbf{e_i} \\otimes \\mathbf{e_j}) ^T)^{-1}$$ " 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "$$m_{rN} = s_{rN}(\\sigma_r^{-2} \\mathbf{u_r} + \\sigma_x^{-2} (\\mathbf{e_i} \\otimes \\mathbf{e_j}))$$" 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "### Design 2\n", 146 | "\n" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [] 155 | } 156 | ], 157 | "metadata": { 158 | "kernelspec": { 159 | "display_name": "Python 3", 160 | "language": "python", 161 | "name": "python3" 162 | }, 163 | "language_info": { 164 | "codemirror_mode": { 165 | "name": "ipython", 166 | "version": 3 167 | }, 168 | "file_extension": ".py", 169 | "mimetype": "text/x-python", 170 | "name": "python", 171 | "nbconvert_exporter": "python", 172 | "pygments_lexer": "ipython3", 173 | "version": "3.6.5" 174 | } 175 | }, 176 | "nbformat": 4, 177 | "nbformat_minor": 2 178 | } 179 | -------------------------------------------------------------------------------- /docs/design/survey.rst: -------------------------------------------------------------------------------- 1 | =========================================================================== 2 | Notes on Active Learning, Bandits, Choice and Design of Experiments (ABCDE) 3 | =========================================================================== 4 | 5 | There are four ideas which are often used for eliciting human 6 | responses using machine learning predictors. At a high level they are 7 | similar is spirit, but they have different foundations which lead to 8 | different formulations. The ideas are active learning, bandits and 9 | experimental design. Related to this but with literature from a different 10 | field is social choice theory, which looks at how individual preferences are aggregated. 11 | 12 | Overview of ABCDE 13 | ================= 14 | 15 | Active Learning 16 | --------------- 17 | 18 | Active learning considers the setting where the agent interacts with 19 | its environment to procure a training set, rather than passively 20 | receiving i.i.d. samples from some underlying distribution. 21 | 22 | It is often assumed that the environment is infinite (e.g. $R^d$) and 23 | the agent has to choose a location, $x$, to query. The oracle then returns 24 | the label $y$. It is often assumed that there is no noise in the label, 25 | and hence there is no benefit of querying the same point $x$ again. In 26 | many practical applications, the environment is considered to be 27 | finite (but large). This is called the pool-based active learning. 28 | 29 | The active learning algorithm is often compared to the passive 30 | learning algorithm. 31 | 32 | Bandits 33 | ------- 34 | 35 | A bandit problem is a sequential allocation problem defined by a set 36 | of actions. The agent chooses an action at each time step, and the 37 | environment returns a reward. The aim of the agent is to maximise reward. 38 | 39 | In basic settings, the set of actions is considered to be 40 | finite. There are three fundamental formalisations of the bandit 41 | problem, depending on the assumed nature of the reward process: 42 | stochastic, adversarial and Markovian. In all three settings the 43 | reward is uncertain, and hence the agent may have to play a particular 44 | action repeatedly. 45 | 46 | The agent is compared to a static agent which has played the best 47 | action. This difference in reward is called regret. 48 | 49 | Experimental Design 50 | ------------------- 51 | 52 | In contrast to active learning, experimental design considers the problem of regression, 53 | i.e. where the label $y\in R$ is a real number. 54 | 55 | The problem to be solved in experimental design is to choose a set of 56 | trials (say of size N) to gather enough information about the object 57 | of interest. The goal is to maximise the information obtained about 58 | the parameters of the model (of the object). 59 | 60 | It is often assumed that the observations at the N trials are 61 | independent. When N is finite this is called exact design, otherwise 62 | it is called approximate or continuous design. The environment is 63 | assumed to be infinite (e.g. $R^d$) and the observations are scalar real variables. 64 | 65 | 66 | ============== 67 | Unsorted notes 68 | ============== 69 | 70 | * Thompson sampling 71 | * Upper Confidence Bound 72 | 73 | Notes on UCB for binary rewards 74 | ------------------------------- 75 | 76 | In the special case when the rewards of the arms are {0,1}, we can get much tighter analysis. See [pymaBandits](http://mloss.org/software/view/415/). This is also implemented in this repository under ```python/digbeta```. 77 | 78 | 79 | Notes on UCB for graphs 80 | ----------------------- 81 | 82 | *Spectral Bandits for Smooth Graph Functions 83 | Michal Valko, Remi Munos, Branislav Kveton, Tomas Kocak 84 | ICML 2014* 85 | 86 | Study bandit problem where the arms are the nodes of a graph and the expected payoff of pulling an arm is a smooth function on this graph. 87 | 88 | Assume that the graph is known, and its edges represent the similarities of the nodes. At time $t$, choose a node and observe its payoff. Based on the payoff, update model. 89 | 90 | Assume that number of nodes $N$ is large, and interested in the regime $t < N$. 91 | 92 | 93 | 94 | 95 | Related Literature 96 | ================== 97 | 98 | This is an unsorted list of references. 99 | 100 | * Prediction, Learning, and Games, 101 | Nicolo Cesa-Bianchi, Gabor Lugosi 102 | Cambridge University Press, 2006 103 | 104 | * Active Learning Literature Survey 105 | Burr Settles 106 | Computer Sciences Technical Report 1648 107 | University of Wisconsin–Madison, 2010 108 | 109 | * Regret Analysis of Stochastic and Nonstochastic Multi-armed Bandit Problems 110 | Sebastien Bubeck, Nicolo Cesa-Bianchi 111 | Foundations and Trends in Machine Learning, Vol 5, No 1, 2012, pp. 1-122 112 | 113 | * Spectral Bandits for Smooth Graph Functions 114 | Michal Valko, Remi Munos, Branislav Kveton, Tomas Kocak 115 | ICML 2014 116 | 117 | * Spectral Thompson Sampling 118 | Tomas Kocak, Michal Valko, Remi Munos, Shipra Agrawal 119 | AAAI 2014 120 | 121 | * An Analysis of Active Learning Strategies for Sequence Labeling Tasks 122 | Burr Settles, Mark Craven 123 | EMNLP 2008 124 | 125 | * Margin-based active learning for structured predictions 126 | Kevin Small, Dan Roth 127 | International Journal of Machine Learning and Cybernetics, 2010, 1:3-25 128 | 129 | * Emilie Kaufmann, Nathaniel Korda and Remi Munos 130 | Thompson Sampling: An Asymptotically Optimal Finite Time Analysis, ALT 2012 131 | 132 | * Thompson Sampling for 1-Dimensional Exponential Family Bandits 133 | Nathaniel Korda, Emilie Kaufmann, Remi Munos 134 | NIPS 2013 135 | 136 | * On Bayesian Upper Confidence Bounds for Bandit Problems 137 | Emilie Kaufmann, Olivier Cappe, Aurelien Garivier 138 | AISTATS 2012 139 | 140 | * Building Bridges: Viewing Active Learning from the Multi-Armed Bandit Lens 141 | Ravi Ganti, Alexander G. Gray 142 | UAI 2013 143 | 144 | * From Theories to Queries: Active Learning in Practice 145 | Burr Settles 146 | JMLR W&CP, NIPS 2011 Workshop on Active Learning and Experimental Design 147 | 148 | * Contextual Gaussian Process Bandit Optimization. 149 | Andreas Krause, Cheng Soon Ong 150 | NIPS 2011 151 | 152 | * Contextual Bandit for Active Learning: Active Thompson Sampling. 153 | Djallel Bouneffouf, Romain Laroche, Tanguy Urvoy, Raphael Feraud, Robin Allesiardo. 154 | NIPS 2014 155 | 156 | * Towards Anytime Active Learning: Interrupting Experts to Reduce Annotation Costs 157 | Maria Ramirez-Loaiza, Aron Culotta, Mustafa Bilgic 158 | SIGKDD 2013 159 | 160 | * Actively Learning Ontology Matching via User Interaction 161 | Feng Shi, Juanzi Li, Jie Tang, Guotong Xie, Hanyu Li 162 | ISWC 2009 163 | 164 | * A Novel Method for Measuring Semantic Similarity for XML Schema Matching 165 | Buhwan Jeong, Daewon Lee, Hyunbo Cho, Jaewook Lee 166 | Expert Systems with Applications 2008 167 | 168 | * Tamr Product White Paper 169 | http://www.tamr.com/tamr-technical-overview/ 170 | 171 | * Design of Experiments in Nonlinear Models 172 | Luc Pronzato, Andrej Pazman 173 | Springer 2013 174 | 175 | * Optimisation in space of measures and optimal design 176 | Ilya Molchanov and Sergei Zuyev 177 | ESAIM: Probability and Statistics, Vol. 8, pp. 12-24, 2004 178 | 179 | * Active Learning for logistic regression: an evaluation 180 | Andrew I. Schein and Lyle H. Ungar 181 | Machine Learning, 2007, 68: 235-265 182 | 183 | * Learning to Optimize Via Information-Directed Sampling 184 | Daniel Russo and Benjamin Van Roy 185 | 186 | * The KL-UCB Algorithm for Bounded Stochastic Bandits and Beyond 187 | Aurelien Garivier and Olivier Cappe 188 | COLT 2011 189 | 190 | * A Finite-Time Analysis of Multi-armed Bandits Problems with Kullback-Leibler Divergences 191 | Odalric-Ambrym Maillard, Remi Munos, Gilles Stoltz 192 | COLT 2011 193 | 194 | * Kullback-Leibler Upper Confidence Bounds for Optimal Sequential Allocation 195 | Olivier Cappe, Aurelien Garivier, Odalric-Ambrym Maillard, Remi Munos, Gilles Stoltz 196 | Annals of Statistics, 2013 197 | 198 | * Xiaojin Zhu, Zoubin Ghahramani, John Lafferty, 199 | Semi-Supervised Learning Using Gaussian Fields and Harmonic Functions 200 | ICML 2003 201 | 202 | * Efficient and Parsimonious Agnostic Active Learning 203 | Tzu-Kuo Huang, Alekh Agarwal, Daniel J. Hsu, John Langford, Robert E. Schapire 204 | NIPS 2015 205 | 206 | * NEXT: A System for Real-World Development, Evaluation, and Application of Active Learning 207 | Kevin Jamieson, Lalit Jain, Chris Fernandez, Nick Glattard, Robert Nowak 208 | NIPS 2015 209 | 210 | * Baram, Y., El-Yaniv, R., and Luz, K. (2004). 211 | Online choice of active learning algorithms. Journal of Machine Learning Research, 5:255–291. 212 | 213 | * Hsu, W.-N. and Lin, H.-T. (2015). Active learning by learning. In AAAI 15. 214 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | REM Command file for Sphinx documentation 4 | 5 | if "%SPHINXBUILD%" == "" ( 6 | set SPHINXBUILD=sphinx-build 7 | ) 8 | set BUILDDIR=_build 9 | set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . 10 | set I18NSPHINXOPTS=%SPHINXOPTS% . 11 | if NOT "%PAPER%" == "" ( 12 | set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% 13 | set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% 14 | ) 15 | 16 | if "%1" == "" goto help 17 | 18 | if "%1" == "help" ( 19 | :help 20 | echo.Please use `make ^` where ^ is one of 21 | echo. html to make standalone HTML files 22 | echo. dirhtml to make HTML files named index.html in directories 23 | echo. singlehtml to make a single large HTML file 24 | echo. pickle to make pickle files 25 | echo. json to make JSON files 26 | echo. htmlhelp to make HTML files and a HTML help project 27 | echo. qthelp to make HTML files and a qthelp project 28 | echo. devhelp to make HTML files and a Devhelp project 29 | echo. epub to make an epub 30 | echo. epub3 to make an epub3 31 | echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter 32 | echo. text to make text files 33 | echo. man to make manual pages 34 | echo. texinfo to make Texinfo files 35 | echo. gettext to make PO message catalogs 36 | echo. changes to make an overview over all changed/added/deprecated items 37 | echo. xml to make Docutils-native XML files 38 | echo. pseudoxml to make pseudoxml-XML files for display purposes 39 | echo. linkcheck to check all external links for integrity 40 | echo. doctest to run all doctests embedded in the documentation if enabled 41 | echo. coverage to run coverage check of the documentation if enabled 42 | echo. dummy to check syntax errors of document sources 43 | goto end 44 | ) 45 | 46 | if "%1" == "clean" ( 47 | for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i 48 | del /q /s %BUILDDIR%\* 49 | goto end 50 | ) 51 | 52 | 53 | REM Check if sphinx-build is available and fallback to Python version if any 54 | %SPHINXBUILD% 1>NUL 2>NUL 55 | if errorlevel 9009 goto sphinx_python 56 | goto sphinx_ok 57 | 58 | :sphinx_python 59 | 60 | set SPHINXBUILD=python -m sphinx.__init__ 61 | %SPHINXBUILD% 2> nul 62 | if errorlevel 9009 ( 63 | echo. 64 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 65 | echo.installed, then set the SPHINXBUILD environment variable to point 66 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 67 | echo.may add the Sphinx directory to PATH. 68 | echo. 69 | echo.If you don't have Sphinx installed, grab it from 70 | echo.http://sphinx-doc.org/ 71 | exit /b 1 72 | ) 73 | 74 | :sphinx_ok 75 | 76 | 77 | if "%1" == "html" ( 78 | %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html 79 | if errorlevel 1 exit /b 1 80 | echo. 81 | echo.Build finished. The HTML pages are in %BUILDDIR%/html. 82 | goto end 83 | ) 84 | 85 | if "%1" == "dirhtml" ( 86 | %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml 87 | if errorlevel 1 exit /b 1 88 | echo. 89 | echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. 90 | goto end 91 | ) 92 | 93 | if "%1" == "singlehtml" ( 94 | %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml 95 | if errorlevel 1 exit /b 1 96 | echo. 97 | echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. 98 | goto end 99 | ) 100 | 101 | if "%1" == "pickle" ( 102 | %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle 103 | if errorlevel 1 exit /b 1 104 | echo. 105 | echo.Build finished; now you can process the pickle files. 106 | goto end 107 | ) 108 | 109 | if "%1" == "json" ( 110 | %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json 111 | if errorlevel 1 exit /b 1 112 | echo. 113 | echo.Build finished; now you can process the JSON files. 114 | goto end 115 | ) 116 | 117 | if "%1" == "htmlhelp" ( 118 | %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp 119 | if errorlevel 1 exit /b 1 120 | echo. 121 | echo.Build finished; now you can run HTML Help Workshop with the ^ 122 | .hhp project file in %BUILDDIR%/htmlhelp. 123 | goto end 124 | ) 125 | 126 | if "%1" == "qthelp" ( 127 | %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp 128 | if errorlevel 1 exit /b 1 129 | echo. 130 | echo.Build finished; now you can run "qcollectiongenerator" with the ^ 131 | .qhcp project file in %BUILDDIR%/qthelp, like this: 132 | echo.^> qcollectiongenerator %BUILDDIR%\qthelp\Acton.qhcp 133 | echo.To view the help file: 134 | echo.^> assistant -collectionFile %BUILDDIR%\qthelp\Acton.ghc 135 | goto end 136 | ) 137 | 138 | if "%1" == "devhelp" ( 139 | %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp 140 | if errorlevel 1 exit /b 1 141 | echo. 142 | echo.Build finished. 143 | goto end 144 | ) 145 | 146 | if "%1" == "epub" ( 147 | %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub 148 | if errorlevel 1 exit /b 1 149 | echo. 150 | echo.Build finished. The epub file is in %BUILDDIR%/epub. 151 | goto end 152 | ) 153 | 154 | if "%1" == "epub3" ( 155 | %SPHINXBUILD% -b epub3 %ALLSPHINXOPTS% %BUILDDIR%/epub3 156 | if errorlevel 1 exit /b 1 157 | echo. 158 | echo.Build finished. The epub3 file is in %BUILDDIR%/epub3. 159 | goto end 160 | ) 161 | 162 | if "%1" == "latex" ( 163 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 164 | if errorlevel 1 exit /b 1 165 | echo. 166 | echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. 167 | goto end 168 | ) 169 | 170 | if "%1" == "latexpdf" ( 171 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 172 | cd %BUILDDIR%/latex 173 | make all-pdf 174 | cd %~dp0 175 | echo. 176 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 177 | goto end 178 | ) 179 | 180 | if "%1" == "latexpdfja" ( 181 | %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex 182 | cd %BUILDDIR%/latex 183 | make all-pdf-ja 184 | cd %~dp0 185 | echo. 186 | echo.Build finished; the PDF files are in %BUILDDIR%/latex. 187 | goto end 188 | ) 189 | 190 | if "%1" == "text" ( 191 | %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text 192 | if errorlevel 1 exit /b 1 193 | echo. 194 | echo.Build finished. The text files are in %BUILDDIR%/text. 195 | goto end 196 | ) 197 | 198 | if "%1" == "man" ( 199 | %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man 200 | if errorlevel 1 exit /b 1 201 | echo. 202 | echo.Build finished. The manual pages are in %BUILDDIR%/man. 203 | goto end 204 | ) 205 | 206 | if "%1" == "texinfo" ( 207 | %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo 208 | if errorlevel 1 exit /b 1 209 | echo. 210 | echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. 211 | goto end 212 | ) 213 | 214 | if "%1" == "gettext" ( 215 | %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale 216 | if errorlevel 1 exit /b 1 217 | echo. 218 | echo.Build finished. The message catalogs are in %BUILDDIR%/locale. 219 | goto end 220 | ) 221 | 222 | if "%1" == "changes" ( 223 | %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes 224 | if errorlevel 1 exit /b 1 225 | echo. 226 | echo.The overview file is in %BUILDDIR%/changes. 227 | goto end 228 | ) 229 | 230 | if "%1" == "linkcheck" ( 231 | %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck 232 | if errorlevel 1 exit /b 1 233 | echo. 234 | echo.Link check complete; look for any errors in the above output ^ 235 | or in %BUILDDIR%/linkcheck/output.txt. 236 | goto end 237 | ) 238 | 239 | if "%1" == "doctest" ( 240 | %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest 241 | if errorlevel 1 exit /b 1 242 | echo. 243 | echo.Testing of doctests in the sources finished, look at the ^ 244 | results in %BUILDDIR%/doctest/output.txt. 245 | goto end 246 | ) 247 | 248 | if "%1" == "coverage" ( 249 | %SPHINXBUILD% -b coverage %ALLSPHINXOPTS% %BUILDDIR%/coverage 250 | if errorlevel 1 exit /b 1 251 | echo. 252 | echo.Testing of coverage in the sources finished, look at the ^ 253 | results in %BUILDDIR%/coverage/python.txt. 254 | goto end 255 | ) 256 | 257 | if "%1" == "xml" ( 258 | %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml 259 | if errorlevel 1 exit /b 1 260 | echo. 261 | echo.Build finished. The XML files are in %BUILDDIR%/xml. 262 | goto end 263 | ) 264 | 265 | if "%1" == "pseudoxml" ( 266 | %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml 267 | if errorlevel 1 exit /b 1 268 | echo. 269 | echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. 270 | goto end 271 | ) 272 | 273 | if "%1" == "dummy" ( 274 | %SPHINXBUILD% -b dummy %ALLSPHINXOPTS% %BUILDDIR%/dummy 275 | if errorlevel 1 exit /b 1 276 | echo. 277 | echo.Build finished. Dummy builder generates no files. 278 | goto end 279 | ) 280 | 281 | :end 282 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don\'t have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help 23 | help: 24 | @echo "Please use \`make ' where is one of" 25 | @echo " html to make standalone HTML files" 26 | @echo " dirhtml to make HTML files named index.html in directories" 27 | @echo " singlehtml to make a single large HTML file" 28 | @echo " pickle to make pickle files" 29 | @echo " json to make JSON files" 30 | @echo " htmlhelp to make HTML files and a HTML help project" 31 | @echo " qthelp to make HTML files and a qthelp project" 32 | @echo " applehelp to make an Apple Help Book" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " epub3 to make an epub3" 36 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 37 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 38 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 39 | @echo " text to make text files" 40 | @echo " man to make manual pages" 41 | @echo " texinfo to make Texinfo files" 42 | @echo " info to make Texinfo files and run them through makeinfo" 43 | @echo " gettext to make PO message catalogs" 44 | @echo " changes to make an overview of all changed/added/deprecated items" 45 | @echo " xml to make Docutils-native XML files" 46 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 47 | @echo " linkcheck to check all external links for integrity" 48 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 49 | @echo " coverage to run coverage check of the documentation (if enabled)" 50 | @echo " dummy to check syntax errors of document sources" 51 | 52 | .PHONY: clean 53 | clean: 54 | rm -rf $(BUILDDIR)/* 55 | 56 | .PHONY: html 57 | html: 58 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 61 | 62 | .PHONY: dirhtml 63 | dirhtml: 64 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 65 | @echo 66 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 67 | 68 | .PHONY: singlehtml 69 | singlehtml: 70 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 71 | @echo 72 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 73 | 74 | .PHONY: pickle 75 | pickle: 76 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 77 | @echo 78 | @echo "Build finished; now you can process the pickle files." 79 | 80 | .PHONY: json 81 | json: 82 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 83 | @echo 84 | @echo "Build finished; now you can process the JSON files." 85 | 86 | .PHONY: htmlhelp 87 | htmlhelp: 88 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 89 | @echo 90 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 91 | ".hhp project file in $(BUILDDIR)/htmlhelp." 92 | 93 | .PHONY: qthelp 94 | qthelp: 95 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 96 | @echo 97 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 98 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 99 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/Acton.qhcp" 100 | @echo "To view the help file:" 101 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/Acton.qhc" 102 | 103 | .PHONY: applehelp 104 | applehelp: 105 | $(SPHINXBUILD) -b applehelp $(ALLSPHINXOPTS) $(BUILDDIR)/applehelp 106 | @echo 107 | @echo "Build finished. The help book is in $(BUILDDIR)/applehelp." 108 | @echo "N.B. You won't be able to view it unless you put it in" \ 109 | "~/Library/Documentation/Help or install it in your application" \ 110 | "bundle." 111 | 112 | .PHONY: devhelp 113 | devhelp: 114 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 115 | @echo 116 | @echo "Build finished." 117 | @echo "To view the help file:" 118 | @echo "# mkdir -p $$HOME/.local/share/devhelp/Acton" 119 | @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/Acton" 120 | @echo "# devhelp" 121 | 122 | .PHONY: epub 123 | epub: 124 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 125 | @echo 126 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 127 | 128 | .PHONY: epub3 129 | epub3: 130 | $(SPHINXBUILD) -b epub3 $(ALLSPHINXOPTS) $(BUILDDIR)/epub3 131 | @echo 132 | @echo "Build finished. The epub3 file is in $(BUILDDIR)/epub3." 133 | 134 | .PHONY: latex 135 | latex: 136 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 137 | @echo 138 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 139 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 140 | "(use \`make latexpdf' here to do that automatically)." 141 | 142 | .PHONY: latexpdf 143 | latexpdf: 144 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 145 | @echo "Running LaTeX files through pdflatex..." 146 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 147 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 148 | 149 | .PHONY: latexpdfja 150 | latexpdfja: 151 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 152 | @echo "Running LaTeX files through platex and dvipdfmx..." 153 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 154 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 155 | 156 | .PHONY: text 157 | text: 158 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 159 | @echo 160 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 161 | 162 | .PHONY: man 163 | man: 164 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 165 | @echo 166 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 167 | 168 | .PHONY: texinfo 169 | texinfo: 170 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 171 | @echo 172 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 173 | @echo "Run \`make' in that directory to run these through makeinfo" \ 174 | "(use \`make info' here to do that automatically)." 175 | 176 | .PHONY: info 177 | info: 178 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 179 | @echo "Running Texinfo files through makeinfo..." 180 | make -C $(BUILDDIR)/texinfo info 181 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 182 | 183 | .PHONY: gettext 184 | gettext: 185 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 186 | @echo 187 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 188 | 189 | .PHONY: changes 190 | changes: 191 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 192 | @echo 193 | @echo "The overview file is in $(BUILDDIR)/changes." 194 | 195 | .PHONY: linkcheck 196 | linkcheck: 197 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 198 | @echo 199 | @echo "Link check complete; look for any errors in the above output " \ 200 | "or in $(BUILDDIR)/linkcheck/output.txt." 201 | 202 | .PHONY: doctest 203 | doctest: 204 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 205 | @echo "Testing of doctests in the sources finished, look at the " \ 206 | "results in $(BUILDDIR)/doctest/output.txt." 207 | 208 | .PHONY: coverage 209 | coverage: 210 | $(SPHINXBUILD) -b coverage $(ALLSPHINXOPTS) $(BUILDDIR)/coverage 211 | @echo "Testing of coverage in the sources finished, look at the " \ 212 | "results in $(BUILDDIR)/coverage/python.txt." 213 | 214 | .PHONY: xml 215 | xml: 216 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 217 | @echo 218 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 219 | 220 | .PHONY: pseudoxml 221 | pseudoxml: 222 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 223 | @echo 224 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 225 | 226 | .PHONY: dummy 227 | dummy: 228 | $(SPHINXBUILD) -b dummy $(ALLSPHINXOPTS) $(BUILDDIR)/dummy 229 | @echo 230 | @echo "Build finished. Dummy builder generates no files." 231 | -------------------------------------------------------------------------------- /acton/cli.py: -------------------------------------------------------------------------------- 1 | """Command-line interface for Acton.""" 2 | 3 | import logging 4 | import struct 5 | import sys 6 | from typing import BinaryIO, Iterable, List 7 | 8 | import acton.acton 9 | import acton.predictors 10 | import acton.proto.wrappers 11 | import acton.recommenders 12 | import click 13 | 14 | 15 | def read_bytes_from_buffer(n: int, buffer: BinaryIO) -> bytes: 16 | """Reads n bytes from stdin, blocking until all bytes are received. 17 | 18 | Parameters 19 | ---------- 20 | n 21 | How many bytes to read. 22 | buffer 23 | Which buffer to read from. 24 | 25 | Returns 26 | ------- 27 | bytes 28 | Exactly n bytes. 29 | """ 30 | b = b'' 31 | while len(b) < n: 32 | b += buffer.read(n - len(b)) 33 | assert len(b) == n 34 | return b 35 | 36 | 37 | def read_binary() -> bytes: 38 | """Reads binary data from stdin. 39 | 40 | Notes 41 | ----- 42 | The first eight bytes are expected to be the length of the input data as an 43 | unsigned long long. 44 | 45 | Returns 46 | ------- 47 | bytes 48 | Binary data. 49 | """ 50 | logging.debug('Reading 8 bytes from stdin.') 51 | length = read_bytes_from_buffer(8, sys.stdin.buffer) 52 | length, = struct.unpack(' Iterable[str]: 239 | """Yields lines from stdin.""" 240 | for line in sys.stdin: 241 | line = line.strip() 242 | logging.debug('Read line {} from stdin.'.format(repr(line))) 243 | if line: 244 | yield line 245 | 246 | 247 | @click.command() 248 | @click.option('--data', 249 | type=click.Path(exists=True, dir_okay=False), 250 | help='Path to labels file', 251 | required=False) 252 | @click.option('-l', '--label', 253 | type=str, 254 | help='Column name of labels', 255 | required=False) 256 | @click.option('-f', '--feature', 257 | type=str, 258 | multiple=True, 259 | help='Column names of features') 260 | @click.option('--labeller-accuracy', 261 | type=float, 262 | help='Accuracy of simulated labellers', 263 | default=1.0) 264 | @click.option('--pandas-key', 265 | type=str, 266 | default='', 267 | help='Key for pandas HDF5') 268 | @click.option('-v', '--verbose', 269 | is_flag=True, 270 | help='Verbose output') 271 | def label( 272 | data: str, 273 | feature: List[str], 274 | label: str, 275 | labeller_accuracy: float, 276 | verbose: bool, 277 | pandas_key: str, 278 | ): 279 | # Logging setup. 280 | logging.warning('Not implemented: labeller_accuracy') 281 | logging.captureWarnings(True) 282 | if verbose: 283 | logging.root.setLevel(logging.DEBUG) 284 | 285 | # If any arguments are specified, expect all arguments. 286 | if data or label or pandas_key: 287 | if not data or not label: 288 | raise ValueError('--data, --label, or --pandas-key specified, but ' 289 | 'missing --data or --label.') 290 | 291 | # Handle database arguments. 292 | data_path = data 293 | feature_cols = feature 294 | label_col = label 295 | 296 | # Read IDs from stdin. 297 | ids_to_label = [int(i) for i in lines_from_stdin()] 298 | 299 | # There wasn't a recommendations protobuf given, so we have no existing 300 | # labelled instances. 301 | labelled_ids = [] 302 | 303 | # Construct the recommendations protobuf. 304 | DB, db_kwargs = acton.acton.get_DB(data_path, pandas_key=pandas_key) 305 | db_kwargs['label_col'] = label_col 306 | db_kwargs['feature_cols'] = feature_cols 307 | with DB(data_path, **db_kwargs) as db: 308 | recs = acton.proto.wrappers.Recommendations.make( 309 | recommended_ids=ids_to_label, 310 | labelled_ids=labelled_ids, 311 | recommender='None', 312 | db=db) 313 | else: 314 | # Read a recommendations protobuf from stdin. 315 | recs = read_binary() 316 | recs = acton.proto.wrappers.Recommendations.deserialise(recs) 317 | 318 | proto = acton.acton.label(recs) 319 | write_binary(proto.proto.SerializeToString()) 320 | 321 | 322 | if __name__ == '__main__': 323 | sys.exit(main()) 324 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Acton documentation build configuration file, created by 5 | # sphinx-quickstart on Sat Jan 21 12:22:55 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | import sys 17 | import os 18 | 19 | # If extensions (or modules to document with autodoc) are in another directory, 20 | # add these directories to sys.path here. If the directory is relative to the 21 | # documentation root, use os.path.abspath to make it absolute, like shown here. 22 | sys.path.insert(0, os.path.abspath('..')) 23 | sys.path.insert(0, os.path.abspath('.')) 24 | 25 | # -- General configuration ------------------------------------------------ 26 | 27 | # If your documentation needs a minimal Sphinx version, state it here. 28 | needs_sphinx = '1.3' 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | 'sphinx.ext.autodoc', 35 | 'sphinx.ext.viewcode', 36 | 'sphinx.ext.napoleon', 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ['_templates'] 41 | 42 | # The suffix(es) of source filenames. 43 | # You can specify multiple suffix as a list of string: 44 | # source_suffix = ['.rst', '.md'] 45 | source_suffix = '.rst' 46 | 47 | # The encoding of source files. 48 | #source_encoding = 'utf-8-sig' 49 | 50 | # The master toctree document. 51 | master_doc = 'index' 52 | 53 | # General information about the project. 54 | project = 'Acton' 55 | copyright = '2017, Matthew Alger & Cheng Soon Ong' 56 | author = 'Matthew Alger & Cheng Soon Ong' 57 | 58 | # The version info for the project you're documenting, acts as replacement for 59 | # |version| and |release|, also used in various other places throughout the 60 | # built documents. 61 | # 62 | # The short X.Y version. 63 | version = '0.3.3' 64 | # The full version, including alpha/beta/rc tags. 65 | release = version 66 | 67 | # The language for content autogenerated by Sphinx. Refer to documentation 68 | # for a list of supported languages. 69 | # 70 | # This is also used if you do content translation via gettext catalogs. 71 | # Usually you set "language" from the command line for these cases. 72 | language = None 73 | 74 | # There are two options for replacing |today|: either, you set today to some 75 | # non-false value, then it is used: 76 | #today = '' 77 | # Else, today_fmt is used as the format for a strftime call. 78 | #today_fmt = '%B %d, %Y' 79 | 80 | # List of patterns, relative to source directory, that match files and 81 | # directories to ignore when looking for source files. 82 | # This patterns also effect to html_static_path and html_extra_path 83 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 84 | 85 | # The reST default role (used for this markup: `text`) to use for all 86 | # documents. 87 | #default_role = None 88 | 89 | # If true, '()' will be appended to :func: etc. cross-reference text. 90 | #add_function_parentheses = True 91 | 92 | # If true, the current module name will be prepended to all description 93 | # unit titles (such as .. function::). 94 | #add_module_names = True 95 | 96 | # If true, sectionauthor and moduleauthor directives will be shown in the 97 | # output. They are ignored by default. 98 | #show_authors = False 99 | 100 | # The name of the Pygments (syntax highlighting) style to use. 101 | pygments_style = 'sphinx' 102 | 103 | # A list of ignored prefixes for module index sorting. 104 | #modindex_common_prefix = [] 105 | 106 | # If true, keep warnings as "system message" paragraphs in the built documents. 107 | #keep_warnings = False 108 | 109 | # If true, `todo` and `todoList` produce output, else they produce nothing. 110 | todo_include_todos = False 111 | 112 | 113 | # -- Options for HTML output ---------------------------------------------- 114 | 115 | # The theme to use for HTML and HTML Help pages. See the documentation for 116 | # a list of builtin themes. 117 | on_rtd = os.environ.get('READTHEDOCS', None) == 'True' 118 | 119 | if not on_rtd: # only import and set the theme if we're building docs locally 120 | import sphinx_rtd_theme 121 | html_theme = 'sphinx_rtd_theme' 122 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 123 | 124 | # Theme options are theme-specific and customize the look and feel of a theme 125 | # further. For a list of options available for each theme, see the 126 | # documentation. 127 | #html_theme_options = {} 128 | 129 | # Add any paths that contain custom themes here, relative to this directory. 130 | #html_theme_path = [] 131 | 132 | # The name for this set of Sphinx documents. 133 | # " v documentation" by default. 134 | #html_title = 'Acton v0.3.1' 135 | 136 | # A shorter title for the navigation bar. Default is the same as html_title. 137 | #html_short_title = None 138 | 139 | # The name of an image file (relative to this directory) to place at the top 140 | # of the sidebar. 141 | #html_logo = None 142 | 143 | # The name of an image file (relative to this directory) to use as a favicon of 144 | # the docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 145 | # pixels large. 146 | #html_favicon = None 147 | 148 | # Add any paths that contain custom static files (such as style sheets) here, 149 | # relative to this directory. They are copied after the builtin static files, 150 | # so a file named "default.css" will overwrite the builtin "default.css". 151 | html_static_path = ['_static'] 152 | 153 | # Add any extra paths that contain custom files (such as robots.txt or 154 | # .htaccess) here, relative to this directory. These files are copied 155 | # directly to the root of the documentation. 156 | #html_extra_path = [] 157 | 158 | # If not None, a 'Last updated on:' timestamp is inserted at every page 159 | # bottom, using the given strftime format. 160 | # The empty string is equivalent to '%b %d, %Y'. 161 | #html_last_updated_fmt = None 162 | 163 | # If true, SmartyPants will be used to convert quotes and dashes to 164 | # typographically correct entities. 165 | #html_use_smartypants = True 166 | 167 | # Custom sidebar templates, maps document names to template names. 168 | #html_sidebars = {} 169 | 170 | # Additional templates that should be rendered to pages, maps page names to 171 | # template names. 172 | #html_additional_pages = {} 173 | 174 | # If false, no module index is generated. 175 | #html_domain_indices = True 176 | 177 | # If false, no index is generated. 178 | #html_use_index = True 179 | 180 | # If true, the index is split into individual pages for each letter. 181 | #html_split_index = False 182 | 183 | # If true, links to the reST sources are added to the pages. 184 | #html_show_sourcelink = True 185 | 186 | # If true, "Created using Sphinx" is shown in the HTML footer. Default is True. 187 | #html_show_sphinx = True 188 | 189 | # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. 190 | #html_show_copyright = True 191 | 192 | # If true, an OpenSearch description file will be output, and all pages will 193 | # contain a tag referring to it. The value of this option must be the 194 | # base URL from which the finished HTML is served. 195 | #html_use_opensearch = '' 196 | 197 | # This is the file name suffix for HTML files (e.g. ".xhtml"). 198 | #html_file_suffix = None 199 | 200 | # Language to be used for generating the HTML full-text search index. 201 | # Sphinx supports the following languages: 202 | # 'da', 'de', 'en', 'es', 'fi', 'fr', 'h', 'it', 'ja' 203 | # 'nl', 'no', 'pt', 'ro', 'r', 'sv', 'tr', 'zh' 204 | #html_search_language = 'en' 205 | 206 | # A dictionary with options for the search language support, empty by default. 207 | # 'ja' uses this config value. 208 | # 'zh' user can custom change `jieba` dictionary path. 209 | #html_search_options = {'type': 'default'} 210 | 211 | # The name of a javascript file (relative to the configuration directory) that 212 | # implements a search results scorer. If empty, the default will be used. 213 | #html_search_scorer = 'scorer.js' 214 | 215 | # Output file base name for HTML help builder. 216 | htmlhelp_basename = 'Actondoc' 217 | 218 | # -- Options for LaTeX output --------------------------------------------- 219 | 220 | latex_elements = { 221 | # The paper size ('letterpaper' or 'a4paper'). 222 | # 'papersize': 'letterpaper', 223 | 224 | # The font size ('10pt', '11pt' or '12pt'). 225 | # 'pointsize': '10pt', 226 | 227 | # Additional stuff for the LaTeX preamble. 228 | # 'preamble': '', 229 | 230 | # Latex figure (float) alignment 231 | # 'figure_align': 'htbp', 232 | } 233 | 234 | # Grouping the document tree into LaTeX files. List of tuples 235 | # (source start file, target name, title, 236 | # author, documentclass [howto, manual, or own class]). 237 | latex_documents = [ 238 | (master_doc, 'Acton.tex', 'Acton Documentation', 239 | 'Matthew Alger \\& Cheng Soon Ong', 'manual'), 240 | ] 241 | 242 | # The name of an image file (relative to this directory) to place at the top of 243 | # the title page. 244 | #latex_logo = None 245 | 246 | # For "manual" documents, if this is true, then toplevel headings are parts, 247 | # not chapters. 248 | #latex_use_parts = False 249 | 250 | # If true, show page references after internal links. 251 | #latex_show_pagerefs = False 252 | 253 | # If true, show URL addresses after external links. 254 | #latex_show_urls = False 255 | 256 | # Documents to append as an appendix to all manuals. 257 | #latex_appendices = [] 258 | 259 | # If false, no module index is generated. 260 | #latex_domain_indices = True 261 | 262 | 263 | # -- Options for manual page output --------------------------------------- 264 | 265 | # One entry per manual page. List of tuples 266 | # (source start file, name, description, authors, manual section). 267 | man_pages = [ 268 | (master_doc, 'acton', 'Acton Documentation', 269 | [author], 1) 270 | ] 271 | 272 | # If true, show URL addresses after external links. 273 | #man_show_urls = False 274 | 275 | 276 | # -- Options for Texinfo output ------------------------------------------- 277 | 278 | # Grouping the document tree into Texinfo files. List of tuples 279 | # (source start file, target name, title, author, 280 | # dir menu entry, description, category) 281 | texinfo_documents = [ 282 | (master_doc, 'Acton', 'Acton Documentation', 283 | author, 'Acton', 'One line description of project.', 284 | 'Miscellaneous'), 285 | ] 286 | 287 | # Documents to append as an appendix to all manuals. 288 | #texinfo_appendices = [] 289 | 290 | # If false, no module index is generated. 291 | #texinfo_domain_indices = True 292 | 293 | # How to display URL addresses: 'footnote', 'no', or 'inline'. 294 | #texinfo_show_urls = 'footnote' 295 | 296 | # If true, do not generate a @detailmenu in the "Top" node's menu. 297 | #texinfo_no_detailmenu = False 298 | 299 | from mock import MagicMock 300 | 301 | # These blank classes are needed to mock sklearn's base classes without hitting 302 | # metaclass errors. 303 | 304 | 305 | class BaseEstimator: 306 | def __init__(self, *args, **kwargs): 307 | pass 308 | 309 | def __getattr__(self, name): 310 | return MagicMock() 311 | 312 | 313 | class ClassifierMixin: 314 | def __init__(self, *args, **kwargs): 315 | pass 316 | 317 | def __getattr__(self, name): 318 | return MagicMock() 319 | 320 | 321 | skbm = MagicMock() 322 | skbm.BaseEstimator = BaseEstimator 323 | skbm.ClassifierMixin = ClassifierMixin 324 | 325 | # This is a hack to get the autodocs building with the typing module and C 326 | # extensions. If you have C extensions (notably protobuf) then ReadTheDocs can't 327 | # build docs, so they need to be mocked out. However, mocks break the typing 328 | # module, because typing checks isinstance(arg, type) in generics like Optional 329 | # and Union. Finally, mocking individual protobuf modules is difficult and 330 | # fragile. Thus this hack: A subclass of MagicMock that has class attributes 331 | # that are also subclasses of MagicMock (and so on). 332 | 333 | 334 | class MockedClassAttributes(type): 335 | def __getattr__(cls, key): 336 | return get_mock_type(key) 337 | 338 | 339 | def get_mock_type(name): 340 | newtype = MockedClassAttributes(name, (MagicMock,), {}) 341 | return newtype 342 | 343 | # This hooks the mock into the protobuf library. 344 | 345 | 346 | def GeneratedProtocolMessageType(name, *args, **kwargs): 347 | return get_mock_type(name) 348 | 349 | 350 | gp = MagicMock() 351 | gp.reflection = gpr = MagicMock() 352 | gpr.GeneratedProtocolMessageType = GeneratedProtocolMessageType 353 | 354 | 355 | class Mock(MagicMock): 356 | @classmethod 357 | def __getattr__(cls, name): 358 | if name == 'base': 359 | return skbm 360 | 361 | return MagicMock() 362 | 363 | 364 | MOCK_MODULES = [ 365 | 'astropy', 366 | 'astropy.io', 367 | 'astropy.io.ascii', 368 | 'astropy.io.fits', 369 | 'astropy.table', 370 | 'click', 371 | 'GPy', 372 | 'google', 373 | 'google.protobuf.json_format', 374 | 'h5py', 375 | 'matplotlib', 376 | 'matplotlib.pyplot', 377 | 'numpy', 378 | 'pandas', 379 | 'protobuf', 380 | 'scipy', 381 | 'scipy.stats', 382 | 'sklearn', 383 | 'sklearn.model_selection', 384 | 'sklearn.datasets', 385 | 'sklearn.linear_model', 386 | 'sklearn.metrics', 387 | 'sklearn.neighbors', 388 | 'sklearn.preprocessing', 389 | 'sklearn.utils', 390 | 'sklearn.utils.estimator_checks', 391 | 'sklearn.utils.multiclass', 392 | 'sklearn.utils.validation', 393 | 'tables', 394 | ] 395 | sys.modules.update((mod_name, Mock()) for mod_name in MOCK_MODULES) 396 | sys.modules.update([('sklearn.base', skbm)]) 397 | sys.modules.update([('google.protobuf.reflection', gpr)]) 398 | sys.modules.update([('google.protobuf', gp)]) 399 | -------------------------------------------------------------------------------- /acton/recommenders.py: -------------------------------------------------------------------------------- 1 | """Recommender classes.""" 2 | 3 | from abc import ABC, abstractmethod 4 | import logging 5 | from typing import Sequence 6 | import warnings 7 | 8 | import acton.database 9 | import numpy 10 | import scipy.stats 11 | 12 | _E_ALPHA = 1. 13 | _E_BETA = 1. 14 | _R_ALPHA = 1. 15 | _R_BETA = 1. 16 | _P_SAMPLE_GAP = 5 17 | _P_SAMPLE = False 18 | _PARALLEL = False 19 | _MAX_THREAD = 4 20 | _POS_VAL = 1 21 | _MC_MOVE = 1 22 | _SGLD = False 23 | _NMINI = 1 24 | _GIBBS_INIT = True 25 | _SAMPLE_ALL = True 26 | 27 | _VAR_E = 1. 28 | _VAR_R = 1. 29 | _VAR_X = 0.01 30 | 31 | _DEST = '' 32 | _LOG = '' 33 | 34 | a = 0.001 35 | b = 0.01 36 | tau = -0.55 37 | 38 | MIN_VAL = numpy.iinfo(numpy.int32).min 39 | 40 | 41 | def choose_mmr(features: numpy.ndarray, scores: numpy.ndarray, n: int, 42 | l: float=0.5) -> Sequence[int]: 43 | """Chooses n scores using maximal marginal relevance. 44 | 45 | Notes 46 | ----- 47 | Scores are chosen from highest to lowest. If there are less scores to choose 48 | from than requested, all scores will be returned in order of preference. 49 | 50 | Parameters 51 | ---------- 52 | scores 53 | 1D array of scores. 54 | n 55 | Number of scores to choose. 56 | l 57 | Lambda parameter for MMR. l = 1 gives a relevance-ranked list and l = 0 58 | gives a maximal diversity ranking. 59 | 60 | Returns 61 | ------- 62 | Sequence[int] 63 | List of indices of scores chosen. 64 | """ 65 | if n < 0: 66 | raise ValueError('n must be a non-negative integer.') 67 | 68 | if n == 0: 69 | return [] 70 | 71 | selections = [scores.argmax()] 72 | selections_set = set(selections) 73 | 74 | logging.debug('Running MMR.') 75 | dists = [] 76 | dists_matrix = None 77 | while len(selections) < n: 78 | if len(selections) % (n // 10) == 0: 79 | logging.debug('MMR epoch {}/{}.'.format(len(selections), n)) 80 | # Compute distances for last selection. 81 | last = features[selections[-1]:selections[-1] + 1] 82 | last_dists = numpy.linalg.norm(features - last, axis=1) 83 | dists.append(last_dists) 84 | dists_matrix = numpy.array(dists) 85 | 86 | next_best = None 87 | next_best_margin = float('-inf') 88 | 89 | for i in range(len(scores)): 90 | if i in selections_set: 91 | continue 92 | 93 | margin = l * (scores[i] - (1 - l) * dists_matrix[:, i].max()) 94 | if margin > next_best_margin: 95 | next_best_margin = margin 96 | next_best = i 97 | 98 | if next_best is None: 99 | break 100 | 101 | selections.append(next_best) 102 | selections_set.add(next_best) 103 | 104 | return selections 105 | 106 | 107 | def choose_boltzmann(features: numpy.ndarray, scores: numpy.ndarray, n: int, 108 | temperature: float=1.0) -> Sequence[int]: 109 | """Chooses n scores using a Boltzmann distribution. 110 | 111 | Notes 112 | ----- 113 | Scores are chosen from highest to lowest. If there are less scores to choose 114 | from than requested, all scores will be returned in order of preference. 115 | 116 | Parameters 117 | ---------- 118 | scores 119 | 1D array of scores. 120 | n 121 | Number of scores to choose. 122 | temperature 123 | Temperature parameter for sampling. Higher temperatures give more 124 | diversity. 125 | 126 | Returns 127 | ------- 128 | Sequence[int] 129 | List of indices of scores chosen. 130 | """ 131 | if n < 0: 132 | raise ValueError('n must be a non-negative integer.') 133 | 134 | if n == 0: 135 | return [] 136 | 137 | boltzmann_scores = numpy.exp(scores / temperature) 138 | boltzmann_scores /= boltzmann_scores.sum() 139 | not_chosen = list(range(len(boltzmann_scores))) 140 | chosen = [] 141 | while len(chosen) < n and not_chosen: 142 | scores_ = boltzmann_scores[not_chosen] 143 | r = numpy.random.uniform(high=scores_.sum()) 144 | total = 0 145 | upto = 0 146 | while True: 147 | score = scores_[upto] 148 | total += score 149 | if total > r: 150 | break 151 | 152 | upto += 1 153 | chosen.append(not_chosen[upto]) 154 | not_chosen.pop(upto) 155 | 156 | return chosen 157 | 158 | 159 | class Recommender(ABC): 160 | """Base class for recommenders. 161 | 162 | Attributes 163 | ---------- 164 | """ 165 | 166 | @abstractmethod 167 | def recommend(self, ids: Sequence[int], 168 | predictions: numpy.ndarray, 169 | n: int=1, diversity: float=0.5) -> Sequence[int]: 170 | """Recommends an instance to label. 171 | 172 | Parameters 173 | ---------- 174 | ids 175 | Sequence of IDs in the unlabelled data pool. 176 | predictions 177 | N x T x C array of predictions. 178 | n 179 | Number of recommendations to make. 180 | diversity 181 | Recommendation diversity in [0, 1]. 182 | 183 | Returns 184 | ------- 185 | Sequence[int] 186 | IDs of the instances to label. 187 | """ 188 | 189 | 190 | class RandomRecommender(Recommender): 191 | """Recommends instances at random.""" 192 | 193 | def __init__(self, db: acton.database.Database): 194 | """ 195 | Parameters 196 | ---------- 197 | db 198 | Features database. 199 | """ 200 | self._db = db 201 | 202 | def recommend(self, ids: Sequence[int], 203 | predictions: numpy.ndarray, 204 | n: int=1, diversity: float=0.5) -> Sequence[int]: 205 | """Recommends an instance to label. 206 | 207 | Parameters 208 | ---------- 209 | ids 210 | Sequence of IDs in the unlabelled data pool. 211 | predictions 212 | N x T x C array of predictions. 213 | n 214 | Number of recommendations to make. 215 | diversity 216 | Recommendation diversity in [0, 1]. 217 | 218 | Returns 219 | ------- 220 | Sequence[int] 221 | IDs of the instances to label. 222 | """ 223 | return numpy.random.choice(list(ids), size=n) 224 | 225 | 226 | class QBCRecommender(Recommender): 227 | """Recommends instances by committee disagreement.""" 228 | 229 | def __init__(self, db: acton.database.Database): 230 | """ 231 | Parameters 232 | ---------- 233 | db 234 | Features database. 235 | """ 236 | self._db = db 237 | 238 | def recommend(self, ids: Sequence[int], 239 | predictions: numpy.ndarray, 240 | n: int=1, diversity: float=0.5) -> Sequence[int]: 241 | """Recommends an instance to label. 242 | 243 | Notes 244 | ----- 245 | Assumes predictions are probabilities of positive binary label. 246 | 247 | Parameters 248 | ---------- 249 | ids 250 | Sequence of IDs in the unlabelled data pool. 251 | predictions 252 | N x T x C array of predictions. The ith row must correspond with the 253 | ith ID in the sequence. 254 | n 255 | Number of recommendations to make. 256 | diversity 257 | Recommendation diversity in [0, 1]. 258 | 259 | Returns 260 | ------- 261 | Sequence[int] 262 | IDs of the instances to label. 263 | """ 264 | assert predictions.shape[1] > 2, "QBC must have > 2 predictors." 265 | assert len(ids) == predictions.shape[0] 266 | assert 0 <= diversity <= 1 267 | labels = predictions.argmax(axis=2) 268 | plurality_labels, plurality_counts = scipy.stats.mode(labels, axis=1) 269 | assert plurality_labels.shape == (predictions.shape[0], 1), \ 270 | 'plurality_labels has shape {}; expected {}'.format( 271 | plurality_labels.shape, (predictions.shape[0], 1)) 272 | agree_with_plurality = labels == plurality_labels 273 | assert labels.shape == agree_with_plurality.shape 274 | n_agree = labels.sum(axis=1) 275 | p_agree = n_agree / n_agree.max() # Agreement is now between 0 and 1. 276 | disagreement = 1 - p_agree 277 | indices = choose_boltzmann(self._db.read_features(ids), disagreement, n, 278 | temperature=diversity * 2) 279 | return [ids[i] for i in indices] 280 | 281 | 282 | class UncertaintyRecommender(Recommender): 283 | """Recommends instances by confidence-based uncertainty sampling.""" 284 | 285 | def __init__(self, db: acton.database.Database): 286 | """ 287 | Parameters 288 | ---------- 289 | db 290 | Features database. 291 | """ 292 | self._db = db 293 | 294 | def recommend(self, ids: Sequence[int], 295 | predictions: numpy.ndarray, 296 | n: int=1, diversity: float=0.5) -> Sequence[int]: 297 | """Recommends an instance to label. 298 | 299 | Notes 300 | ----- 301 | Assumes predictions are probabilities of positive binary label. 302 | 303 | Parameters 304 | ---------- 305 | ids 306 | Sequence of IDs in the unlabelled data pool. 307 | predictions 308 | N x 1 x C array of predictions. The ith row must correspond with the 309 | ith ID in the sequence. 310 | n 311 | Number of recommendations to make. 312 | diversity 313 | Recommendation diversity in [0, 1]. 314 | 315 | Returns 316 | ------- 317 | Sequence[int] 318 | IDs of the instances to label. 319 | """ 320 | if predictions.shape[1] != 1: 321 | raise ValueError('Uncertainty sampling must have one predictor') 322 | 323 | assert len(ids) == predictions.shape[0] 324 | 325 | # x* = argmax (1 - p(y^ | x)) where y^ = argmax p(y | x) (Settles 2009). 326 | proximities = 1 - predictions.max(axis=2).ravel() 327 | assert proximities.shape == (len(ids),) 328 | 329 | indices = choose_boltzmann(self._db.read_features(ids), proximities, n, 330 | temperature=diversity * 2) 331 | return [ids[i] for i in indices] 332 | 333 | 334 | class EntropyRecommender(Recommender): 335 | """Recommends instances by confidence-based uncertainty sampling.""" 336 | 337 | def __init__(self, db: acton.database.Database): 338 | """ 339 | Parameters 340 | ---------- 341 | db 342 | Features database. 343 | """ 344 | self._db = db 345 | 346 | def recommend(self, ids: Sequence[int], 347 | predictions: numpy.ndarray, 348 | n: int=1, diversity: float=0.5) -> Sequence[int]: 349 | """Recommends an instance to label. 350 | 351 | Parameters 352 | ---------- 353 | ids 354 | Sequence of IDs in the unlabelled data pool. 355 | predictions 356 | N x 1 x C array of predictions. The ith row must correspond with the 357 | ith ID in the sequence. 358 | n 359 | Number of recommendations to make. 360 | diversity 361 | Recommendation diversity in [0, 1]. 362 | 363 | Returns 364 | ------- 365 | Sequence[int] 366 | IDs of the instances to label. 367 | """ 368 | if predictions.shape[1] != 1: 369 | raise ValueError('Uncertainty sampling must have one predictor') 370 | 371 | assert len(ids) == predictions.shape[0] 372 | 373 | with warnings.catch_warnings(): 374 | warnings.filterwarnings(action='ignore', category=RuntimeWarning) 375 | proximities = -predictions * numpy.log(predictions) 376 | 377 | proximities = proximities.sum(axis=1).max(axis=1).ravel() 378 | proximities[numpy.isnan(proximities)] = float('-inf') 379 | 380 | assert proximities.shape == (len(ids),) 381 | 382 | indices = choose_boltzmann(self._db.read_features(ids), proximities, n, 383 | temperature=diversity * 2) 384 | return [ids[i] for i in indices] 385 | 386 | 387 | class MarginRecommender(Recommender): 388 | """Recommends instances by margin-based uncertainty sampling.""" 389 | 390 | def __init__(self, db: acton.database.Database): 391 | """ 392 | Parameters 393 | ---------- 394 | db 395 | Features database. 396 | """ 397 | self._db = db 398 | 399 | def recommend(self, ids: Sequence[int], 400 | predictions: numpy.ndarray, 401 | n: int=1, diversity: float=0.5) -> Sequence[int]: 402 | """Recommends an instance to label. 403 | 404 | Notes 405 | ----- 406 | Assumes predictions are probabilities of positive binary label. 407 | 408 | Parameters 409 | ---------- 410 | ids 411 | Sequence of IDs in the unlabelled data pool. 412 | predictions 413 | N x 1 x C array of predictions. The ith row must correspond with the 414 | ith ID in the sequence. 415 | n 416 | Number of recommendations to make. 417 | diversity 418 | Recommendation diversity in [0, 1]. 419 | 420 | Returns 421 | ------- 422 | Sequence[int] 423 | IDs of the instances to label. 424 | """ 425 | if predictions.shape[1] != 1: 426 | raise ValueError('Uncertainty sampling must have one predictor') 427 | 428 | assert len(ids) == predictions.shape[0] 429 | 430 | # x* = argmin p(y1^ | x) - p(y2^ | x) where yn^ = argmax p(yn | x) 431 | # (Settles 2009). 432 | partitioned = numpy.partition(predictions, -2, axis=2) 433 | most_likely = partitioned[:, 0, -1] 434 | second_most_likely = partitioned[:, 0, -2] 435 | assert most_likely.shape == (len(ids),) 436 | scores = 1 - (most_likely - second_most_likely) 437 | 438 | indices = choose_boltzmann(self._db.read_features(ids), scores, n, 439 | temperature=diversity * 2) 440 | return [ids[i] for i in indices] 441 | 442 | 443 | class ThompsonSamplingRecommender(Recommender): 444 | """Recommends instances by Thompson Sampling. 445 | Input: 446 | K x N x N predictions. 447 | Output 448 | IDs of the instances to label. 449 | 450 | Only support one recommendation 451 | 452 | Attributes 453 | ----------------- 454 | db 455 | Features database. 456 | 457 | """ 458 | 459 | def __init__(self, db: acton.database.Database): 460 | """ 461 | Parameters 462 | ---------- 463 | db 464 | Features database. 465 | """ 466 | self._db = db 467 | 468 | def recommend(self, ids: Sequence[tuple], 469 | predictions: numpy.ndarray, 470 | n: int=1, diversity: float=0.0, 471 | repreated_labelling: bool = True) -> Sequence[int]: 472 | """Recommends an instance to label. 473 | 474 | Notes 475 | ----- 476 | Predictions are reconstruct enties equal to e_i R_k e_j^T. 477 | 478 | Parameters 479 | ---------- 480 | ids 481 | Sequence of IDs in the unlabelled data pool. 482 | predictions 483 | K x N x N array of predictions. 484 | n 485 | Number of recommendations to make. 486 | diversity 487 | recommend methods selection. 488 | 0.5 represents Thompson Samplig; 489 | 1.0 represents Random Sampling 490 | repeated_labelling 491 | whether allow one instance to be labelled more than once 492 | 493 | Returns 494 | ------- 495 | Sequence[int] 496 | IDs of the instances to label. 497 | """ 498 | 499 | n_relations, n_entities, _ = predictions.shape 500 | 501 | MIN_VAL = numpy.iinfo(numpy.int32).min 502 | 503 | # mask tensor: 0 represents unlabelled, 1 represents labelled 504 | 505 | if repreated_labelling: 506 | # test: allow repeated labelling 507 | mask = numpy.zeros_like(predictions) 508 | else: 509 | mask = numpy.ones_like(predictions) 510 | for _tuple in ids: 511 | r_k, e_i, e_j = _tuple 512 | mask[r_k, e_i, e_j] = 0 513 | 514 | if diversity == 0.0: 515 | predictions[mask == 1] = MIN_VAL 516 | return [numpy.unravel_index(predictions.argmax(), 517 | predictions.shape)] 518 | else: 519 | correct = False 520 | while not correct: 521 | sample = (numpy.random.randint(n_relations), 522 | numpy.random.randint(n_entities), 523 | numpy.random.randint(n_entities)) 524 | if mask[sample] == 0: 525 | correct = True 526 | return [sample] 527 | 528 | 529 | # For safe string-based access to recommender classes. 530 | RECOMMENDERS = { 531 | 'RandomRecommender': RandomRecommender, 532 | 'QBCRecommender': QBCRecommender, 533 | 'UncertaintyRecommender': UncertaintyRecommender, 534 | 'EntropyRecommender': EntropyRecommender, 535 | 'MarginRecommender': MarginRecommender, 536 | 'ThompsonSamplingRecommender': ThompsonSamplingRecommender, 537 | 'None': RandomRecommender, 538 | } 539 | -------------------------------------------------------------------------------- /acton/proto/wrappers.py: -------------------------------------------------------------------------------- 1 | """Classes that wrap protobufs.""" 2 | 3 | import json 4 | from typing import Union, List, Iterable 5 | 6 | import acton.database 7 | import acton.proto.acton_pb2 as acton_pb 8 | import acton.proto.io 9 | import google.protobuf.json_format as json_format 10 | import numpy 11 | import sklearn.preprocessing 12 | from sklearn.preprocessing import LabelEncoder as SKLabelEncoder 13 | 14 | 15 | def validate_db(db: acton_pb.Database): 16 | """Validates a Database proto. 17 | 18 | Parameters 19 | ---------- 20 | db 21 | Database to validate. 22 | 23 | Raises 24 | ------ 25 | ValueError 26 | """ 27 | if db.class_name not in acton.database.DATABASES: 28 | raise ValueError('Invalid database class name: {}'.format( 29 | db.class_name)) 30 | 31 | if not db.path: 32 | raise ValueError('Must specify db.path.') 33 | 34 | 35 | def deserialise_encoder( 36 | encoder: acton_pb.Database.LabelEncoder 37 | ) -> sklearn.preprocessing.LabelEncoder: 38 | """Deserialises a LabelEncoder protobuf. 39 | 40 | Parameters 41 | ---------- 42 | encoder 43 | LabelEncoder protobuf. 44 | 45 | Returns 46 | ------- 47 | sklearn.preprocessing.LabelEncoder 48 | LabelEncoder (or None if no encodings were specified). 49 | """ 50 | encodings = [] 51 | for encoding in encoder.encoding: 52 | encodings.append((encoding.class_int, encoding.class_label)) 53 | encodings.sort() 54 | encodings = numpy.array([c[1] for c in encodings]) 55 | 56 | encoder = SKLabelEncoder() 57 | encoder.classes_ = encodings 58 | return encoder 59 | 60 | 61 | class LabelPool(object): 62 | """Wrapper for the LabelPool protobuf. 63 | 64 | Attributes 65 | ---------- 66 | proto : acton_pb.LabelPool 67 | Protobuf representing the label pool. 68 | db_kwargs : dict 69 | Key-value pairs of keyword arguments for the database constructor. 70 | label_encoder : sklearn.preprocessing.LabelEncoder 71 | Encodes labels as integers. May be None. 72 | """ 73 | 74 | def __init__(self, proto: Union[str, acton_pb.LabelPool]): 75 | """ 76 | Parameters 77 | ---------- 78 | proto 79 | Path to .proto file, or raw protobuf itself. 80 | """ 81 | try: 82 | self.proto = acton.proto.io.read_proto(proto, acton_pb.LabelPool) 83 | except TypeError: 84 | if isinstance(proto, acton_pb.LabelPool): 85 | self.proto = proto 86 | else: 87 | raise TypeError('proto should be str or LabelPool protobuf.') 88 | self._validate_proto() 89 | self.db_kwargs = {kwa.key: json.loads(kwa.value) 90 | for kwa in self.proto.db.kwarg} 91 | if len(self.proto.db.label_encoder.encoding) > 0: 92 | self.label_encoder = deserialise_encoder( 93 | self.proto.db.label_encoder) 94 | self.db_kwargs['label_encoder'] = self.label_encoder 95 | else: 96 | self.label_encoder = None 97 | self._set_default() 98 | 99 | @classmethod 100 | def deserialise(cls, proto: bytes, json: bool=False) -> 'LabelPool': 101 | """Deserialises a protobuf into a LabelPool. 102 | 103 | Parameters 104 | ---------- 105 | proto 106 | Serialised protobuf. 107 | json 108 | Whether the serialised protobuf is in JSON format. 109 | 110 | Returns 111 | ------- 112 | LabelPool 113 | """ 114 | if not json: 115 | lp = acton_pb.LabelPool() 116 | lp.ParseFromString(proto) 117 | return cls(lp) 118 | 119 | return cls(json_format.Parse(proto, acton_pb.LabelPool())) 120 | 121 | @property 122 | def DB(self) -> acton.database.Database: 123 | """Gets a database context manager for the specified database. 124 | 125 | Returns 126 | ------- 127 | type 128 | Database context manager. 129 | """ 130 | if hasattr(self, '_DB'): 131 | return self._DB 132 | 133 | self._DB = lambda: acton.database.DATABASES[self.proto.db.class_name]( 134 | self.proto.db.path, **self.db_kwargs) 135 | 136 | return self._DB 137 | 138 | @property 139 | def ids(self) -> List[int]: 140 | """Gets a list of IDs. 141 | 142 | Returns 143 | ------- 144 | List[int] 145 | List of known IDs. 146 | """ 147 | if hasattr(self, '_ids'): 148 | return self._ids 149 | 150 | self._ids = list(self.proto.id) 151 | return self._ids 152 | 153 | @property 154 | def labels(self) -> numpy.ndarray: 155 | """Gets labels array specified in input. 156 | 157 | Notes 158 | ----- 159 | The returned array is cached by this object so future calls will not 160 | need to recompile the array. 161 | 162 | Returns 163 | ------- 164 | numpy.ndarray 165 | T x N x F NumPy array of labels. 166 | """ 167 | if hasattr(self, '_labels'): 168 | return self._labels 169 | 170 | ids = self.ids 171 | with self.DB() as db: 172 | return db.read_labels([0], ids) 173 | 174 | def _validate_proto(self): 175 | """Checks that the protobuf is valid and enforces constraints. 176 | 177 | Raises 178 | ------ 179 | ValueError 180 | """ 181 | validate_db(self.proto.db) 182 | 183 | def _set_default(self): 184 | """Adds default parameters to the protobuf.""" 185 | 186 | @classmethod 187 | def make( 188 | cls: type, 189 | ids: Iterable[int], 190 | db: acton.database.Database) -> 'LabelPool': 191 | """Constructs a LabelPool. 192 | 193 | Parameters 194 | ---------- 195 | ids 196 | Iterable of instance IDs. 197 | db 198 | Database 199 | 200 | Returns 201 | ------- 202 | LabelPool 203 | """ 204 | proto = acton_pb.LabelPool() 205 | 206 | # Store the IDs. 207 | for id_ in ids: 208 | proto.id.append(id_) 209 | 210 | # Store the database. 211 | proto.db.CopyFrom(db.to_proto()) 212 | 213 | return cls(proto) 214 | 215 | 216 | class Predictions(object): 217 | """Wrapper for the Predictions protobuf. 218 | 219 | Attributes 220 | ---------- 221 | proto : acton_pb.Predictions 222 | Protobuf representing predictions. 223 | db_kwargs : dict 224 | Dictionary of database keyword arguments. 225 | label_encoder : sklearn.preprocessing.LabelEncoder 226 | Encodes labels as integers. May be None. 227 | """ 228 | 229 | def __init__(self, proto: Union[str, acton_pb.Predictions]): 230 | """ 231 | Parameters 232 | ---------- 233 | proto 234 | Path to .proto file, or raw protobuf itself. 235 | """ 236 | try: 237 | self.proto = acton.proto.io.read_proto( 238 | proto, acton_pb.Predictions) 239 | except TypeError: 240 | if isinstance(proto, acton_pb.Predictions): 241 | self.proto = proto 242 | else: 243 | raise TypeError('proto should be str or Predictions protobuf.') 244 | self._validate_proto() 245 | self.db_kwargs = {kwa.key: json.loads(kwa.value) 246 | for kwa in self.proto.db.kwarg} 247 | if len(self.proto.db.label_encoder.encoding) > 0: 248 | self.label_encoder = deserialise_encoder( 249 | self.proto.db.label_encoder) 250 | self.db_kwargs['label_encoder'] = self.label_encoder 251 | else: 252 | self.label_encoder = None 253 | self._set_default() 254 | 255 | @property 256 | def DB(self) -> acton.database.Database: 257 | """Gets a database context manager for the specified database. 258 | 259 | Returns 260 | ------- 261 | type 262 | Database context manager. 263 | """ 264 | if hasattr(self, '_DB'): 265 | return self._DB 266 | 267 | self._DB = lambda: acton.database.DATABASES[self.proto.db.class_name]( 268 | self.proto.db.path, **self.db_kwargs) 269 | 270 | return self._DB 271 | 272 | @property 273 | def predicted_ids(self) -> List[int]: 274 | """Gets a list of IDs corresponding to predictions. 275 | 276 | Returns 277 | ------- 278 | List[int] 279 | List of IDs corresponding to predictions. 280 | """ 281 | if hasattr(self, '_predicted_ids'): 282 | return self._predicted_ids 283 | 284 | self._predicted_ids = [prediction.id 285 | for prediction in self.proto.prediction] 286 | return self._predicted_ids 287 | 288 | @property 289 | def labelled_ids(self) -> List[int]: 290 | """Gets a list of IDs the predictor knew the label for. 291 | 292 | Returns 293 | ------- 294 | List[int] 295 | List of IDs the predictor knew the label for. 296 | """ 297 | if hasattr(self, '_labelled_ids'): 298 | return self._labelled_ids 299 | 300 | self._labelled_ids = list(self.proto.labelled_id) 301 | return self._labelled_ids 302 | 303 | @property 304 | def predictions(self) -> numpy.ndarray: 305 | """Gets predictions array specified in input. 306 | 307 | Notes 308 | ----- 309 | The returned array is cached by this object so future calls will not 310 | need to recompile the array. 311 | 312 | Returns 313 | ------- 314 | numpy.ndarray 315 | T x N x D NumPy array of predictions. 316 | """ 317 | if hasattr(self, '_predictions'): 318 | return self._predictions 319 | 320 | self._predictions = [] 321 | for prediction in self.proto.prediction: 322 | data = prediction.prediction 323 | shape = (self.proto.n_predictors, 324 | self.proto.n_prediction_dimensions) 325 | self._predictions.append( 326 | acton.proto.io.get_ndarray(data, shape, float)) 327 | self._predictions = numpy.array(self._predictions).transpose((1, 0, 2)) 328 | return self._predictions 329 | 330 | def _validate_proto(self): 331 | """Checks that the protobuf is valid and enforces constraints. 332 | 333 | Raises 334 | ------ 335 | ValueError 336 | """ 337 | if self.proto.n_predictors < 1: 338 | raise ValueError('Number of predictors must be > 0.') 339 | 340 | if self.proto.n_prediction_dimensions < 1: 341 | raise ValueError('Prediction dimension must be > 0.') 342 | 343 | validate_db(self.proto.db) 344 | 345 | def _set_default(self): 346 | """Adds default parameters to the protobuf.""" 347 | 348 | @classmethod 349 | def make( 350 | cls: type, 351 | predicted_ids: Iterable[int], 352 | labelled_ids: Iterable[int], 353 | predictions: numpy.ndarray, 354 | db: acton.database.Database, 355 | predictor: str='') -> 'Predictions': 356 | """Converts NumPy predictions to a Predictions object. 357 | 358 | Parameters 359 | ---------- 360 | predicted_ids 361 | Iterable of instance IDs corresponding to predictions. 362 | labelled_ids 363 | Iterable of instance IDs used to train the predictor. 364 | predictions 365 | T x N x D array of corresponding predictions. 366 | predictor 367 | Name of predictor used to generate predictions. 368 | db 369 | Database. 370 | 371 | Returns 372 | ------- 373 | Predictions 374 | """ 375 | proto = acton_pb.Predictions() 376 | 377 | # Store single data first. 378 | n_predictors, n_instances, n_prediction_dimensions = predictions.shape 379 | proto.n_predictors = n_predictors 380 | proto.n_prediction_dimensions = n_prediction_dimensions 381 | proto.predictor = predictor 382 | 383 | # Store the database. 384 | proto.db.CopyFrom(db.to_proto()) 385 | 386 | # Store the predictions array. We can do this by looping over the 387 | # instances. 388 | for id_, prediction in zip( 389 | predicted_ids, predictions.transpose((1, 0, 2))): 390 | prediction_ = proto.prediction.add() 391 | prediction_.id = int(id_) # numpy.int64 -> int 392 | prediction_.prediction.extend(prediction.ravel()) 393 | 394 | # Store the labelled IDs. 395 | for id_ in labelled_ids: 396 | # int() here takes numpy.int64 to int, for protobuf compatibility. 397 | proto.labelled_id.append(int(id_)) 398 | 399 | return cls(proto) 400 | 401 | @classmethod 402 | def deserialise(cls, proto: bytes, json: bool=False) -> 'Predictions': 403 | """Deserialises a protobuf into Predictions. 404 | 405 | Parameters 406 | ---------- 407 | proto 408 | Serialised protobuf. 409 | json 410 | Whether the serialised protobuf is in JSON format. 411 | 412 | Returns 413 | ------- 414 | Predictions 415 | """ 416 | if not json: 417 | predictions = acton_pb.Predictions() 418 | predictions.ParseFromString(proto) 419 | return cls(predictions) 420 | 421 | return cls(json_format.Parse(proto, acton_pb.Predictions())) 422 | 423 | 424 | class Recommendations(object): 425 | """Wrapper for the Recommendations protobuf. 426 | 427 | Attributes 428 | ---------- 429 | proto : acton_pb.Recommendations 430 | Protobuf representing recommendations. 431 | db_kwargs : dict 432 | Key-value pairs of keyword arguments for the database constructor. 433 | label_encoder : sklearn.preprocessing.LabelEncoder 434 | Encodes labels as integers. May be None. 435 | """ 436 | 437 | def __init__(self, proto: Union[str, acton_pb.Recommendations]): 438 | """ 439 | Parameters 440 | ---------- 441 | proto 442 | Path to .proto file, or raw protobuf itself. 443 | """ 444 | try: 445 | self.proto = acton.proto.io.read_proto( 446 | proto, acton_pb.Recommendations) 447 | except TypeError: 448 | if isinstance(proto, acton_pb.Recommendations): 449 | self.proto = proto 450 | else: 451 | raise TypeError( 452 | 'proto should be str or Recommendations protobuf.') 453 | self._validate_proto() 454 | self.db_kwargs = {kwa.key: json.loads(kwa.value) 455 | for kwa in self.proto.db.kwarg} 456 | if len(self.proto.db.label_encoder.encoding) > 0: 457 | self.label_encoder = deserialise_encoder( 458 | self.proto.db.label_encoder) 459 | self.db_kwargs['label_encoder'] = self.label_encoder 460 | else: 461 | self.label_encoder = None 462 | self._set_default() 463 | 464 | @classmethod 465 | def deserialise(cls, proto: bytes, json: bool=False) -> 'Recommendations': 466 | """Deserialises a protobuf into Recommendations. 467 | 468 | Parameters 469 | ---------- 470 | proto 471 | Serialised protobuf. 472 | json 473 | Whether the serialised protobuf is in JSON format. 474 | 475 | Returns 476 | ------- 477 | Recommendations 478 | """ 479 | if not json: 480 | recommendations = acton_pb.Recommendations() 481 | recommendations.ParseFromString(proto) 482 | return cls(recommendations) 483 | 484 | return cls(json_format.Parse(proto, acton_pb.Recommendations())) 485 | 486 | @property 487 | def DB(self) -> acton.database.Database: 488 | """Gets a database context manager for the specified database. 489 | 490 | Returns 491 | ------- 492 | type 493 | Database context manager. 494 | """ 495 | if hasattr(self, '_DB'): 496 | return self._DB 497 | 498 | self._DB = lambda: acton.database.DATABASES[self.proto.db.class_name]( 499 | self.proto.db.path, **self.db_kwargs) 500 | 501 | return self._DB 502 | 503 | @property 504 | def recommendations(self) -> List[int]: 505 | """Gets a list of recommended IDs. 506 | 507 | Returns 508 | ------- 509 | List[int] 510 | List of recommended IDs. 511 | """ 512 | if hasattr(self, '_recommendations'): 513 | return self._recommendations 514 | 515 | self._recommendations = list(self.proto.recommended_id) 516 | return self._recommendations 517 | 518 | @property 519 | def labelled_ids(self) -> List[int]: 520 | """Gets a list of labelled IDs. 521 | 522 | Returns 523 | ------- 524 | List[int] 525 | List of labelled IDs. 526 | """ 527 | if hasattr(self, '_labelled_ids'): 528 | return self._labelled_ids 529 | 530 | self._labelled_ids = list(self.proto.labelled_id) 531 | return self._labelled_ids 532 | 533 | def _validate_proto(self): 534 | """Checks that the protobuf is valid and enforces constraints. 535 | 536 | Raises 537 | ------ 538 | ValueError 539 | """ 540 | validate_db(self.proto.db) 541 | 542 | def _set_default(self): 543 | """Adds default parameters to the protobuf.""" 544 | 545 | @classmethod 546 | def make( 547 | cls: type, 548 | recommended_ids: Iterable[int], 549 | labelled_ids: Iterable[int], 550 | recommender: str, 551 | db: acton.database.Database) -> 'Recommendations': 552 | """Constructs a Recommendations. 553 | 554 | Parameters 555 | ---------- 556 | recommended_ids 557 | Iterable of recommended instance IDs. 558 | labelled_ids 559 | Iterable of labelled instance IDs used to make recommendations. 560 | recommender 561 | Name of the recommender used to make recommendations. 562 | db 563 | Database. 564 | 565 | Returns 566 | ------- 567 | Recommendations 568 | """ 569 | proto = acton_pb.Recommendations() 570 | 571 | # Store single data first. 572 | proto.recommender = recommender 573 | 574 | # Store the IDs. 575 | for id_ in recommended_ids: 576 | proto.recommended_id.append(id_) 577 | for id_ in labelled_ids: 578 | proto.labelled_id.append(id_) 579 | 580 | # Store the database. 581 | proto.db.CopyFrom(db.to_proto()) 582 | 583 | return cls(proto) 584 | -------------------------------------------------------------------------------- /acton/proto/acton_pb2.py: -------------------------------------------------------------------------------- 1 | # Generated by the protocol buffer compiler. DO NOT EDIT! 2 | # source: acton.proto 3 | 4 | import sys 5 | _b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) 6 | from google.protobuf import descriptor as _descriptor 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | DESCRIPTOR = _descriptor.FileDescriptor( 16 | name='acton.proto', 17 | package='acton', 18 | syntax='proto3', 19 | serialized_pb=_b( 20 | '\n\x0b\x61\x63ton.proto\x12\x05\x61\x63ton\"$\n\x06KeyVal\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\xfc\x01\n\x08\x44\x61tabase\x12\x0c\n\x04path\x18\x01 \x01(\t\x12\x12\n\nclass_name\x18\x02 \x01(\t\x12\x1c\n\x05kwarg\x18\x03 \x03(\x0b\x32\r.acton.KeyVal\x12\x33\n\rlabel_encoder\x18\x04 \x01(\x0b\x32\x1c.acton.Database.LabelEncoder\x1a{\n\x0cLabelEncoder\x12\x37\n\x08\x65ncoding\x18\x01 \x03(\x0b\x32%.acton.Database.LabelEncoder.Encoding\x1a\x32\n\x08\x45ncoding\x12\x13\n\x0b\x63lass_label\x18\x01 \x01(\t\x12\x11\n\tclass_int\x18\x02 \x01(\x05\"4\n\tLabelPool\x12\n\n\x02id\x18\x01 \x03(\x03\x12\x1b\n\x02\x64\x62\x18\x02 \x01(\x0b\x32\x0f.acton.Database\"\xea\x01\n\x0bPredictions\x12\x31\n\nprediction\x18\x01 \x03(\x0b\x32\x1d.acton.Predictions.Prediction\x12\x13\n\x0blabelled_id\x18\x02 \x03(\x03\x12\x14\n\x0cn_predictors\x18\x03 \x01(\x05\x12\x1f\n\x17n_prediction_dimensions\x18\x04 \x01(\x05\x12\x11\n\tpredictor\x18\x05 \x01(\t\x12\x1b\n\x02\x64\x62\x18\x06 \x01(\x0b\x32\x0f.acton.Database\x1a,\n\nPrediction\x12\n\n\x02id\x18\x01 \x01(\x03\x12\x12\n\nprediction\x18\x02 \x03(\x01\"p\n\x0fRecommendations\x12\x16\n\x0erecommended_id\x18\x01 \x03(\x03\x12\x13\n\x0blabelled_id\x18\x02 \x03(\x03\x12\x13\n\x0brecommender\x18\x03 \x01(\t\x12\x1b\n\x02\x64\x62\x18\x04 \x01(\x0b\x32\x0f.acton.Databaseb\x06proto3') 21 | ) 22 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 23 | 24 | 25 | _KEYVAL = _descriptor.Descriptor( 26 | name='KeyVal', 27 | full_name='acton.KeyVal', 28 | filename=None, 29 | file=DESCRIPTOR, 30 | containing_type=None, 31 | fields=[ 32 | _descriptor.FieldDescriptor( 33 | name='key', full_name='acton.KeyVal.key', index=0, 34 | number=1, type=9, cpp_type=9, label=1, 35 | has_default_value=False, default_value=_b("").decode('utf-8'), 36 | message_type=None, enum_type=None, containing_type=None, 37 | is_extension=False, extension_scope=None, 38 | options=None), 39 | _descriptor.FieldDescriptor( 40 | name='value', full_name='acton.KeyVal.value', index=1, 41 | number=2, type=9, cpp_type=9, label=1, 42 | has_default_value=False, default_value=_b("").decode('utf-8'), 43 | message_type=None, enum_type=None, containing_type=None, 44 | is_extension=False, extension_scope=None, 45 | options=None), 46 | ], 47 | extensions=[ 48 | ], 49 | nested_types=[], 50 | enum_types=[ 51 | ], 52 | options=None, 53 | is_extendable=False, 54 | syntax='proto3', 55 | extension_ranges=[], 56 | oneofs=[ 57 | ], 58 | serialized_start=22, 59 | serialized_end=58, 60 | ) 61 | 62 | 63 | _DATABASE_LABELENCODER_ENCODING = _descriptor.Descriptor( 64 | name='Encoding', 65 | full_name='acton.Database.LabelEncoder.Encoding', 66 | filename=None, 67 | file=DESCRIPTOR, 68 | containing_type=None, 69 | fields=[ 70 | _descriptor.FieldDescriptor( 71 | name='class_label', full_name='acton.Database.LabelEncoder.Encoding.class_label', index=0, 72 | number=1, type=9, cpp_type=9, label=1, 73 | has_default_value=False, default_value=_b("").decode('utf-8'), 74 | message_type=None, enum_type=None, containing_type=None, 75 | is_extension=False, extension_scope=None, 76 | options=None), 77 | _descriptor.FieldDescriptor( 78 | name='class_int', full_name='acton.Database.LabelEncoder.Encoding.class_int', index=1, 79 | number=2, type=5, cpp_type=1, label=1, 80 | has_default_value=False, default_value=0, 81 | message_type=None, enum_type=None, containing_type=None, 82 | is_extension=False, extension_scope=None, 83 | options=None), 84 | ], 85 | extensions=[ 86 | ], 87 | nested_types=[], 88 | enum_types=[ 89 | ], 90 | options=None, 91 | is_extendable=False, 92 | syntax='proto3', 93 | extension_ranges=[], 94 | oneofs=[ 95 | ], 96 | serialized_start=263, 97 | serialized_end=313, 98 | ) 99 | 100 | _DATABASE_LABELENCODER = _descriptor.Descriptor( 101 | name='LabelEncoder', 102 | full_name='acton.Database.LabelEncoder', 103 | filename=None, 104 | file=DESCRIPTOR, 105 | containing_type=None, 106 | fields=[ 107 | _descriptor.FieldDescriptor( 108 | name='encoding', full_name='acton.Database.LabelEncoder.encoding', index=0, 109 | number=1, type=11, cpp_type=10, label=3, 110 | has_default_value=False, default_value=[], 111 | message_type=None, enum_type=None, containing_type=None, 112 | is_extension=False, extension_scope=None, 113 | options=None), 114 | ], 115 | extensions=[ 116 | ], 117 | nested_types=[_DATABASE_LABELENCODER_ENCODING, ], 118 | enum_types=[ 119 | ], 120 | options=None, 121 | is_extendable=False, 122 | syntax='proto3', 123 | extension_ranges=[], 124 | oneofs=[ 125 | ], 126 | serialized_start=190, 127 | serialized_end=313, 128 | ) 129 | 130 | _DATABASE = _descriptor.Descriptor( 131 | name='Database', 132 | full_name='acton.Database', 133 | filename=None, 134 | file=DESCRIPTOR, 135 | containing_type=None, 136 | fields=[ 137 | _descriptor.FieldDescriptor( 138 | name='path', full_name='acton.Database.path', index=0, 139 | number=1, type=9, cpp_type=9, label=1, 140 | has_default_value=False, default_value=_b("").decode('utf-8'), 141 | message_type=None, enum_type=None, containing_type=None, 142 | is_extension=False, extension_scope=None, 143 | options=None), 144 | _descriptor.FieldDescriptor( 145 | name='class_name', full_name='acton.Database.class_name', index=1, 146 | number=2, type=9, cpp_type=9, label=1, 147 | has_default_value=False, default_value=_b("").decode('utf-8'), 148 | message_type=None, enum_type=None, containing_type=None, 149 | is_extension=False, extension_scope=None, 150 | options=None), 151 | _descriptor.FieldDescriptor( 152 | name='kwarg', full_name='acton.Database.kwarg', index=2, 153 | number=3, type=11, cpp_type=10, label=3, 154 | has_default_value=False, default_value=[], 155 | message_type=None, enum_type=None, containing_type=None, 156 | is_extension=False, extension_scope=None, 157 | options=None), 158 | _descriptor.FieldDescriptor( 159 | name='label_encoder', full_name='acton.Database.label_encoder', index=3, 160 | number=4, type=11, cpp_type=10, label=1, 161 | has_default_value=False, default_value=None, 162 | message_type=None, enum_type=None, containing_type=None, 163 | is_extension=False, extension_scope=None, 164 | options=None), 165 | ], 166 | extensions=[ 167 | ], 168 | nested_types=[_DATABASE_LABELENCODER, ], 169 | enum_types=[ 170 | ], 171 | options=None, 172 | is_extendable=False, 173 | syntax='proto3', 174 | extension_ranges=[], 175 | oneofs=[ 176 | ], 177 | serialized_start=61, 178 | serialized_end=313, 179 | ) 180 | 181 | 182 | _LABELPOOL = _descriptor.Descriptor( 183 | name='LabelPool', 184 | full_name='acton.LabelPool', 185 | filename=None, 186 | file=DESCRIPTOR, 187 | containing_type=None, 188 | fields=[ 189 | _descriptor.FieldDescriptor( 190 | name='id', full_name='acton.LabelPool.id', index=0, 191 | number=1, type=3, cpp_type=2, label=3, 192 | has_default_value=False, default_value=[], 193 | message_type=None, enum_type=None, containing_type=None, 194 | is_extension=False, extension_scope=None, 195 | options=None), 196 | _descriptor.FieldDescriptor( 197 | name='db', full_name='acton.LabelPool.db', index=1, 198 | number=2, type=11, cpp_type=10, label=1, 199 | has_default_value=False, default_value=None, 200 | message_type=None, enum_type=None, containing_type=None, 201 | is_extension=False, extension_scope=None, 202 | options=None), 203 | ], 204 | extensions=[ 205 | ], 206 | nested_types=[], 207 | enum_types=[ 208 | ], 209 | options=None, 210 | is_extendable=False, 211 | syntax='proto3', 212 | extension_ranges=[], 213 | oneofs=[ 214 | ], 215 | serialized_start=315, 216 | serialized_end=367, 217 | ) 218 | 219 | 220 | _PREDICTIONS_PREDICTION = _descriptor.Descriptor( 221 | name='Prediction', 222 | full_name='acton.Predictions.Prediction', 223 | filename=None, 224 | file=DESCRIPTOR, 225 | containing_type=None, 226 | fields=[ 227 | _descriptor.FieldDescriptor( 228 | name='id', full_name='acton.Predictions.Prediction.id', index=0, 229 | number=1, type=3, cpp_type=2, label=1, 230 | has_default_value=False, default_value=0, 231 | message_type=None, enum_type=None, containing_type=None, 232 | is_extension=False, extension_scope=None, 233 | options=None), 234 | _descriptor.FieldDescriptor( 235 | name='prediction', full_name='acton.Predictions.Prediction.prediction', index=1, 236 | number=2, type=1, cpp_type=5, label=3, 237 | has_default_value=False, default_value=[], 238 | message_type=None, enum_type=None, containing_type=None, 239 | is_extension=False, extension_scope=None, 240 | options=None), 241 | ], 242 | extensions=[ 243 | ], 244 | nested_types=[], 245 | enum_types=[ 246 | ], 247 | options=None, 248 | is_extendable=False, 249 | syntax='proto3', 250 | extension_ranges=[], 251 | oneofs=[ 252 | ], 253 | serialized_start=560, 254 | serialized_end=604, 255 | ) 256 | 257 | _PREDICTIONS = _descriptor.Descriptor( 258 | name='Predictions', 259 | full_name='acton.Predictions', 260 | filename=None, 261 | file=DESCRIPTOR, 262 | containing_type=None, 263 | fields=[ 264 | _descriptor.FieldDescriptor( 265 | name='prediction', full_name='acton.Predictions.prediction', index=0, 266 | number=1, type=11, cpp_type=10, label=3, 267 | has_default_value=False, default_value=[], 268 | message_type=None, enum_type=None, containing_type=None, 269 | is_extension=False, extension_scope=None, 270 | options=None), 271 | _descriptor.FieldDescriptor( 272 | name='labelled_id', full_name='acton.Predictions.labelled_id', index=1, 273 | number=2, type=3, cpp_type=2, label=3, 274 | has_default_value=False, default_value=[], 275 | message_type=None, enum_type=None, containing_type=None, 276 | is_extension=False, extension_scope=None, 277 | options=None), 278 | _descriptor.FieldDescriptor( 279 | name='n_predictors', full_name='acton.Predictions.n_predictors', index=2, 280 | number=3, type=5, cpp_type=1, label=1, 281 | has_default_value=False, default_value=0, 282 | message_type=None, enum_type=None, containing_type=None, 283 | is_extension=False, extension_scope=None, 284 | options=None), 285 | _descriptor.FieldDescriptor( 286 | name='n_prediction_dimensions', full_name='acton.Predictions.n_prediction_dimensions', index=3, 287 | number=4, type=5, cpp_type=1, label=1, 288 | has_default_value=False, default_value=0, 289 | message_type=None, enum_type=None, containing_type=None, 290 | is_extension=False, extension_scope=None, 291 | options=None), 292 | _descriptor.FieldDescriptor( 293 | name='predictor', full_name='acton.Predictions.predictor', index=4, 294 | number=5, type=9, cpp_type=9, label=1, 295 | has_default_value=False, default_value=_b("").decode('utf-8'), 296 | message_type=None, enum_type=None, containing_type=None, 297 | is_extension=False, extension_scope=None, 298 | options=None), 299 | _descriptor.FieldDescriptor( 300 | name='db', full_name='acton.Predictions.db', index=5, 301 | number=6, type=11, cpp_type=10, label=1, 302 | has_default_value=False, default_value=None, 303 | message_type=None, enum_type=None, containing_type=None, 304 | is_extension=False, extension_scope=None, 305 | options=None), 306 | ], 307 | extensions=[ 308 | ], 309 | nested_types=[_PREDICTIONS_PREDICTION, ], 310 | enum_types=[ 311 | ], 312 | options=None, 313 | is_extendable=False, 314 | syntax='proto3', 315 | extension_ranges=[], 316 | oneofs=[ 317 | ], 318 | serialized_start=370, 319 | serialized_end=604, 320 | ) 321 | 322 | 323 | _RECOMMENDATIONS = _descriptor.Descriptor( 324 | name='Recommendations', 325 | full_name='acton.Recommendations', 326 | filename=None, 327 | file=DESCRIPTOR, 328 | containing_type=None, 329 | fields=[ 330 | _descriptor.FieldDescriptor( 331 | name='recommended_id', full_name='acton.Recommendations.recommended_id', index=0, 332 | number=1, type=3, cpp_type=2, label=3, 333 | has_default_value=False, default_value=[], 334 | message_type=None, enum_type=None, containing_type=None, 335 | is_extension=False, extension_scope=None, 336 | options=None), 337 | _descriptor.FieldDescriptor( 338 | name='labelled_id', full_name='acton.Recommendations.labelled_id', index=1, 339 | number=2, type=3, cpp_type=2, label=3, 340 | has_default_value=False, default_value=[], 341 | message_type=None, enum_type=None, containing_type=None, 342 | is_extension=False, extension_scope=None, 343 | options=None), 344 | _descriptor.FieldDescriptor( 345 | name='recommender', full_name='acton.Recommendations.recommender', index=2, 346 | number=3, type=9, cpp_type=9, label=1, 347 | has_default_value=False, default_value=_b("").decode('utf-8'), 348 | message_type=None, enum_type=None, containing_type=None, 349 | is_extension=False, extension_scope=None, 350 | options=None), 351 | _descriptor.FieldDescriptor( 352 | name='db', full_name='acton.Recommendations.db', index=3, 353 | number=4, type=11, cpp_type=10, label=1, 354 | has_default_value=False, default_value=None, 355 | message_type=None, enum_type=None, containing_type=None, 356 | is_extension=False, extension_scope=None, 357 | options=None), 358 | ], 359 | extensions=[ 360 | ], 361 | nested_types=[], 362 | enum_types=[ 363 | ], 364 | options=None, 365 | is_extendable=False, 366 | syntax='proto3', 367 | extension_ranges=[], 368 | oneofs=[ 369 | ], 370 | serialized_start=606, 371 | serialized_end=718, 372 | ) 373 | 374 | _DATABASE_LABELENCODER_ENCODING.containing_type = _DATABASE_LABELENCODER 375 | _DATABASE_LABELENCODER.fields_by_name['encoding'].message_type = _DATABASE_LABELENCODER_ENCODING 376 | _DATABASE_LABELENCODER.containing_type = _DATABASE 377 | _DATABASE.fields_by_name['kwarg'].message_type = _KEYVAL 378 | _DATABASE.fields_by_name['label_encoder'].message_type = _DATABASE_LABELENCODER 379 | _LABELPOOL.fields_by_name['db'].message_type = _DATABASE 380 | _PREDICTIONS_PREDICTION.containing_type = _PREDICTIONS 381 | _PREDICTIONS.fields_by_name['prediction'].message_type = _PREDICTIONS_PREDICTION 382 | _PREDICTIONS.fields_by_name['db'].message_type = _DATABASE 383 | _RECOMMENDATIONS.fields_by_name['db'].message_type = _DATABASE 384 | DESCRIPTOR.message_types_by_name['KeyVal'] = _KEYVAL 385 | DESCRIPTOR.message_types_by_name['Database'] = _DATABASE 386 | DESCRIPTOR.message_types_by_name['LabelPool'] = _LABELPOOL 387 | DESCRIPTOR.message_types_by_name['Predictions'] = _PREDICTIONS 388 | DESCRIPTOR.message_types_by_name['Recommendations'] = _RECOMMENDATIONS 389 | 390 | KeyVal = _reflection.GeneratedProtocolMessageType('KeyVal', (_message.Message,), dict( 391 | DESCRIPTOR=_KEYVAL, 392 | __module__='acton_pb2' 393 | # @@protoc_insertion_point(class_scope:acton.KeyVal) 394 | )) 395 | _sym_db.RegisterMessage(KeyVal) 396 | 397 | Database = _reflection.GeneratedProtocolMessageType('Database', (_message.Message,), dict( 398 | 399 | LabelEncoder=_reflection.GeneratedProtocolMessageType('LabelEncoder', (_message.Message,), dict( 400 | 401 | Encoding=_reflection.GeneratedProtocolMessageType('Encoding', (_message.Message,), dict( 402 | DESCRIPTOR=_DATABASE_LABELENCODER_ENCODING, 403 | __module__='acton_pb2' 404 | # @@protoc_insertion_point(class_scope:acton.Database.LabelEncoder.Encoding) 405 | )), 406 | DESCRIPTOR=_DATABASE_LABELENCODER, 407 | __module__='acton_pb2' 408 | # @@protoc_insertion_point(class_scope:acton.Database.LabelEncoder) 409 | )), 410 | DESCRIPTOR=_DATABASE, 411 | __module__='acton_pb2' 412 | # @@protoc_insertion_point(class_scope:acton.Database) 413 | )) 414 | _sym_db.RegisterMessage(Database) 415 | _sym_db.RegisterMessage(Database.LabelEncoder) 416 | _sym_db.RegisterMessage(Database.LabelEncoder.Encoding) 417 | 418 | LabelPool = _reflection.GeneratedProtocolMessageType('LabelPool', (_message.Message,), dict( 419 | DESCRIPTOR=_LABELPOOL, 420 | __module__='acton_pb2' 421 | # @@protoc_insertion_point(class_scope:acton.LabelPool) 422 | )) 423 | _sym_db.RegisterMessage(LabelPool) 424 | 425 | Predictions = _reflection.GeneratedProtocolMessageType('Predictions', (_message.Message,), dict( 426 | 427 | Prediction=_reflection.GeneratedProtocolMessageType('Prediction', (_message.Message,), dict( 428 | DESCRIPTOR=_PREDICTIONS_PREDICTION, 429 | __module__='acton_pb2' 430 | # @@protoc_insertion_point(class_scope:acton.Predictions.Prediction) 431 | )), 432 | DESCRIPTOR=_PREDICTIONS, 433 | __module__='acton_pb2' 434 | # @@protoc_insertion_point(class_scope:acton.Predictions) 435 | )) 436 | _sym_db.RegisterMessage(Predictions) 437 | _sym_db.RegisterMessage(Predictions.Prediction) 438 | 439 | Recommendations = _reflection.GeneratedProtocolMessageType('Recommendations', (_message.Message,), dict( 440 | DESCRIPTOR=_RECOMMENDATIONS, 441 | __module__='acton_pb2' 442 | # @@protoc_insertion_point(class_scope:acton.Recommendations) 443 | )) 444 | _sym_db.RegisterMessage(Recommendations) 445 | 446 | 447 | # @@protoc_insertion_point(module_scope) 448 | -------------------------------------------------------------------------------- /acton/acton.py: -------------------------------------------------------------------------------- 1 | """Main processing script for Acton.""" 2 | 3 | import logging 4 | import time 5 | from typing import Iterable, List, TypeVar 6 | 7 | import acton.database 8 | import acton.labellers 9 | import acton.predictors 10 | import acton.proto.io 11 | import acton.proto.wrappers 12 | import acton.recommenders 13 | import numpy 14 | import pandas 15 | import sklearn.linear_model 16 | import sklearn.metrics 17 | import sklearn.model_selection 18 | import sklearn.preprocessing 19 | from sklearn.metrics import roc_auc_score 20 | 21 | T = TypeVar('T') 22 | 23 | 24 | def draw(n: int, lst: List[T], replace: bool = True) -> List[T]: 25 | """Draws n random elements from a list. 26 | 27 | Parameters 28 | --------- 29 | n 30 | Number of elements to draw. 31 | lst 32 | List of elements to draw from. 33 | replace 34 | Draw with replacement. 35 | 36 | Returns 37 | ------- 38 | List[T] 39 | n random elements. 40 | """ 41 | # While we use replace=False generally in this codebase, the NumPy default 42 | # is True - so we should use that here. 43 | return list(numpy.random.choice(lst, size=n, replace=replace)) 44 | 45 | 46 | def validate_predictor(predictor: str): 47 | """Raises an exception if the predictor is not valid. 48 | 49 | Parameters 50 | ---------- 51 | predictor 52 | Name of predictor. 53 | 54 | Raises 55 | ------ 56 | ValueError 57 | """ 58 | if predictor not in acton.predictors.PREDICTORS: 59 | raise ValueError('Unknown predictor: {}. predictors are one of ' 60 | '{}.'.format(predictor, 61 | acton.predictors.PREDICTORS.keys())) 62 | 63 | 64 | def validate_recommender(recommender: str): 65 | """Raises an exception if the recommender is not valid. 66 | 67 | Parameters 68 | ---------- 69 | recommender 70 | Name of recommender. 71 | 72 | Raises 73 | ------ 74 | ValueError 75 | """ 76 | if recommender not in acton.recommenders.RECOMMENDERS: 77 | raise ValueError('Unknown recommender: {}. Recommenders are one of ' 78 | '{}.'.format(recommender, 79 | acton.recommenders.RECOMMENDERS.keys())) 80 | 81 | 82 | def simulate_active_learning( 83 | ids: Iterable[int], 84 | db: acton.database.Database, 85 | db_kwargs: dict, 86 | output_path: str, 87 | n_initial_labels: int = 10, 88 | n_epochs: int = 10, 89 | test_size: int = 0.2, 90 | recommender: str = 'RandomRecommender', 91 | predictor: str = 'LogisticRegression', 92 | labeller: str = 'DatabaseLabeller', 93 | n_recommendations: int = 1, 94 | diversity: float = 0.5, 95 | repeated_labelling: bool = True, 96 | inc_sub: bool = False, 97 | subn_entities: int=0, 98 | subn_relations: int=0): 99 | """Simulates an active learning task. 100 | 101 | Parameters 102 | --------- 103 | ids 104 | IDs of instances in the unlabelled pool. 105 | db 106 | Database with features and labels. 107 | db_kwargs 108 | Keyword arguments for the database constructor. 109 | output_path 110 | Path to output intermediate predictions to. Will be overwritten. 111 | n_initial_labels 112 | Number of initial labels to draw. 113 | n_epochs 114 | Number of epochs. 115 | test_size 116 | Percentage size of testing set. 117 | recommender 118 | Name of recommender to make recommendations. 119 | labeller 120 | Name of labeller to label 121 | predictor 122 | Name of predictor to make predictions. 123 | n_recommendations 124 | Number of recommendations to make at once. 125 | repeated_labelling 126 | whether allow one instance to be labelled more than once 127 | inc_sub 128 | indicates whether increasing subsampling size when gets more labels 129 | subn_entities 130 | number of entities for subsampling 131 | subn_relations 132 | number of relations for subsampling 133 | """ 134 | validate_recommender(recommender) 135 | validate_predictor(predictor) 136 | 137 | # Seed RNG. 138 | numpy.random.seed(0) 139 | 140 | # Bytestring describing this run. 141 | metadata = '{} | {}'.format(recommender, predictor).encode('ascii') 142 | 143 | # Split into training and testing sets. 144 | logging.debug('Found {} instances.'.format(len(ids))) 145 | logging.debug('Splitting into training/testing sets.') 146 | train_ids, test_ids = sklearn.model_selection.train_test_split( 147 | ids, test_size=test_size) 148 | test_ids.sort() 149 | 150 | # Set up predictor, labeller, and recommender. 151 | # TODO(MatthewJA): Handle multiple labellers better than just averaging. 152 | predictor_name = predictor # For saving. 153 | predictor = acton.predictors.PREDICTORS[predictor](db=db, n_jobs=-1) 154 | labeller_name = labeller 155 | labeller = acton.labellers.LABELLERS[labeller](db) 156 | recommender = acton.recommenders.RECOMMENDERS[recommender](db=db) 157 | 158 | # Draw some initial labels. 159 | logging.debug('Drawing initial labels.') 160 | recommendations = draw(n_initial_labels, train_ids, replace=False) 161 | 162 | if labeller_name == 'GraphDatabaseLabeller': 163 | 164 | assert subn_entities > 0 165 | assert subn_relations > 0 166 | 167 | tensor_ids = ids.reshape( 168 | (db.n_relations, db.n_entities, db.n_entities)) 169 | 170 | rec_x, rec_y, rec_z = \ 171 | numpy.unravel_index(recommendations, tensor_ids.shape) 172 | recommendations = list(zip(rec_x, rec_y, rec_z)) 173 | 174 | train_x, train_y, train_z = \ 175 | numpy.unravel_index(ids[train_ids], tensor_ids.shape) 176 | train_ids = list(zip(train_x, train_y, train_z)) 177 | 178 | test_x, test_y, test_z = \ 179 | numpy.unravel_index(ids[test_ids], tensor_ids.shape) 180 | test_ids = list(zip(test_x, test_y, test_z)) 181 | 182 | logging.debug('Recommending: {}'.format(recommendations)) 183 | 184 | # This will store all IDs of things we have already labelled. 185 | labelled_ids = [] 186 | # This will store all the corresponding labels. 187 | labels = numpy.zeros((0, 1)) 188 | 189 | # Simulation loop. 190 | logging.debug('Writing protobufs to {}.'.format(output_path)) 191 | writer = acton.proto.io.write_protos(output_path, metadata=metadata) 192 | next(writer) # Prime the coroutine. 193 | 194 | train_error_list = [] 195 | test_error_list = [] 196 | 197 | gain_ts = [] 198 | run_time = [] 199 | 200 | for epoch in range(n_epochs): 201 | begin_epoch = time.time() 202 | logging.info('Epoch {}/{}'.format(epoch + 1, n_epochs)) 203 | # Label the recommendations. 204 | logging.debug('Labelling recommendations.') 205 | new_labels = numpy.array([ 206 | labeller.query(id_) for id_ in recommendations]).reshape((-1, 1)) 207 | 208 | labelled_ids.extend(recommendations) 209 | logging.debug('Sorting label IDs.') 210 | 211 | if labeller_name != 'GraphDatabaseLabeller': 212 | labelled_ids.sort() 213 | 214 | labels = numpy.concatenate([labels, new_labels], axis=0) 215 | 216 | # Here, we would write the labels to the database, but they're already 217 | # there since we're just reading them from there anyway. 218 | pass 219 | 220 | # Pass the labels to the predictor. 221 | logging.debug('Fitting predictor.') 222 | then = time.time() 223 | if labeller_name == 'GraphDatabaseLabeller': 224 | predictor.fit(labelled_ids, 225 | inc_sub=inc_sub, 226 | subn_entities=subn_entities, 227 | subn_relations=subn_relations) 228 | else: 229 | predictor.fit(labelled_ids) 230 | logging.debug('(Took {:.02} s.)'.format(time.time() - then)) 231 | 232 | # Evaluate the predictor. 233 | logging.debug( 234 | 'Making predictions (reference, n = {}).'.format(len(test_ids))) 235 | then = time.time() 236 | test_pred, _test_var = predictor.reference_predict(test_ids) 237 | 238 | logging.debug('(Took {:.02} s.)'.format(time.time() - then)) 239 | 240 | # Construct a protobuf for outputting predictions. 241 | if labeller_name != 'GraphDatabaseLabeller': 242 | proto = acton.proto.wrappers.Predictions.make( 243 | test_ids, 244 | labelled_ids, 245 | test_pred.transpose([1, 0, 2]), # T x N x C -> N x T x C 246 | predictor=predictor_name, 247 | db=db) 248 | # Then write them to a file. 249 | logging.debug('Writing predictions.') 250 | writer.send(proto.proto) 251 | 252 | # Pass the predictions to the recommender. 253 | 254 | unlabelled_ids = list(set(train_ids) - set(labelled_ids)) 255 | if not unlabelled_ids: 256 | logging.info('Labelled all instances.') 257 | break 258 | 259 | unlabelled_ids.sort() 260 | 261 | logging.debug( 262 | 'Making predictions (unlabelled, n = {}).'.format( 263 | len(unlabelled_ids))) 264 | then = time.time() 265 | predictions, _variances = predictor.predict(unlabelled_ids) 266 | 267 | logging.debug('(Took {:.02} s.)'.format(time.time() - then)) 268 | 269 | logging.debug('Making recommendations.') 270 | 271 | if labeller_name == 'GraphDatabaseLabeller': 272 | true_labels = db.read_labels([]) 273 | # logging.debug(unlabelled_ids) 274 | 275 | recommendations = recommender.recommend( 276 | unlabelled_ids, predictions, n=n_recommendations, 277 | diversity=diversity, repreated_labelling=repeated_labelling) 278 | logging.debug('Recommending: {}'.format(recommendations)) 279 | 280 | # compute ROC_AUC_SCORE 281 | train_error = \ 282 | roc_auc_score(true_labels[train_x, train_y, train_z].flatten(), 283 | predictions[train_x, train_y, train_z].flatten()) 284 | test_error = \ 285 | roc_auc_score(true_labels[test_x, test_y, test_z].flatten(), 286 | predictions[test_x, test_y, test_z].flatten()) 287 | 288 | train_error_list.append(train_error) 289 | test_error_list.append(test_error) 290 | 291 | # compute cumulative gain 292 | idx = numpy.unravel_index( 293 | predictions.argmax(), 294 | predictions.shape 295 | ) 296 | if true_labels[idx] == 1: 297 | gain_ts.append(1) 298 | else: 299 | gain_ts.append(0) 300 | # regret_ts = compute_regret(true_labels, seq) 301 | # gain_ts = 1 - numpy.array(regret_ts) 302 | 303 | # return train_error_list, test_error_list, gain_ts 304 | else: 305 | recommendations = recommender.recommend( 306 | unlabelled_ids, predictions, n=n_recommendations) 307 | logging.debug('Recommending: {}'.format(recommendations)) 308 | end_epoch = time.time() 309 | 310 | run_time.append(end_epoch - begin_epoch) 311 | 312 | if labeller_name == 'GraphDatabaseLabeller': 313 | return train_error_list, test_error_list, gain_ts, run_time 314 | else: 315 | return 0 316 | 317 | 318 | def try_pandas(data_path: str) -> bool: 319 | """Guesses if a file is a pandas file. 320 | 321 | Parameters 322 | ---------- 323 | data_path 324 | Path to file. 325 | 326 | Returns 327 | ------- 328 | bool 329 | True if the file is pandas. 330 | """ 331 | try: 332 | pandas.read_hdf(data_path) 333 | except ValueError: 334 | return False 335 | 336 | return True 337 | 338 | 339 | def get_DB( 340 | data_path: str, 341 | pandas_key: str = None) -> (acton.database.Database, dict): 342 | """Gets a Database that will handle the given data table. 343 | 344 | Parameters 345 | ---------- 346 | data_path 347 | Path to file. 348 | pandas_key 349 | Key for pandas HDF5. Specify iff using pandas. 350 | 351 | Returns 352 | ------- 353 | Database 354 | Database that will handle the given data table. 355 | dict 356 | Keyword arguments for the Database constructor. 357 | """ 358 | db_kwargs = {} 359 | 360 | is_fits = data_path.endswith('.fits') 361 | is_ascii = not data_path.endswith('.h5') 362 | if is_fits: 363 | logging.debug('Reading {} as FITS.'.format(data_path)) 364 | DB = acton.database.FITSReader 365 | elif is_ascii: 366 | logging.debug('Reading {} as ASCII.'.format(data_path)) 367 | DB = acton.database.ASCIIReader 368 | else: 369 | # Assume HDF5. 370 | is_pandas = bool(pandas_key) 371 | if is_pandas: 372 | logging.debug('Reading {} as pandas.'.format(data_path)) 373 | DB = acton.database.PandasReader 374 | db_kwargs['key'] = pandas_key 375 | else: 376 | logging.debug('Reading {} as HDF5.'.format(data_path)) 377 | DB = acton.database.HDF5Reader 378 | 379 | return DB, db_kwargs 380 | 381 | 382 | def main(data_path: str, feature_cols: List[str], label_col: str, 383 | output_path: str, n_epochs: int = 10, initial_count: int = 10, 384 | recommender: str = 'RandomRecommender', 385 | predictor: str = 'LogisticRegression', pandas_key: str = '', 386 | n_recommendations: int = 1): 387 | """Simulate an active learning experiment. 388 | 389 | Parameters 390 | --------- 391 | data_path 392 | Path to data file. 393 | feature_cols 394 | List of column names of the features. If empty, all non-label and non-ID 395 | columns will be used. 396 | label_col 397 | Column name of the labels. 398 | output_path 399 | Path to output file. Will be overwritten. 400 | n_epochs 401 | Number of epochs to run. 402 | initial_count 403 | Number of random instances to label initially. 404 | recommender 405 | Name of recommender to make recommendations. 406 | predictor 407 | Name of predictor to make predictions. 408 | pandas_key 409 | Key for pandas HDF5. Specify iff using pandas. 410 | n_recommendations 411 | Number of recommendations to make at once. 412 | """ 413 | DB, db_kwargs = get_DB(data_path, pandas_key=pandas_key) 414 | 415 | db_kwargs['feature_cols'] = feature_cols 416 | db_kwargs['label_col'] = label_col 417 | 418 | with DB(data_path, **db_kwargs) as reader: 419 | return simulate_active_learning( 420 | reader.get_known_instance_ids(), 421 | reader, 422 | db_kwargs, 423 | output_path, 424 | n_epochs=n_epochs, 425 | n_initial_labels=initial_count, 426 | recommender=recommender, 427 | predictor=predictor, 428 | n_recommendations=n_recommendations) 429 | 430 | 431 | def predict( 432 | labels: acton.proto.wrappers.LabelPool, 433 | predictor: str) -> acton.proto.wrappers.Predictions: 434 | """Train a predictor and predict labels. 435 | 436 | Parameters 437 | --------- 438 | labels 439 | IDs of labelled instances. 440 | predictor 441 | Name of predictor to make predictions. 442 | """ 443 | validate_predictor(predictor) 444 | 445 | with labels.DB() as db: 446 | ids = db.get_known_instance_ids() 447 | train_ids = labels.ids 448 | 449 | predictor_name = predictor 450 | predictor = acton.predictors.PREDICTORS[predictor](db=db, n_jobs=-1) 451 | 452 | logging.debug('Training predictor with IDs: {}'.format(train_ids)) 453 | predictor.fit(train_ids) 454 | 455 | predictions, _variances = predictor.reference_predict(ids) 456 | 457 | # Construct a protobuf for outputting predictions. 458 | proto = acton.proto.wrappers.Predictions.make( 459 | ids, 460 | train_ids, 461 | predictions.transpose([1, 0, 2]), # T x N x C -> N x T x C 462 | predictor=predictor_name, 463 | db=db) 464 | return proto 465 | 466 | 467 | def recommend( 468 | predictions: acton.proto.wrappers.Predictions, 469 | recommender: str = 'RandomRecommender', 470 | n_recommendations: int = 1) -> acton.proto.wrappers.Recommendations: 471 | """Recommends instances to label based on predictions. 472 | 473 | Parameters 474 | --------- 475 | recommender 476 | Name of recommender to make recommendations. 477 | n_recommendations 478 | Number of recommendations to make at once. Default 1. 479 | 480 | Returns 481 | ------- 482 | acton.proto.wrappers.Recommendations 483 | """ 484 | validate_recommender(recommender) 485 | 486 | # Make a list of IDs that do not have labels and the indices of the 487 | # corresponding predictions. 488 | ids = [] 489 | indices = [] 490 | has_labels = set(predictions.labelled_ids) 491 | for pred_index, id_ in enumerate(predictions.predicted_ids): 492 | if id_ not in has_labels: 493 | ids.append(id_) 494 | indices.append(pred_index) 495 | # Array of predictions for unlabelled instances. 496 | predictions_array = predictions.predictions[:, indices] 497 | 498 | with predictions.DB() as db: 499 | recommender_name = recommender 500 | recommender = acton.recommenders.RECOMMENDERS[recommender](db=db) 501 | recommendations = recommender.recommend( 502 | ids, predictions_array, n=n_recommendations) 503 | 504 | logging.debug('Recommending: {}'.format(list(recommendations))) 505 | 506 | # Construct a protobuf for outputting recommendations. 507 | proto = acton.proto.wrappers.Recommendations.make( 508 | [int(r) for r in recommendations], 509 | predictions.labelled_ids, 510 | recommender=recommender_name, 511 | db=db) 512 | return proto 513 | 514 | 515 | def label(recommendations: acton.proto.wrappers.Recommendations 516 | ) -> acton.proto.wrappers.LabelPool: 517 | """Simulates a labelling task. 518 | 519 | Parameters 520 | --------- 521 | data_path 522 | Path to data file. 523 | feature_cols 524 | List of column names of features. If empty, all columns will be used. 525 | label_col 526 | Column name of the labels. 527 | pandas_key 528 | Key for pandas HDF5. Specify iff using pandas. 529 | 530 | Returns 531 | ------- 532 | acton.proto.wrappers.LabelPool 533 | """ 534 | # We'd store the labels here, except that we just read them from the DB. 535 | # Instead, we'll record that we've labelled them. 536 | # # labeller = acton.labellers.DatabaseLabeller(db) 537 | # # labels = [labeller.query(id_) for id_ in ids] 538 | 539 | # TODO(MatthewJA): Consider optimising this (doesn't really need a sort). 540 | ids_to_label = recommendations.recommendations 541 | labelled_ids = recommendations.labelled_ids 542 | logging.debug('Recommended IDs: {}'.format(ids_to_label)) 543 | logging.debug('Already labelled IDs: {}'.format(labelled_ids)) 544 | ids = sorted(set(ids_to_label) | set(labelled_ids)) 545 | logging.debug('Now labelled IDs: {}'.format(ids)) 546 | 547 | # Return a protobuf. 548 | with recommendations.DB() as db: 549 | proto = acton.proto.wrappers.LabelPool.make(ids=ids, db=db) 550 | return proto 551 | --------------------------------------------------------------------------------