├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── requirements.txt ├── run_experiment.py ├── sampling_methods ├── __init__.py ├── bandit_discrete.py ├── constants.py ├── graph_density.py ├── hierarchical_clustering_AL.py ├── informative_diverse.py ├── kcenter_greedy.py ├── margin_AL.py ├── mixture_of_samplers.py ├── represent_cluster_centers.py ├── sampling_def.py ├── simulate_batch.py ├── uniform_sampling.py ├── utils │ ├── __init__.py │ ├── tree.py │ └── tree_test.py └── wrapper_sampler_def.py └── utils ├── __init__.py ├── allconv.py ├── chart_data.py ├── create_data.py ├── kernel_block_solver.py ├── small_cnn.py └── utils.py /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Active Learning Playground 2 | 3 | ## Introduction 4 | 5 | This is a python module for experimenting with different active learning 6 | algorithms. There are a few key components to running active learning 7 | experiments: 8 | 9 | * Main experiment script is 10 | [`run_experiment.py`](run_experiment.py) 11 | with many flags for different run options. 12 | 13 | * Supported datasets can be downloaded to a specified directory by running 14 | [`utils/create_data.py`](utils/create_data.py). 15 | 16 | * Supported active learning methods are in 17 | [`sampling_methods`](sampling_methods/). 18 | 19 | Below I will go into each component in more detail. 20 | 21 | DISCLAIMER: This is not an official Google product. 22 | 23 | ## Setup 24 | The dependencies are in [`requirements.txt`](requirements.txt). Please make sure these packages are 25 | installed before running experiments. If GPU capable `tensorflow` is desired, please follow 26 | instructions [here](https://www.tensorflow.org/install/). 27 | 28 | It is highly suggested that you install all dependencies into a separate `virtualenv` for 29 | easy package management. 30 | 31 | ## Getting benchmark datasets 32 | 33 | By default the datasets are saved to `/tmp/data`. You can specify another directory via the 34 | `--save_dir` flag. 35 | 36 | Redownloading all the datasets will be very time consuming so please be patient. 37 | You can specify a subset of the data to download by passing in a comma separated 38 | string of datasets via the `--datasets` flag. 39 | 40 | ## Running experiments 41 | 42 | There are a few key flags for 43 | [`run_experiment.py`](run_experiment.py): 44 | 45 | * `dataset`: name of the dataset, must match the save name used in 46 | `create_data.py`. Must also exist in the data_dir. 47 | 48 | * `sampling_method`: active learning method to use. Must be specified in 49 | [`sampling_methods/constants.py`](sampling_methods/constants.py). 50 | 51 | * `warmstart_size`: initial batch of uniformly sampled examples to use as seed 52 | data. Float indicates percentage of total training data and integer 53 | indicates raw size. 54 | 55 | * `batch_size`: number of datapoints to request in each batch. Float indicates 56 | percentage of total training data and integer indicates raw size. 57 | 58 | * `score_method`: model to use to evaluate the performance of the sampling 59 | method. Must be in `get_model` method of 60 | [`utils/utils.py`](utils/utils.py). 61 | 62 | * `data_dir`: directory with saved datasets. 63 | 64 | * `save_dir`: directory to save results. 65 | 66 | This is just a subset of all the flags. There are also options for 67 | preprocessing, introducing labeling noise, dataset subsampling, and using a 68 | different model to select than to score/evaluate. 69 | 70 | ## Available active learning methods 71 | 72 | All named active learning methods are in 73 | [`sampling_methods/constants.py`](sampling_methods/constants.py). 74 | 75 | You can also specify a mixture of active learning methods by following the 76 | pattern of `[sampling_method]-[mixture_weight]` separated by dashes; i.e. 77 | `mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34`. 78 | 79 | Some supported sampling methods include: 80 | 81 | * Uniform: samples are selected via uniform sampling. 82 | 83 | * Margin: uncertainty based sampling method. 84 | 85 | * Informative and diverse: margin and cluster based sampling method. 86 | 87 | * k-center greedy: representative strategy that greedily forms a batch of 88 | points to minimize maximum distance from a labeled point. 89 | 90 | * Graph density: representative strategy that selects points in dense regions 91 | of pool. 92 | 93 | * Exp3 bandit: meta-active learning method that tries to learns optimal 94 | sampling method using a popular multi-armed bandit algorithm. 95 | 96 | ### Adding new active learning methods 97 | 98 | Implement either a base sampler that inherits from 99 | [`SamplingMethod`](sampling_methods/sampling_def.py) 100 | or a meta-sampler that calls base samplers which inherits from 101 | [`WrapperSamplingMethod`](sampling_methods/wrapper_sampler_def.py). 102 | 103 | The only method that must be implemented by any sampler is `select_batch_`, 104 | which can have arbitrary named arguments. The only restriction is that the name 105 | for the same input must be consistent across all the samplers (i.e. the indices 106 | for already selected examples all have the same name across samplers). Adding a 107 | new named argument that hasn't been used in other sampling methods will require 108 | feeding that into the `select_batch` call in 109 | [`run_experiment.py`](run_experiment.py). 110 | 111 | After implementing your sampler, be sure to add it to 112 | [`constants.py`](sampling_methods/constants.py) 113 | so that it can be called from 114 | [`run_experiment.py`](run_experiment.py). 115 | 116 | ## Available models 117 | 118 | All available models are in the `get_model` method of 119 | [`utils/utils.py`](utils/utils.py). 120 | 121 | Supported methods: 122 | 123 | * Linear SVM: scikit method with grid search wrapper for regularization 124 | parameter. 125 | 126 | * Kernel SVM: scikit method with grid search wrapper for regularization 127 | parameter. 128 | 129 | * Logistc Regression: scikit method with grid search wrapper for 130 | regularization parameter. 131 | 132 | * Small CNN: 4 layer CNN optimized using rmsprop implemented in Keras with 133 | tensorflow backend. 134 | 135 | * Kernel Least Squares Classification: block gradient descient solver that can 136 | use multiple cores so is often faster than scikit Kernel SVM. 137 | 138 | ### Adding new models 139 | 140 | New models must follow the scikit learn api and implement the following methods 141 | 142 | * `fit(X, y[, sample_weight])`: fit the model to the input features and 143 | target. 144 | 145 | * `predict(X)`: predict the value of the input features. 146 | 147 | * `score(X, y)`: returns target metric given test features and test targets. 148 | 149 | * `decision_function(X)` (optional): return class probabilities, distance to 150 | decision boundaries, or other metric that can be used by margin sampler as a 151 | measure of uncertainty. 152 | 153 | See 154 | [`small_cnn.py`](utils/small_cnn.py) 155 | for an example. 156 | 157 | After implementing your new model, be sure to add it to `get_model` method of 158 | [`utils/utils.py`](utils/utils.py). 159 | 160 | Currently models must be added on a one-off basis and not all scikit-learn 161 | classifiers are supported due to the need for user input on whether and how to 162 | tune the hyperparameters of the model. However, it is very easy to add a 163 | scikit-learn model with hyperparameter search wrapped around as a supported 164 | model. 165 | 166 | ## Collecting results and charting 167 | 168 | The 169 | [`utils/chart_data.py`](utils/chart_data.py) 170 | script handles processing of data and charting for a specified dataset and 171 | source directory. 172 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | numpy>=1.13 3 | scipy>=0.19 4 | pandas>=0.20 5 | scikit-learn>=0.19 6 | matplotlib>=2.0.2 7 | tensorflow>=1.3 8 | keras>=2.0.8 9 | -------------------------------------------------------------------------------- /run_experiment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Run active learner on classification tasks. 16 | 17 | Supported datasets include mnist, letter, cifar10, newsgroup20, rcv1, 18 | wikipedia attack, and select classification datasets from mldata. 19 | See utils/create_data.py for all available datasets. 20 | 21 | For binary classification, mnist_4_9 indicates mnist filtered down to just 4 and 22 | 9. 23 | By default uses logistic regression but can also train using kernel SVM. 24 | 2 fold cv is used to tune regularization parameter over a exponential grid. 25 | 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | import os 33 | import pickle 34 | import sys 35 | from time import gmtime 36 | from time import strftime 37 | 38 | import numpy as np 39 | from sklearn.preprocessing import normalize 40 | from sklearn.preprocessing import StandardScaler 41 | 42 | from absl import app 43 | from absl import flags 44 | from tensorflow import gfile 45 | 46 | from sampling_methods.constants import AL_MAPPING 47 | from sampling_methods.constants import get_AL_sampler 48 | from sampling_methods.constants import get_wrapper_AL_mapping 49 | from utils import utils 50 | 51 | flags.DEFINE_string("dataset", "letter", "Dataset name") 52 | flags.DEFINE_string("sampling_method", "margin", 53 | ("Name of sampling method to use, can be any defined in " 54 | "AL_MAPPING in sampling_methods.constants")) 55 | flags.DEFINE_float( 56 | "warmstart_size", 0.02, 57 | ("Can be float or integer. Float indicates percentage of training data " 58 | "to use in the initial warmstart model") 59 | ) 60 | flags.DEFINE_float( 61 | "batch_size", 0.02, 62 | ("Can be float or integer. Float indicates batch size as a percentage " 63 | "of training data size.") 64 | ) 65 | flags.DEFINE_integer("trials", 1, 66 | "Number of curves to create using different seeds") 67 | flags.DEFINE_integer("seed", 1, "Seed to use for rng and random state") 68 | # TODO(lisha): add feature noise to simulate data outliers 69 | flags.DEFINE_string("confusions", "0.", "Percentage of labels to randomize") 70 | flags.DEFINE_string("active_sampling_percentage", "1.0", 71 | "Mixture weights on active sampling.") 72 | flags.DEFINE_string( 73 | "score_method", "logistic", 74 | "Method to use to calculate accuracy.") 75 | flags.DEFINE_string( 76 | "select_method", "None", 77 | "Method to use for selecting points.") 78 | flags.DEFINE_string("normalize_data", "False", "Whether to normalize the data.") 79 | flags.DEFINE_string("standardize_data", "True", 80 | "Whether to standardize the data.") 81 | flags.DEFINE_string("save_dir", "/tmp/toy_experiments", 82 | "Where to save outputs") 83 | flags.DEFINE_string("data_dir", "/tmp/data", 84 | "Directory with predownloaded and saved datasets.") 85 | flags.DEFINE_string("max_dataset_size", "15000", 86 | ("maximum number of datapoints to include in data " 87 | "zero indicates no limit")) 88 | flags.DEFINE_float("train_horizon", "1.0", 89 | "how far to extend learning curve as a percent of train") 90 | flags.DEFINE_string("do_save", "True", 91 | "whether to save log and results") 92 | FLAGS = flags.FLAGS 93 | 94 | 95 | get_wrapper_AL_mapping() 96 | 97 | 98 | def generate_one_curve(X, 99 | y, 100 | sampler, 101 | score_model, 102 | seed, 103 | warmstart_size, 104 | batch_size, 105 | select_model=None, 106 | confusion=0., 107 | active_p=1.0, 108 | max_points=None, 109 | standardize_data=False, 110 | norm_data=False, 111 | train_horizon=0.5): 112 | """Creates one learning curve for both active and passive learning. 113 | 114 | Will calculate accuracy on validation set as the number of training data 115 | points increases for both PL and AL. 116 | Caveats: training method used is sensitive to sorting of the data so we 117 | resort all intermediate datasets 118 | 119 | Args: 120 | X: training data 121 | y: training labels 122 | sampler: sampling class from sampling_methods, assumes reference 123 | passed in and sampler not yet instantiated. 124 | score_model: model used to score the samplers. Expects fit and predict 125 | methods to be implemented. 126 | seed: seed used for data shuffle and other sources of randomness in sampler 127 | or model training 128 | warmstart_size: float or int. float indicates percentage of train data 129 | to use for initial model 130 | batch_size: float or int. float indicates batch size as a percent of 131 | training data 132 | select_model: defaults to None, in which case the score model will be 133 | used to select new datapoints to label. Model must implement fit, predict 134 | and depending on AL method may also need decision_function. 135 | confusion: percentage of labels of one class to flip to the other 136 | active_p: percent of batch to allocate to active learning 137 | max_points: limit dataset size for preliminary 138 | standardize_data: wheter to standardize the data to 0 mean unit variance 139 | norm_data: whether to normalize the data. Default is False for logistic 140 | regression. 141 | train_horizon: how long to draw the curve for. Percent of training data. 142 | 143 | Returns: 144 | results: dictionary of results for all samplers 145 | sampler_states: dictionary of sampler objects for debugging 146 | """ 147 | # TODO(lishal): add option to find best hyperparameter setting first on 148 | # full dataset and fix the hyperparameter for the rest of the routine 149 | # This will save computation and also lead to more stable behavior for the 150 | # test accuracy 151 | 152 | # TODO(lishal): remove mixture parameter and have the mixture be specified as 153 | # a mixture of samplers strategy 154 | def select_batch(sampler, uniform_sampler, mixture, N, already_selected, 155 | **kwargs): 156 | n_active = int(mixture * N) 157 | n_passive = N - n_active 158 | kwargs["N"] = n_active 159 | kwargs["already_selected"] = already_selected 160 | batch_AL = sampler.select_batch(**kwargs) 161 | already_selected = already_selected + batch_AL 162 | kwargs["N"] = n_passive 163 | kwargs["already_selected"] = already_selected 164 | batch_PL = uniform_sampler.select_batch(**kwargs) 165 | return batch_AL + batch_PL 166 | 167 | np.random.seed(seed) 168 | data_splits = [2./3, 1./6, 1./6] 169 | 170 | # 2/3 of data for training 171 | if max_points is None: 172 | max_points = len(y) 173 | train_size = int(min(max_points, len(y)) * data_splits[0]) 174 | if batch_size < 1: 175 | batch_size = int(batch_size * train_size) 176 | else: 177 | batch_size = int(batch_size) 178 | if warmstart_size < 1: 179 | # Set seed batch to provide enough samples to get at least 4 per class 180 | # TODO(lishal): switch to sklearn stratified sampler 181 | seed_batch = int(warmstart_size * train_size) 182 | else: 183 | seed_batch = int(warmstart_size) 184 | seed_batch = max(seed_batch, 6 * len(np.unique(y))) 185 | 186 | indices, X_train, y_train, X_val, y_val, X_test, y_test, y_noise = ( 187 | utils.get_train_val_test_splits(X,y,max_points,seed,confusion, 188 | seed_batch, split=data_splits)) 189 | 190 | # Preprocess data 191 | if norm_data: 192 | print("Normalizing data") 193 | X_train = normalize(X_train) 194 | X_val = normalize(X_val) 195 | X_test = normalize(X_test) 196 | if standardize_data: 197 | print("Standardizing data") 198 | scaler = StandardScaler().fit(X_train) 199 | X_train = scaler.transform(X_train) 200 | X_val = scaler.transform(X_val) 201 | X_test = scaler.transform(X_test) 202 | print("active percentage: " + str(active_p) + " warmstart batch: " + 203 | str(seed_batch) + " batch size: " + str(batch_size) + " confusion: " + 204 | str(confusion) + " seed: " + str(seed)) 205 | 206 | # Initialize samplers 207 | uniform_sampler = AL_MAPPING["uniform"](X_train, y_train, seed) 208 | sampler = sampler(X_train, y_train, seed) 209 | 210 | results = {} 211 | data_sizes = [] 212 | accuracy = [] 213 | selected_inds = range(seed_batch) 214 | 215 | # If select model is None, use score_model 216 | same_score_select = False 217 | if select_model is None: 218 | select_model = score_model 219 | same_score_select = True 220 | 221 | n_batches = int(np.ceil((train_horizon * train_size - seed_batch) * 222 | 1.0 / batch_size)) + 1 223 | for b in range(n_batches): 224 | n_train = seed_batch + min(train_size - seed_batch, b * batch_size) 225 | print("Training model on " + str(n_train) + " datapoints") 226 | 227 | assert n_train == len(selected_inds) 228 | data_sizes.append(n_train) 229 | 230 | # Sort active_ind so that the end results matches that of uniform sampling 231 | partial_X = X_train[sorted(selected_inds)] 232 | partial_y = y_train[sorted(selected_inds)] 233 | score_model.fit(partial_X, partial_y) 234 | if not same_score_select: 235 | select_model.fit(partial_X, partial_y) 236 | acc = score_model.score(X_test, y_test) 237 | accuracy.append(acc) 238 | print("Sampler: %s, Accuracy: %.2f%%" % (sampler.name, accuracy[-1]*100)) 239 | 240 | n_sample = min(batch_size, train_size - len(selected_inds)) 241 | select_batch_inputs = { 242 | "model": select_model, 243 | "labeled": dict(zip(selected_inds, y_train[selected_inds])), 244 | "eval_acc": accuracy[-1], 245 | "X_test": X_val, 246 | "y_test": y_val, 247 | "y": y_train 248 | } 249 | new_batch = select_batch(sampler, uniform_sampler, active_p, n_sample, 250 | selected_inds, **select_batch_inputs) 251 | selected_inds.extend(new_batch) 252 | print('Requested: %d, Selected: %d' % (n_sample, len(new_batch))) 253 | assert len(new_batch) == n_sample 254 | assert len(list(set(selected_inds))) == len(selected_inds) 255 | 256 | # Check that the returned indice are correct and will allow mapping to 257 | # training set from original data 258 | assert all(y_noise[indices[selected_inds]] == y_train[selected_inds]) 259 | results["accuracy"] = accuracy 260 | results["selected_inds"] = selected_inds 261 | results["data_sizes"] = data_sizes 262 | results["indices"] = indices 263 | results["noisy_targets"] = y_noise 264 | return results, sampler 265 | 266 | 267 | def main(argv): 268 | del argv 269 | 270 | if not gfile.Exists(FLAGS.save_dir): 271 | try: 272 | gfile.MkDir(FLAGS.save_dir) 273 | except: 274 | print(('WARNING: error creating save directory, ' 275 | 'directory most likely already created.')) 276 | 277 | save_dir = os.path.join( 278 | FLAGS.save_dir, 279 | FLAGS.dataset + "_" + FLAGS.sampling_method) 280 | do_save = FLAGS.do_save == "True" 281 | 282 | if do_save: 283 | if not gfile.Exists(save_dir): 284 | try: 285 | gfile.MkDir(save_dir) 286 | except: 287 | print(('WARNING: error creating save directory, ' 288 | 'directory most likely already created.')) 289 | # Set up logging 290 | filename = os.path.join( 291 | save_dir, "log-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime()) + ".txt") 292 | sys.stdout = utils.Logger(filename) 293 | 294 | confusions = [float(t) for t in FLAGS.confusions.split(" ")] 295 | mixtures = [float(t) for t in FLAGS.active_sampling_percentage.split(" ")] 296 | all_results = {} 297 | max_dataset_size = None if FLAGS.max_dataset_size == "0" else int( 298 | FLAGS.max_dataset_size) 299 | normalize_data = FLAGS.normalize_data == "True" 300 | standardize_data = FLAGS.standardize_data == "True" 301 | X, y = utils.get_mldata(FLAGS.data_dir, FLAGS.dataset) 302 | starting_seed = FLAGS.seed 303 | 304 | for c in confusions: 305 | for m in mixtures: 306 | for seed in range(starting_seed, starting_seed + FLAGS.trials): 307 | sampler = get_AL_sampler(FLAGS.sampling_method) 308 | score_model = utils.get_model(FLAGS.score_method, seed) 309 | if (FLAGS.select_method == "None" or 310 | FLAGS.select_method == FLAGS.score_method): 311 | select_model = None 312 | else: 313 | select_model = utils.get_model(FLAGS.select_method, seed) 314 | results, sampler_state = generate_one_curve( 315 | X, y, sampler, score_model, seed, FLAGS.warmstart_size, 316 | FLAGS.batch_size, select_model, c, m, max_dataset_size, 317 | standardize_data, normalize_data, FLAGS.train_horizon) 318 | key = (FLAGS.dataset, FLAGS.sampling_method, FLAGS.score_method, 319 | FLAGS.select_method, m, FLAGS.warmstart_size, FLAGS.batch_size, 320 | c, standardize_data, normalize_data, seed) 321 | sampler_output = sampler_state.to_dict() 322 | results["sampler_output"] = sampler_output 323 | all_results[key] = results 324 | fields = [ 325 | "dataset", "sampler", "score_method", "select_method", 326 | "active percentage", "warmstart size", "batch size", "confusion", 327 | "standardize", "normalize", "seed" 328 | ] 329 | all_results["tuple_keys"] = fields 330 | 331 | if do_save: 332 | filename = ("results_score_" + FLAGS.score_method + 333 | "_select_" + FLAGS.select_method + 334 | "_norm_" + str(normalize_data) + 335 | "_stand_" + str(standardize_data)) 336 | existing_files = gfile.Glob(os.path.join(save_dir, filename + "*.pkl")) 337 | filename = os.path.join(save_dir, 338 | filename + "_" + str(1000+len(existing_files))[1:] + ".pkl") 339 | pickle.dump(all_results, gfile.GFile(filename, "w")) 340 | sys.stdout.flush_file() 341 | 342 | 343 | if __name__ == "__main__": 344 | app.run(main) 345 | -------------------------------------------------------------------------------- /sampling_methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /sampling_methods/bandit_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Bandit wrapper around base AL sampling methods. 16 | 17 | Assumes adversarial multi-armed bandit setting where arms correspond to 18 | mixtures of different AL methods. 19 | 20 | Uses EXP3 algorithm to decide which AL method to use to create the next batch. 21 | Similar to Hsu & Lin 2015, Active Learning by Learning. 22 | https://www.csie.ntu.edu.tw/~htlin/paper/doc/aaai15albl.pdf 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import numpy as np 30 | 31 | from sampling_methods.wrapper_sampler_def import AL_MAPPING, WrapperSamplingMethod 32 | 33 | 34 | class BanditDiscreteSampler(WrapperSamplingMethod): 35 | """Wraps EXP3 around mixtures of indicated methods. 36 | 37 | Uses EXP3 mult-armed bandit algorithm to select sampler methods. 38 | """ 39 | def __init__(self, 40 | X, 41 | y, 42 | seed, 43 | reward_function = lambda AL_acc: AL_acc[-1], 44 | gamma=0.5, 45 | samplers=[{'methods':('margin','uniform'),'weights':(0,1)}, 46 | {'methods':('margin','uniform'),'weights':(1,0)}]): 47 | """Initializes sampler with indicated gamma and arms. 48 | 49 | Args: 50 | X: training data 51 | y: labels, may need to be input into base samplers 52 | seed: seed to use for random sampling 53 | reward_function: reward based on previously observed accuracies. Assumes 54 | that the input is a sequence of observed accuracies. Will ultimately be 55 | a class method and may need access to other class properties. 56 | gamma: weight on uniform mixture. Arm probability updates are a weighted 57 | mixture of uniform and an exponentially weighted distribution. 58 | Lower gamma more aggressively updates based on observed rewards. 59 | samplers: list of dicts with two fields 60 | 'samplers': list of named samplers 61 | 'weights': percentage of batch to allocate to each sampler 62 | """ 63 | 64 | self.name = 'bandit_discrete' 65 | np.random.seed(seed) 66 | self.X = X 67 | self.y = y 68 | self.seed = seed 69 | self.initialize_samplers(samplers) 70 | 71 | self.gamma = gamma 72 | self.n_arms = len(samplers) 73 | self.reward_function = reward_function 74 | 75 | self.pull_history = [] 76 | self.acc_history = [] 77 | self.w = np.ones(self.n_arms) 78 | self.x = np.zeros(self.n_arms) 79 | self.p = self.w / (1.0 * self.n_arms) 80 | self.probs = [] 81 | 82 | def update_vars(self, arm_pulled): 83 | reward = self.reward_function(self.acc_history) 84 | self.x = np.zeros(self.n_arms) 85 | self.x[arm_pulled] = reward / self.p[arm_pulled] 86 | self.w = self.w * np.exp(self.gamma * self.x / self.n_arms) 87 | self.p = ((1.0 - self.gamma) * self.w / sum(self.w) 88 | + self.gamma / self.n_arms) 89 | print(self.p) 90 | self.probs.append(self.p) 91 | 92 | def select_batch_(self, already_selected, N, eval_acc, **kwargs): 93 | """Returns batch of datapoints sampled using mixture of AL_methods. 94 | 95 | Assumes that data has already been shuffled. 96 | 97 | Args: 98 | already_selected: index of datapoints already selected 99 | N: batch size 100 | eval_acc: accuracy of model trained after incorporating datapoints from 101 | last recommended batch 102 | 103 | Returns: 104 | indices of points selected to label 105 | """ 106 | # Update observed reward and arm probabilities 107 | self.acc_history.append(eval_acc) 108 | if len(self.pull_history) > 0: 109 | self.update_vars(self.pull_history[-1]) 110 | # Sample an arm 111 | arm = np.random.choice(range(self.n_arms), p=self.p) 112 | self.pull_history.append(arm) 113 | kwargs['N'] = N 114 | kwargs['already_selected'] = already_selected 115 | sample = self.samplers[arm].select_batch(**kwargs) 116 | return sample 117 | 118 | def to_dict(self): 119 | output = {} 120 | output['samplers'] = self.base_samplers 121 | output['arm_probs'] = self.probs 122 | output['pull_history'] = self.pull_history 123 | output['rewards'] = self.acc_history 124 | return output 125 | 126 | -------------------------------------------------------------------------------- /sampling_methods/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Controls imports to fill up dictionary of different sampling methods. 16 | """ 17 | 18 | from functools import partial 19 | AL_MAPPING = {} 20 | 21 | 22 | def get_base_AL_mapping(): 23 | from sampling_methods.margin_AL import MarginAL 24 | from sampling_methods.informative_diverse import InformativeClusterDiverseSampler 25 | from sampling_methods.hierarchical_clustering_AL import HierarchicalClusterAL 26 | from sampling_methods.uniform_sampling import UniformSampling 27 | from sampling_methods.represent_cluster_centers import RepresentativeClusterMeanSampling 28 | from sampling_methods.graph_density import GraphDensitySampler 29 | from sampling_methods.kcenter_greedy import kCenterGreedy 30 | AL_MAPPING['margin'] = MarginAL 31 | AL_MAPPING['informative_diverse'] = InformativeClusterDiverseSampler 32 | AL_MAPPING['hierarchical'] = HierarchicalClusterAL 33 | AL_MAPPING['uniform'] = UniformSampling 34 | AL_MAPPING['margin_cluster_mean'] = RepresentativeClusterMeanSampling 35 | AL_MAPPING['graph_density'] = GraphDensitySampler 36 | AL_MAPPING['kcenter'] = kCenterGreedy 37 | 38 | 39 | def get_all_possible_arms(): 40 | from sampling_methods.mixture_of_samplers import MixtureOfSamplers 41 | AL_MAPPING['mixture_of_samplers'] = MixtureOfSamplers 42 | 43 | 44 | def get_wrapper_AL_mapping(): 45 | from sampling_methods.bandit_discrete import BanditDiscreteSampler 46 | from sampling_methods.simulate_batch import SimulateBatchSampler 47 | AL_MAPPING['bandit_mixture'] = partial( 48 | BanditDiscreteSampler, 49 | samplers=[{ 50 | 'methods': ['margin', 'uniform'], 51 | 'weights': [0, 1] 52 | }, { 53 | 'methods': ['margin', 'uniform'], 54 | 'weights': [0.25, 0.75] 55 | }, { 56 | 'methods': ['margin', 'uniform'], 57 | 'weights': [0.5, 0.5] 58 | }, { 59 | 'methods': ['margin', 'uniform'], 60 | 'weights': [0.75, 0.25] 61 | }, { 62 | 'methods': ['margin', 'uniform'], 63 | 'weights': [1, 0] 64 | }]) 65 | AL_MAPPING['bandit_discrete'] = partial( 66 | BanditDiscreteSampler, 67 | samplers=[{ 68 | 'methods': ['margin', 'uniform'], 69 | 'weights': [0, 1] 70 | }, { 71 | 'methods': ['margin', 'uniform'], 72 | 'weights': [1, 0] 73 | }]) 74 | AL_MAPPING['simulate_batch_mixture'] = partial( 75 | SimulateBatchSampler, 76 | samplers=({ 77 | 'methods': ['margin', 'uniform'], 78 | 'weights': [1, 0] 79 | }, { 80 | 'methods': ['margin', 'uniform'], 81 | 'weights': [0.5, 0.5] 82 | }, { 83 | 'methods': ['margin', 'uniform'], 84 | 'weights': [0, 1] 85 | }), 86 | n_sims=5, 87 | train_per_sim=10, 88 | return_best_sim=False) 89 | AL_MAPPING['simulate_batch_best_sim'] = partial( 90 | SimulateBatchSampler, 91 | samplers=[{ 92 | 'methods': ['margin', 'uniform'], 93 | 'weights': [1, 0] 94 | }], 95 | n_sims=10, 96 | train_per_sim=10, 97 | return_type='best_sim') 98 | AL_MAPPING['simulate_batch_frequency'] = partial( 99 | SimulateBatchSampler, 100 | samplers=[{ 101 | 'methods': ['margin', 'uniform'], 102 | 'weights': [1, 0] 103 | }], 104 | n_sims=10, 105 | train_per_sim=10, 106 | return_type='frequency') 107 | 108 | def get_mixture_of_samplers(name): 109 | assert 'mixture_of_samplers' in name 110 | if 'mixture_of_samplers' not in AL_MAPPING: 111 | raise KeyError('Mixture of Samplers not yet loaded.') 112 | args = name.split('-')[1:] 113 | samplers = args[0::2] 114 | weights = args[1::2] 115 | weights = [float(w) for w in weights] 116 | assert sum(weights) == 1 117 | mixture = {'methods': samplers, 'weights': weights} 118 | print(mixture) 119 | return partial(AL_MAPPING['mixture_of_samplers'], mixture=mixture) 120 | 121 | 122 | def get_AL_sampler(name): 123 | if name in AL_MAPPING and name != 'mixture_of_samplers': 124 | return AL_MAPPING[name] 125 | if 'mixture_of_samplers' in name: 126 | return get_mixture_of_samplers(name) 127 | raise NotImplementedError('The specified sampler is not available.') 128 | -------------------------------------------------------------------------------- /sampling_methods/graph_density.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Diversity promoting sampling method that uses graph density to determine 16 | most representative points. 17 | 18 | This is an implementation of the method described in 19 | https://www.mpi-inf.mpg.de/fileadmin/inf/d2/Research_projects_files/EbertCVPR2012.pdf 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import copy 27 | 28 | from sklearn.neighbors import kneighbors_graph 29 | from sklearn.metrics import pairwise_distances 30 | import numpy as np 31 | from sampling_methods.sampling_def import SamplingMethod 32 | 33 | 34 | class GraphDensitySampler(SamplingMethod): 35 | """Diversity promoting sampling method that uses graph density to determine 36 | most representative points. 37 | """ 38 | 39 | def __init__(self, X, y, seed): 40 | self.name = 'graph_density' 41 | self.X = X 42 | self.flat_X = self.flatten_X() 43 | # Set gamma for gaussian kernel to be equal to 1/n_features 44 | self.gamma = 1. / self.X.shape[1] 45 | self.compute_graph_density() 46 | 47 | def compute_graph_density(self, n_neighbor=10): 48 | # kneighbors graph is constructed using k=10 49 | connect = kneighbors_graph(self.flat_X, n_neighbor,p=1) 50 | # Make connectivity matrix symmetric, if a point is a k nearest neighbor of 51 | # another point, make it vice versa 52 | neighbors = connect.nonzero() 53 | inds = zip(neighbors[0],neighbors[1]) 54 | # Graph edges are weighted by applying gaussian kernel to manhattan dist. 55 | # By default, gamma for rbf kernel is equal to 1/n_features but may 56 | # get better results if gamma is tuned. 57 | for entry in inds: 58 | i = entry[0] 59 | j = entry[1] 60 | distance = pairwise_distances(self.flat_X[[i]],self.flat_X[[j]],metric='manhattan') 61 | distance = distance[0,0] 62 | weight = np.exp(-distance * self.gamma) 63 | connect[i,j] = weight 64 | connect[j,i] = weight 65 | self.connect = connect 66 | # Define graph density for an observation to be sum of weights for all 67 | # edges to the node representing the datapoint. Normalize sum weights 68 | # by total number of neighbors. 69 | self.graph_density = np.zeros(self.X.shape[0]) 70 | for i in np.arange(self.X.shape[0]): 71 | self.graph_density[i] = connect[i,:].sum() / (connect[i,:]>0).sum() 72 | self.starting_density = copy.deepcopy(self.graph_density) 73 | 74 | def select_batch_(self, N, already_selected, **kwargs): 75 | # If a neighbor has already been sampled, reduce the graph density 76 | # for its direct neighbors to promote diversity. 77 | batch = set() 78 | self.graph_density[already_selected] = min(self.graph_density) - 1 79 | while len(batch) < N: 80 | selected = np.argmax(self.graph_density) 81 | neighbors = (self.connect[selected,:] > 0).nonzero()[1] 82 | self.graph_density[neighbors] = self.graph_density[neighbors] - self.graph_density[selected] 83 | batch.add(selected) 84 | self.graph_density[already_selected] = min(self.graph_density) - 1 85 | self.graph_density[list(batch)] = min(self.graph_density) - 1 86 | return list(batch) 87 | 88 | def to_dict(self): 89 | output = {} 90 | output['connectivity'] = self.connect 91 | output['graph_density'] = self.starting_density 92 | return output -------------------------------------------------------------------------------- /sampling_methods/hierarchical_clustering_AL.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Hierarchical cluster AL method. 16 | 17 | Implements algorithm described in Dasgupta, S and Hsu, D, 18 | "Hierarchical Sampling for Active Learning, 2008 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import numpy as np 26 | from sklearn.cluster import AgglomerativeClustering 27 | from sklearn.decomposition import PCA 28 | from sklearn.neighbors import kneighbors_graph 29 | from sampling_methods.sampling_def import SamplingMethod 30 | from sampling_methods.utils.tree import Tree 31 | 32 | 33 | class HierarchicalClusterAL(SamplingMethod): 34 | """Implements hierarchical cluster AL based method. 35 | 36 | All methods are internal. select_batch_ is called via abstract classes 37 | outward facing method select_batch. 38 | 39 | Default affininity is euclidean and default linkage is ward which links 40 | cluster based on variance reduction. Hence, good results depend on 41 | having normalized and standardized data. 42 | """ 43 | 44 | def __init__(self, X, y, seed, beta=2, affinity='euclidean', linkage='ward', 45 | clustering=None, max_features=None): 46 | """Initializes AL method and fits hierarchical cluster to data. 47 | 48 | Args: 49 | X: data 50 | y: labels for determinining number of clusters as an input to 51 | AgglomerativeClustering 52 | seed: random seed used for sampling datapoints for batch 53 | beta: width of error used to decide admissble labels, higher value of beta 54 | corresponds to wider confidence and less stringent definition of 55 | admissibility 56 | See scikit Aggloerative clustering method for more info 57 | affinity: distance metric used for hierarchical clustering 58 | linkage: linkage method used to determine when to join clusters 59 | clustering: can provide an AgglomerativeClustering that is already fit 60 | max_features: limit number of features used to construct hierarchical 61 | cluster. If specified, PCA is used to perform feature reduction and 62 | the hierarchical clustering is performed using transformed features. 63 | """ 64 | self.name = 'hierarchical' 65 | self.seed = seed 66 | np.random.seed(seed) 67 | # Variables for the hierarchical cluster 68 | self.already_clustered = False 69 | if clustering is not None: 70 | self.model = clustering 71 | self.already_clustered = True 72 | self.n_leaves = None 73 | self.n_components = None 74 | self.children_list = None 75 | self.node_dict = None 76 | self.root = None # Node name, all node instances access through self.tree 77 | self.tree = None 78 | # Variables for the AL algorithm 79 | self.initialized = False 80 | self.beta = beta 81 | self.labels = {} 82 | self.pruning = [] 83 | self.admissible = {} 84 | self.selected_nodes = None 85 | # Data variables 86 | self.classes = None 87 | self.X = X 88 | 89 | classes = list(set(y)) 90 | self.n_classes = len(classes) 91 | if max_features is not None: 92 | transformer = PCA(n_components=max_features) 93 | transformer.fit(X) 94 | self.transformed_X = transformer.fit_transform(X) 95 | #connectivity = kneighbors_graph(self.transformed_X,max_features) 96 | self.model = AgglomerativeClustering( 97 | affinity=affinity, linkage=linkage, n_clusters=len(classes)) 98 | self.fit_cluster(self.transformed_X) 99 | else: 100 | self.model = AgglomerativeClustering( 101 | affinity=affinity, linkage=linkage, n_clusters=len(classes)) 102 | self.fit_cluster(self.X) 103 | self.y = y 104 | 105 | self.y_labels = {} 106 | # Fit cluster and update cluster variables 107 | 108 | self.create_tree() 109 | print('Finished creating hierarchical cluster') 110 | 111 | def fit_cluster(self, X): 112 | if not self.already_clustered: 113 | self.model.fit(X) 114 | self.already_clustered = True 115 | self.n_leaves = self.model.n_leaves_ 116 | self.n_components = self.model.n_components_ 117 | self.children_list = self.model.children_ 118 | 119 | def create_tree(self): 120 | node_dict = {} 121 | for i in range(self.n_leaves): 122 | node_dict[i] = [None, None] 123 | for i in range(len(self.children_list)): 124 | node_dict[self.n_leaves + i] = self.children_list[i] 125 | self.node_dict = node_dict 126 | # The sklearn hierarchical clustering algo numbers leaves which correspond 127 | # to actual datapoints 0 to n_points - 1 and all internal nodes have 128 | # ids greater than n_points - 1 with the root having the highest node id 129 | self.root = max(self.node_dict.keys()) 130 | self.tree = Tree(self.root, self.node_dict) 131 | self.tree.create_child_leaves_mapping(range(self.n_leaves)) 132 | for v in node_dict: 133 | self.admissible[v] = set() 134 | 135 | def get_child_leaves(self, node): 136 | return self.tree.get_child_leaves(node) 137 | 138 | def get_node_leaf_counts(self, node_list): 139 | node_counts = [] 140 | for v in node_list: 141 | node_counts.append(len(self.get_child_leaves(v))) 142 | return np.array(node_counts) 143 | 144 | def get_class_counts(self, y): 145 | """Gets the count of all classes in a sample. 146 | 147 | Args: 148 | y: sample vector for which to perform the count 149 | Returns: 150 | count of classes for the sample vector y, the class order for count will 151 | be the same as that of self.classes 152 | """ 153 | unique, counts = np.unique(y, return_counts=True) 154 | complete_counts = [] 155 | for c in self.classes: 156 | if c not in unique: 157 | complete_counts.append(0) 158 | else: 159 | index = np.where(unique == c)[0][0] 160 | complete_counts.append(counts[index]) 161 | return np.array(complete_counts) 162 | 163 | def observe_labels(self, labeled): 164 | for i in labeled: 165 | self.y_labels[i] = labeled[i] 166 | self.classes = np.array( 167 | sorted(list(set([self.y_labels[k] for k in self.y_labels])))) 168 | self.n_classes = len(self.classes) 169 | 170 | def initialize_algo(self): 171 | self.pruning = [self.root] 172 | self.labels[self.root] = np.random.choice(self.classes) 173 | node = self.tree.get_node(self.root) 174 | node.best_label = self.labels[self.root] 175 | self.selected_nodes = [self.root] 176 | 177 | def get_node_class_probabilities(self, node, y=None): 178 | children = self.get_child_leaves(node) 179 | if y is None: 180 | y_dict = self.y_labels 181 | else: 182 | y_dict = dict(zip(range(len(y)), y)) 183 | labels = [y_dict[c] for c in children if c in y_dict] 184 | # If no labels have been observed, simply return uniform distribution 185 | if len(labels) == 0: 186 | return 0, np.ones(self.n_classes)/self.n_classes 187 | return len(labels), self.get_class_counts(labels) / (len(labels) * 1.0) 188 | 189 | def get_node_upper_lower_bounds(self, node): 190 | n_v, p_v = self.get_node_class_probabilities(node) 191 | # If no observations, return worst possible upper lower bounds 192 | if n_v == 0: 193 | return np.zeros(len(p_v)), np.ones(len(p_v)) 194 | delta = 1. / n_v + np.sqrt(p_v * (1 - p_v) / (1. * n_v)) 195 | return (np.maximum(p_v - delta, np.zeros(len(p_v))), 196 | np.minimum(p_v + delta, np.ones(len(p_v)))) 197 | 198 | def get_node_admissibility(self, node): 199 | p_lb, p_up = self.get_node_upper_lower_bounds(node) 200 | all_other_min = np.vectorize( 201 | lambda i:min([1 - p_up[c] for c in range(len(self.classes)) if c != i])) 202 | lowest_alternative_error = self.beta * all_other_min( 203 | np.arange(len(self.classes))) 204 | return 1 - p_lb < lowest_alternative_error 205 | 206 | def get_adjusted_error(self, node): 207 | _, prob = self.get_node_class_probabilities(node) 208 | error = 1 - prob 209 | admissible = self.get_node_admissibility(node) 210 | not_admissible = np.where(admissible != True)[0] 211 | error[not_admissible] = 1.0 212 | return error 213 | 214 | def get_class_probability_pruning(self, method='lower'): 215 | prob_pruning = [] 216 | for v in self.pruning: 217 | label = self.labels[v] 218 | label_ind = np.where(self.classes == label)[0][0] 219 | if method == 'empirical': 220 | _, v_prob = self.get_node_class_probabilities(v) 221 | else: 222 | lower, upper = self.get_node_upper_lower_bounds(v) 223 | if method == 'lower': 224 | v_prob = lower 225 | elif method == 'upper': 226 | v_prob = upper 227 | else: 228 | raise NotImplementedError 229 | prob = v_prob[label_ind] 230 | prob_pruning.append(prob) 231 | return np.array(prob_pruning) 232 | 233 | def get_pruning_impurity(self, y): 234 | impurity = [] 235 | for v in self.pruning: 236 | _, prob = self.get_node_class_probabilities(v, y) 237 | impurity.append(1-max(prob)) 238 | impurity = np.array(impurity) 239 | weights = self.get_node_leaf_counts(self.pruning) 240 | weights = weights / sum(weights) 241 | return sum(impurity*weights) 242 | 243 | def update_scores(self): 244 | node_list = set(range(self.n_leaves)) 245 | # Loop through generations from bottom to top 246 | while len(node_list) > 0: 247 | parents = set() 248 | for v in node_list: 249 | node = self.tree.get_node(v) 250 | # Update admissible labels for node 251 | admissible = self.get_node_admissibility(v) 252 | admissable_indices = np.where(admissible)[0] 253 | for l in self.classes[admissable_indices]: 254 | self.admissible[v].add(l) 255 | # Calculate score 256 | v_error = self.get_adjusted_error(v) 257 | best_label_ind = np.argmin(v_error) 258 | if admissible[best_label_ind]: 259 | node.best_label = self.classes[best_label_ind] 260 | score = v_error[best_label_ind] 261 | node.split = False 262 | 263 | # Determine if node should be split 264 | if v >= self.n_leaves: # v is not a leaf 265 | if len(admissable_indices) > 0: # There exists an admissible label 266 | # Make sure label set for node so that we can flow to children 267 | # if necessary 268 | assert node.best_label is not None 269 | # Only split if all ancestors are admissible nodes 270 | # This is part of definition of admissible pruning 271 | admissible_ancestors = [len(self.admissible[a]) > 0 for a in 272 | self.tree.get_ancestor(node)] 273 | if all(admissible_ancestors): 274 | left = self.node_dict[v][0] 275 | left_node = self.tree.get_node(left) 276 | right = self.node_dict[v][1] 277 | right_node = self.tree.get_node(right) 278 | node_counts = self.get_node_leaf_counts([v, left, right]) 279 | split_score = (node_counts[1] / node_counts[0] * 280 | left_node.score + node_counts[2] / 281 | node_counts[0] * right_node.score) 282 | if split_score < score: 283 | score = split_score 284 | node.split = True 285 | node.score = score 286 | if node.parent: 287 | parents.add(node.parent.name) 288 | node_list = parents 289 | 290 | def update_pruning_labels(self): 291 | for v in self.selected_nodes: 292 | node = self.tree.get_node(v) 293 | pruning = self.tree.get_pruning(node) 294 | self.pruning.remove(v) 295 | self.pruning.extend(pruning) 296 | # Check that pruning covers all leave nodes 297 | node_counts = self.get_node_leaf_counts(self.pruning) 298 | assert sum(node_counts) == self.n_leaves 299 | # Fill in labels 300 | for v in self.pruning: 301 | node = self.tree.get_node(v) 302 | if node.best_label is None: 303 | node.best_label = node.parent.best_label 304 | self.labels[v] = node.best_label 305 | 306 | def get_fake_labels(self): 307 | fake_y = np.zeros(self.X.shape[0]) 308 | for p in self.pruning: 309 | indices = self.get_child_leaves(p) 310 | fake_y[indices] = self.labels[p] 311 | return fake_y 312 | 313 | def train_using_fake_labels(self, model, X_test, y_test): 314 | classes_labeled = set([self.labels[p] for p in self.pruning]) 315 | if len(classes_labeled) == self.n_classes: 316 | fake_y = self.get_fake_labels() 317 | model.fit(self.X, fake_y) 318 | test_acc = model.score(X_test, y_test) 319 | return test_acc 320 | return 0 321 | 322 | def select_batch_(self, N, already_selected, labeled, y, **kwargs): 323 | # Observe labels for previously recommended batches 324 | self.observe_labels(labeled) 325 | 326 | if not self.initialized: 327 | self.initialize_algo() 328 | self.initialized = True 329 | print('Initialized algo') 330 | 331 | print('Updating scores and pruning for labels from last batch') 332 | self.update_scores() 333 | self.update_pruning_labels() 334 | print('Nodes in pruning: %d' % (len(self.pruning))) 335 | print('Actual impurity for pruning is: %.2f' % 336 | (self.get_pruning_impurity(y))) 337 | 338 | # TODO(lishal): implement multiple selection methods 339 | selected_nodes = set() 340 | weights = self.get_node_leaf_counts(self.pruning) 341 | probs = 1 - self.get_class_probability_pruning() 342 | weights = weights * probs 343 | weights = weights / sum(weights) 344 | batch = [] 345 | 346 | print('Sampling batch') 347 | while len(batch) < N: 348 | node = np.random.choice(list(self.pruning), p=weights) 349 | children = self.get_child_leaves(node) 350 | children = [ 351 | c for c in children if c not in self.y_labels and c not in batch 352 | ] 353 | if len(children) > 0: 354 | selected_nodes.add(node) 355 | batch.append(np.random.choice(children)) 356 | self.selected_nodes = selected_nodes 357 | return batch 358 | 359 | def to_dict(self): 360 | output = {} 361 | output['node_dict'] = self.node_dict 362 | return output 363 | -------------------------------------------------------------------------------- /sampling_methods/informative_diverse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Informative and diverse batch sampler that samples points with small margin 16 | while maintaining same distribution over clusters as entire training data. 17 | 18 | Batch is created by sorting datapoints by increasing margin and then growing 19 | the batch greedily. A point is added to the batch if the result batch still 20 | respects the constraint that the cluster distribution of the batch will 21 | match the cluster distribution of the entire training set. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from sklearn.cluster import MiniBatchKMeans 29 | import numpy as np 30 | from sampling_methods.sampling_def import SamplingMethod 31 | 32 | 33 | class InformativeClusterDiverseSampler(SamplingMethod): 34 | """Selects batch based on informative and diverse criteria. 35 | 36 | Returns highest uncertainty lowest margin points while maintaining 37 | same distribution over clusters as entire dataset. 38 | """ 39 | 40 | def __init__(self, X, y, seed): 41 | self.name = 'informative_and_diverse' 42 | self.X = X 43 | self.flat_X = self.flatten_X() 44 | # y only used for determining how many clusters there should be 45 | # probably not practical to assume we know # of classes before hand 46 | # should also probably scale with dimensionality of data 47 | self.y = y 48 | self.n_clusters = len(list(set(y))) 49 | self.cluster_model = MiniBatchKMeans(n_clusters=self.n_clusters) 50 | self.cluster_data() 51 | 52 | def cluster_data(self): 53 | # Probably okay to always use MiniBatchKMeans 54 | # Should standardize data before clustering 55 | # Can cluster on standardized data but train on raw features if desired 56 | self.cluster_model.fit(self.flat_X) 57 | unique, counts = np.unique(self.cluster_model.labels_, return_counts=True) 58 | self.cluster_prob = counts/sum(counts) 59 | self.cluster_labels = self.cluster_model.labels_ 60 | 61 | def select_batch_(self, model, already_selected, N, **kwargs): 62 | """Returns a batch of size N using informative and diverse selection. 63 | 64 | Args: 65 | model: scikit learn model with decision_function implemented 66 | already_selected: index of datapoints already selected 67 | N: batch size 68 | 69 | Returns: 70 | indices of points selected to add using margin active learner 71 | """ 72 | # TODO(lishal): have MarginSampler and this share margin function 73 | try: 74 | distances = model.decision_function(self.X) 75 | except: 76 | distances = model.predict_proba(self.X) 77 | if len(distances.shape) < 2: 78 | min_margin = abs(distances) 79 | else: 80 | sort_distances = np.sort(distances, 1)[:, -2:] 81 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 82 | rank_ind = np.argsort(min_margin) 83 | rank_ind = [i for i in rank_ind if i not in already_selected] 84 | new_batch_cluster_counts = [0 for _ in range(self.n_clusters)] 85 | new_batch = [] 86 | for i in rank_ind: 87 | if len(new_batch) == N: 88 | break 89 | label = self.cluster_labels[i] 90 | if new_batch_cluster_counts[label] / N < self.cluster_prob[label]: 91 | new_batch.append(i) 92 | new_batch_cluster_counts[label] += 1 93 | n_slot_remaining = N - len(new_batch) 94 | batch_filler = list(set(rank_ind) - set(already_selected) - set(new_batch)) 95 | new_batch.extend(batch_filler[0:n_slot_remaining]) 96 | return new_batch 97 | 98 | def to_dict(self): 99 | output = {} 100 | output['cluster_membership'] = self.cluster_labels 101 | return output 102 | -------------------------------------------------------------------------------- /sampling_methods/kcenter_greedy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Returns points that minimizes the maximum distance of any point to a center. 16 | 17 | Implements the k-Center-Greedy method in 18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for 19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 20 | 21 | Distance metric defaults to l2 distance. Features used to calculate distance 22 | are either raw features or if a model has transform method then uses the output 23 | of model.transform(X). 24 | 25 | Can be extended to a robust k centers algorithm that ignores a certain number of 26 | outlier datapoints. Resulting centers are solution to multiple integer program. 27 | """ 28 | 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | from __future__ import print_function 32 | 33 | import numpy as np 34 | from sklearn.metrics import pairwise_distances 35 | from sampling_methods.sampling_def import SamplingMethod 36 | 37 | 38 | class kCenterGreedy(SamplingMethod): 39 | 40 | def __init__(self, X, y, seed, metric='euclidean'): 41 | self.X = X 42 | self.y = y 43 | self.flat_X = self.flatten_X() 44 | self.name = 'kcenter' 45 | self.features = self.flat_X 46 | self.metric = metric 47 | self.min_distances = None 48 | self.n_obs = self.X.shape[0] 49 | self.already_selected = [] 50 | 51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False): 52 | """Update min distances given cluster centers. 53 | 54 | Args: 55 | cluster_centers: indices of cluster centers 56 | only_new: only calculate distance for newly selected points and update 57 | min_distances. 58 | rest_dist: whether to reset min_distances. 59 | """ 60 | 61 | if reset_dist: 62 | self.min_distances = None 63 | if only_new: 64 | cluster_centers = [d for d in cluster_centers 65 | if d not in self.already_selected] 66 | if cluster_centers: 67 | # Update min_distances for all examples given new cluster center. 68 | x = self.features[cluster_centers] 69 | dist = pairwise_distances(self.features, x, metric=self.metric) 70 | 71 | if self.min_distances is None: 72 | self.min_distances = np.min(dist, axis=1).reshape(-1,1) 73 | else: 74 | self.min_distances = np.minimum(self.min_distances, dist) 75 | 76 | def select_batch_(self, model, already_selected, N, **kwargs): 77 | """ 78 | Diversity promoting active learning method that greedily forms a batch 79 | to minimize the maximum distance to a cluster center among all unlabeled 80 | datapoints. 81 | 82 | Args: 83 | model: model with scikit-like API with decision_function implemented 84 | already_selected: index of datapoints already selected 85 | N: batch size 86 | 87 | Returns: 88 | indices of points selected to minimize distance to cluster centers 89 | """ 90 | 91 | try: 92 | # Assumes that the transform function takes in original data and not 93 | # flattened data. 94 | print('Getting transformed features...') 95 | self.features = model.transform(self.X) 96 | print('Calculating distances...') 97 | self.update_distances(already_selected, only_new=False, reset_dist=True) 98 | except: 99 | print('Using flat_X as features.') 100 | self.update_distances(already_selected, only_new=True, reset_dist=False) 101 | 102 | new_batch = [] 103 | 104 | for _ in range(N): 105 | if self.already_selected is None: 106 | # Initialize centers with a randomly selected datapoint 107 | ind = np.random.choice(np.arange(self.n_obs)) 108 | else: 109 | ind = np.argmax(self.min_distances) 110 | # New examples should not be in already selected since those points 111 | # should have min_distance of zero to a cluster center. 112 | assert ind not in already_selected 113 | 114 | self.update_distances([ind], only_new=True, reset_dist=False) 115 | new_batch.append(ind) 116 | print('Maximum distance from cluster centers is %0.2f' 117 | % max(self.min_distances)) 118 | 119 | 120 | self.already_selected = already_selected 121 | 122 | return new_batch 123 | 124 | -------------------------------------------------------------------------------- /sampling_methods/margin_AL.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Margin based AL method. 16 | 17 | Samples in batches based on margin scores. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | from sampling_methods.sampling_def import SamplingMethod 26 | 27 | 28 | class MarginAL(SamplingMethod): 29 | def __init__(self, X, y, seed): 30 | self.X = X 31 | self.y = y 32 | self.name = 'margin' 33 | 34 | def select_batch_(self, model, already_selected, N, **kwargs): 35 | """Returns batch of datapoints with smallest margin/highest uncertainty. 36 | 37 | For binary classification, can just take the absolute distance to decision 38 | boundary for each point. 39 | For multiclass classification, must consider the margin between distance for 40 | top two most likely classes. 41 | 42 | Args: 43 | model: scikit learn model with decision_function implemented 44 | already_selected: index of datapoints already selected 45 | N: batch size 46 | 47 | Returns: 48 | indices of points selected to add using margin active learner 49 | """ 50 | 51 | try: 52 | distances = model.decision_function(self.X) 53 | except: 54 | distances = model.predict_proba(self.X) 55 | if len(distances.shape) < 2: 56 | min_margin = abs(distances) 57 | else: 58 | sort_distances = np.sort(distances, 1)[:, -2:] 59 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 60 | rank_ind = np.argsort(min_margin) 61 | rank_ind = [i for i in rank_ind if i not in already_selected] 62 | active_samples = rank_ind[0:N] 63 | return active_samples 64 | 65 | -------------------------------------------------------------------------------- /sampling_methods/mixture_of_samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mixture of base sampling strategies 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import copy 24 | 25 | from sampling_methods.sampling_def import SamplingMethod 26 | from sampling_methods.constants import AL_MAPPING, get_base_AL_mapping 27 | 28 | get_base_AL_mapping() 29 | 30 | 31 | class MixtureOfSamplers(SamplingMethod): 32 | """Samples according to mixture of base sampling methods. 33 | 34 | If duplicate points are selected by the mixed strategies when forming the batch 35 | then the remaining slots are divided according to mixture weights and 36 | another partial batch is requested until the batch is full. 37 | """ 38 | def __init__(self, 39 | X, 40 | y, 41 | seed, 42 | mixture={'methods': ('margin', 'uniform'), 43 | 'weight': (0.5, 0.5)}, 44 | samplers=None): 45 | self.X = X 46 | self.y = y 47 | self.name = 'mixture_of_samplers' 48 | self.sampling_methods = mixture['methods'] 49 | self.sampling_weights = dict(zip(mixture['methods'], mixture['weights'])) 50 | self.seed = seed 51 | # A list of initialized samplers is allowed as an input because 52 | # for AL_methods that search over different mixtures, may want mixtures to 53 | # have shared AL_methods so that initialization is only performed once for 54 | # computation intensive methods like HierarchicalClusteringAL and 55 | # states are shared between mixtures. 56 | # If initialized samplers are not provided, initialize them ourselves. 57 | if samplers is None: 58 | self.samplers = {} 59 | self.initialize(self.sampling_methods) 60 | else: 61 | self.samplers = samplers 62 | self.history = [] 63 | 64 | def initialize(self, samplers): 65 | self.samplers = {} 66 | for s in samplers: 67 | self.samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed) 68 | 69 | def select_batch_(self, already_selected, N, **kwargs): 70 | """Returns batch of datapoints selected according to mixture weights. 71 | 72 | Args: 73 | already_included: index of datapoints already selected 74 | N: batch size 75 | 76 | Returns: 77 | indices of points selected to add using margin active learner 78 | """ 79 | kwargs['already_selected'] = copy.copy(already_selected) 80 | inds = set() 81 | self.selected_by_sampler = {} 82 | for s in self.sampling_methods: 83 | self.selected_by_sampler[s] = [] 84 | effective_N = 0 85 | while len(inds) < N: 86 | effective_N += N - len(inds) 87 | for s in self.sampling_methods: 88 | if len(inds) < N: 89 | batch_size = min(max(int(self.sampling_weights[s] * effective_N), 1), N) 90 | sampler = self.samplers[s] 91 | kwargs['N'] = batch_size 92 | s_inds = sampler.select_batch(**kwargs) 93 | for ind in s_inds: 94 | if ind not in self.selected_by_sampler[s]: 95 | self.selected_by_sampler[s].append(ind) 96 | s_inds = [d for d in s_inds if d not in inds] 97 | s_inds = s_inds[0 : min(len(s_inds), N-len(inds))] 98 | inds.update(s_inds) 99 | self.history.append(copy.deepcopy(self.selected_by_sampler)) 100 | return list(inds) 101 | 102 | def to_dict(self): 103 | output = {} 104 | output['history'] = self.history 105 | output['samplers'] = self.sampling_methods 106 | output['mixture_weights'] = self.sampling_weights 107 | for s in self.samplers: 108 | s_output = self.samplers[s].to_dict() 109 | output[s] = s_output 110 | return output 111 | -------------------------------------------------------------------------------- /sampling_methods/represent_cluster_centers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Another informative and diverse sampler that mirrors the algorithm described 16 | in Xu, et. al., Representative Sampling for Text Classification Using 17 | Support Vector Machines, 2003 18 | 19 | Batch is created by clustering points within the margin of the classifier and 20 | choosing points closest to the k centroids. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from sklearn.cluster import MiniBatchKMeans 28 | import numpy as np 29 | from sampling_methods.sampling_def import SamplingMethod 30 | 31 | 32 | class RepresentativeClusterMeanSampling(SamplingMethod): 33 | """Selects batch based on informative and diverse criteria. 34 | 35 | Returns points within the margin of the classifier that are closest to the 36 | k-means centers of those points. 37 | """ 38 | 39 | def __init__(self, X, y, seed): 40 | self.name = 'cluster_mean' 41 | self.X = X 42 | self.flat_X = self.flatten_X() 43 | self.y = y 44 | self.seed = seed 45 | 46 | def select_batch_(self, model, N, already_selected, **kwargs): 47 | # Probably okay to always use MiniBatchKMeans 48 | # Should standardize data before clustering 49 | # Can cluster on standardized data but train on raw features if desired 50 | try: 51 | distances = model.decision_function(self.X) 52 | except: 53 | distances = model.predict_proba(self.X) 54 | if len(distances.shape) < 2: 55 | min_margin = abs(distances) 56 | else: 57 | sort_distances = np.sort(distances, 1)[:, -2:] 58 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 59 | rank_ind = np.argsort(min_margin) 60 | rank_ind = [i for i in rank_ind if i not in already_selected] 61 | 62 | distances = abs(model.decision_function(self.X)) 63 | min_margin_by_class = np.min(abs(distances[already_selected]),axis=0) 64 | unlabeled_in_margin = np.array([i for i in range(len(self.y)) 65 | if i not in already_selected and 66 | any(distances[i] 2: 42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:]))) 43 | return flat_X 44 | 45 | 46 | @abc.abstractmethod 47 | def select_batch_(self): 48 | return 49 | 50 | def select_batch(self, **kwargs): 51 | return self.select_batch_(**kwargs) 52 | 53 | def to_dict(self): 54 | return None -------------------------------------------------------------------------------- /sampling_methods/simulate_batch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ Select a new batch based on results of simulated trajectories.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import copy 22 | import math 23 | 24 | import numpy as np 25 | 26 | from sampling_methods.wrapper_sampler_def import AL_MAPPING 27 | from sampling_methods.wrapper_sampler_def import WrapperSamplingMethod 28 | 29 | 30 | class SimulateBatchSampler(WrapperSamplingMethod): 31 | """Creates batch based on trajectories simulated using smaller batch sizes. 32 | 33 | Current support use case: simulate smaller batches than the batch size 34 | actually indicated to emulate which points would be selected in a 35 | smaller batch setting. This method can do better than just selecting 36 | a batch straight out if smaller batches perform better and the simulations 37 | are informative enough and are not hurt too much by labeling noise. 38 | """ 39 | 40 | def __init__(self, 41 | X, 42 | y, 43 | seed, 44 | samplers=[{'methods': ('margin', 'uniform'),'weight': (1, 0)}], 45 | n_sims=10, 46 | train_per_sim=10, 47 | return_type='best_sim'): 48 | """ Initialize sampler with options. 49 | 50 | Args: 51 | X: training data 52 | y: labels may be used by base sampling methods 53 | seed: seed for np.random 54 | samplers: list of dicts with two fields 55 | 'samplers': list of named samplers 56 | 'weights': percentage of batch to allocate to each sampler 57 | n_sims: number of total trajectories to simulate 58 | train_per_sim: number of minibatches to split the batch into 59 | return_type: two return types supported right now 60 | best_sim: return points selected by the best trajectory 61 | frequency: returns points selected the most over all trajectories 62 | """ 63 | self.name = 'simulate_batch' 64 | self.X = X 65 | self.y = y 66 | self.seed = seed 67 | self.n_sims = n_sims 68 | self.train_per_sim = train_per_sim 69 | self.return_type = return_type 70 | self.samplers_list = samplers 71 | self.initialize_samplers(self.samplers_list) 72 | self.trace = [] 73 | self.selected = [] 74 | np.random.seed(seed) 75 | 76 | def simulate_batch(self, sampler, N, already_selected, y, model, X_test, 77 | y_test, **kwargs): 78 | """Simulates smaller batches by using hallucinated y to select next batch. 79 | 80 | Assumes that select_batch is only dependent on already_selected and not on 81 | any other states internal to the sampler. i.e. this would not work with 82 | BanditDiscreteSampler but will work with margin, hierarchical, and uniform. 83 | 84 | Args: 85 | sampler: dict with two fields 86 | 'samplers': list of named samplers 87 | 'weights': percentage of batch to allocate to each sampler 88 | N: batch size 89 | already_selected: indices already labeled 90 | y: y to use for training 91 | model: model to use for margin calc 92 | X_test: validaiton data 93 | y_test: validation labels 94 | 95 | Returns: 96 | - mean accuracy 97 | - indices selected by best hallucinated trajectory 98 | - best accuracy achieved by one of the trajectories 99 | """ 100 | minibatch = max(int(math.ceil(N / self.train_per_sim)), 1) 101 | results = [] 102 | best_acc = 0 103 | best_inds = [] 104 | self.selected = [] 105 | n_minibatch = int(N/minibatch) + (N % minibatch > 0) 106 | 107 | for _ in range(self.n_sims): 108 | inds = [] 109 | hallucinated_y = [] 110 | 111 | # Copy these objects to make sure they are not modified while simulating 112 | # trajectories as they are used later by the main run_experiment script. 113 | kwargs['already_selected'] = copy.copy(already_selected) 114 | kwargs['y'] = copy.copy(y) 115 | # Assumes that model has already by fit using all labeled data so 116 | # the probabilities can be used immediately to hallucinate labels 117 | kwargs['model'] = copy.deepcopy(model) 118 | 119 | for _ in range(n_minibatch): 120 | batch_size = min(minibatch, N-len(inds)) 121 | if batch_size > 0: 122 | kwargs['N'] = batch_size 123 | new_inds = sampler.select_batch(**kwargs) 124 | inds.extend(new_inds) 125 | 126 | # All models need to have predict_proba method 127 | probs = kwargs['model'].predict_proba(self.X[new_inds]) 128 | # Hallucinate labels for selected datapoints to be label 129 | # using class probabilities from model 130 | try: 131 | classes = kwargs['model'].best_estimator_.classes_ 132 | except: 133 | classes = kwargs['model'].classes_ 134 | new_y = ([ 135 | np.random.choice(classes, p=probs[i, :]) 136 | for i in range(batch_size) 137 | ]) 138 | hallucinated_y.extend(new_y) 139 | # Not saving already_selected here, if saving then should sort 140 | # only for the input to fit but preserve ordering of indices in 141 | # already_selected 142 | kwargs['already_selected'] = sorted(kwargs['already_selected'] 143 | + new_inds) 144 | kwargs['y'][new_inds] = new_y 145 | kwargs['model'].fit(self.X[kwargs['already_selected']], 146 | kwargs['y'][kwargs['already_selected']]) 147 | acc_hallucinated = kwargs['model'].score(X_test, y_test) 148 | if acc_hallucinated > best_acc: 149 | best_acc = acc_hallucinated 150 | best_inds = inds 151 | kwargs['model'].fit(self.X[kwargs['already_selected']], 152 | y[kwargs['already_selected']]) 153 | # Useful to know how accuracy compares for model trained on hallucinated 154 | # labels vs trained on true labels. But can remove this train to speed 155 | # up simulations. Won't speed up significantly since many more models 156 | # are being trained inside the loop above. 157 | acc_true = kwargs['model'].score(X_test, y_test) 158 | results.append([acc_hallucinated, acc_true]) 159 | print('Hallucinated acc: %.3f, Actual Acc: %.3f' % (acc_hallucinated, 160 | acc_true)) 161 | 162 | # Save trajectory for reference 163 | t = {} 164 | t['arm'] = sampler 165 | t['data_size'] = len(kwargs['already_selected']) 166 | t['inds'] = inds 167 | t['y_hal'] = hallucinated_y 168 | t['acc_hal'] = acc_hallucinated 169 | t['acc_true'] = acc_true 170 | self.trace.append(t) 171 | self.selected.extend(inds) 172 | # Delete created copies 173 | del kwargs['model'] 174 | del kwargs['already_selected'] 175 | results = np.array(results) 176 | return np.mean(results, axis=0), best_inds, best_acc 177 | 178 | def sampler_select_batch(self, sampler, N, already_selected, y, model, X_test, y_test, **kwargs): 179 | """Calculate the performance of the model if the batch had been selected using the base method without simulation. 180 | 181 | Args: 182 | sampler: dict with two fields 183 | 'samplers': list of named samplers 184 | 'weights': percentage of batch to allocate to each sampler 185 | N: batch size 186 | already_selected: indices already selected 187 | y: labels to use for training 188 | model: model to use for training 189 | X_test, y_test: validation set 190 | 191 | Returns: 192 | - indices selected by base method 193 | - validation accuracy of model trained on new batch 194 | """ 195 | m = copy.deepcopy(model) 196 | kwargs['y'] = y 197 | kwargs['model'] = m 198 | kwargs['already_selected'] = copy.copy(already_selected) 199 | inds = [] 200 | kwargs['N'] = N 201 | inds.extend(sampler.select_batch(**kwargs)) 202 | kwargs['already_selected'] = sorted(kwargs['already_selected'] + inds) 203 | 204 | m.fit(self.X[kwargs['already_selected']], y[kwargs['already_selected']]) 205 | acc = m.score(X_test, y_test) 206 | del m 207 | del kwargs['already_selected'] 208 | return inds, acc 209 | 210 | def select_batch_(self, N, already_selected, y, model, 211 | X_test, y_test, **kwargs): 212 | """ Returns a batch of size N selected by using the best sampler in simulation 213 | 214 | Args: 215 | samplers: list of sampling methods represented by dict with two fields 216 | 'samplers': list of named samplers 217 | 'weights': percentage of batch to allocate to each sampler 218 | N: batch size 219 | already_selected: indices of datapoints already labeled 220 | y: actual labels, used to compare simulation with actual 221 | model: training model to use to evaluate different samplers. Model must 222 | have a predict_proba method with same signature as that in sklearn 223 | n_sims: the number of simulations to perform for each sampler 224 | minibatch: batch size to use for simulation 225 | """ 226 | 227 | results = [] 228 | 229 | # THE INPUTS CANNOT BE MODIFIED SO WE MAKE COPIES FOR THE CHECK LATER 230 | # Should check model but kernel_svm does not have coef_ so need better 231 | # handling here 232 | copy_selected = copy.copy(already_selected) 233 | copy_y = copy.copy(y) 234 | 235 | for s in self.samplers: 236 | sim_results, sim_inds, sim_acc = self.simulate_batch( 237 | s, N, already_selected, y, model, X_test, y_test, **kwargs) 238 | real_inds, acc = self.sampler_select_batch( 239 | s, N, already_selected, y, model, X_test, y_test, **kwargs) 240 | print('Best simulated acc: %.3f, Actual acc: %.3f' % (sim_acc, acc)) 241 | results.append([sim_results, sim_inds, real_inds, acc]) 242 | best_s = np.argmax([r[0][0] for r in results]) 243 | 244 | # Make sure that model object fed in did not change during simulations 245 | assert all(y == copy_y) 246 | assert all([copy_selected[i] == already_selected[i] 247 | for i in range(len(already_selected))]) 248 | 249 | # Return indices based on return type specified 250 | if self.return_type == 'best_sim': 251 | return results[best_s][1] 252 | elif self.return_type == 'frequency': 253 | unique, counts = np.unique(self.selected, return_counts=True) 254 | argcount = np.argsort(-counts) 255 | return list(unique[argcount[0:N]]) 256 | return results[best_s][2] 257 | 258 | def to_dict(self): 259 | output = {} 260 | output['simulated_trajectories'] = self.trace 261 | return output 262 | -------------------------------------------------------------------------------- /sampling_methods/uniform_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Uniform sampling method. 16 | 17 | Samples in batches. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | 26 | from sampling_methods.sampling_def import SamplingMethod 27 | 28 | 29 | class UniformSampling(SamplingMethod): 30 | 31 | def __init__(self, X, y, seed): 32 | self.X = X 33 | self.y = y 34 | self.name = 'uniform' 35 | np.random.seed(seed) 36 | 37 | def select_batch_(self, already_selected, N, **kwargs): 38 | """Returns batch of randomly sampled datapoints. 39 | 40 | Assumes that data has already been shuffled. 41 | 42 | Args: 43 | already_selected: index of datapoints already selected 44 | N: batch size 45 | 46 | Returns: 47 | indices of points selected to label 48 | """ 49 | 50 | # This is uniform given the remaining pool but biased wrt the entire pool. 51 | sample = [i for i in range(self.X.shape[0]) if i not in already_selected] 52 | return sample[0:N] 53 | -------------------------------------------------------------------------------- /sampling_methods/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /sampling_methods/utils/tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Node and Tree class to support hierarchical clustering AL method. 16 | 17 | Assumed to be binary tree. 18 | 19 | Node class is used to represent each node in a hierarchical clustering. 20 | Each node has certain properties that are used in the AL method. 21 | 22 | Tree class is used to traverse a hierarchical clustering. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import copy 30 | 31 | 32 | class Node(object): 33 | """Node class for hierarchical clustering. 34 | 35 | Initialized with name and left right children. 36 | """ 37 | 38 | def __init__(self, name, left=None, right=None): 39 | self.name = name 40 | self.left = left 41 | self.right = right 42 | self.is_leaf = left is None and right is None 43 | self.parent = None 44 | # Fields for hierarchical clustering AL 45 | self.score = 1.0 46 | self.split = False 47 | self.best_label = None 48 | self.weight = None 49 | 50 | def set_parent(self, parent): 51 | self.parent = parent 52 | 53 | 54 | class Tree(object): 55 | """Tree object for traversing a binary tree. 56 | 57 | Most methods apply to trees in general with the exception of get_pruning 58 | which is specific to the hierarchical clustering AL method. 59 | """ 60 | 61 | def __init__(self, root, node_dict): 62 | """Initializes tree and creates all nodes in node_dict. 63 | 64 | Args: 65 | root: id of the root node 66 | node_dict: dictionary with node_id as keys and entries indicating 67 | left and right child of node respectively. 68 | """ 69 | self.node_dict = node_dict 70 | self.root = self.make_tree(root) 71 | self.nodes = {} 72 | self.leaves_mapping = {} 73 | self.fill_parents() 74 | self.n_leaves = None 75 | 76 | def print_tree(self, node, max_depth): 77 | """Helper function to print out tree for debugging.""" 78 | node_list = [node] 79 | output = "" 80 | level = 0 81 | while level < max_depth and len(node_list): 82 | children = set() 83 | for n in node_list: 84 | node = self.get_node(n) 85 | output += ("\t"*level+"node %d: score %.2f, weight %.2f" % 86 | (node.name, node.score, node.weight)+"\n") 87 | if node.left: 88 | children.add(node.left.name) 89 | if node.right: 90 | children.add(node.right.name) 91 | level += 1 92 | node_list = children 93 | return print(output) 94 | 95 | def make_tree(self, node_id): 96 | if node_id is not None: 97 | return Node(node_id, 98 | self.make_tree(self.node_dict[node_id][0]), 99 | self.make_tree(self.node_dict[node_id][1])) 100 | 101 | def fill_parents(self): 102 | # Setting parent and storing nodes in dict for fast access 103 | def rec(pointer, parent): 104 | if pointer is not None: 105 | self.nodes[pointer.name] = pointer 106 | pointer.set_parent(parent) 107 | rec(pointer.left, pointer) 108 | rec(pointer.right, pointer) 109 | rec(self.root, None) 110 | 111 | def get_node(self, node_id): 112 | return self.nodes[node_id] 113 | 114 | def get_ancestor(self, node): 115 | ancestors = [] 116 | if isinstance(node, int): 117 | node = self.get_node(node) 118 | while node.name != self.root.name: 119 | node = node.parent 120 | ancestors.append(node.name) 121 | return ancestors 122 | 123 | def fill_weights(self): 124 | for v in self.node_dict: 125 | node = self.get_node(v) 126 | node.weight = len(self.leaves_mapping[v]) / (1.0 * self.n_leaves) 127 | 128 | def create_child_leaves_mapping(self, leaves): 129 | """DP for creating child leaves mapping. 130 | 131 | Storing in dict to save recompute. 132 | """ 133 | self.n_leaves = len(leaves) 134 | for v in leaves: 135 | self.leaves_mapping[v] = [v] 136 | node_list = set([self.get_node(v).parent for v in leaves]) 137 | while node_list: 138 | to_fill = copy.copy(node_list) 139 | for v in node_list: 140 | if (v.left.name in self.leaves_mapping 141 | and v.right.name in self.leaves_mapping): 142 | to_fill.remove(v) 143 | self.leaves_mapping[v.name] = (self.leaves_mapping[v.left.name] + 144 | self.leaves_mapping[v.right.name]) 145 | if v.parent is not None: 146 | to_fill.add(v.parent) 147 | node_list = to_fill 148 | self.fill_weights() 149 | 150 | def get_child_leaves(self, node): 151 | return self.leaves_mapping[node] 152 | 153 | def get_pruning(self, node): 154 | if node.split: 155 | return self.get_pruning(node.left) + self.get_pruning(node.right) 156 | else: 157 | return [node.name] 158 | 159 | -------------------------------------------------------------------------------- /sampling_methods/utils/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sampling_methods.utils.tree.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import unittest 22 | from sampling_methods.utils import tree 23 | 24 | 25 | class TreeTest(unittest.TestCase): 26 | 27 | def setUp(self): 28 | node_dict = { 29 | 1: (2, 3), 30 | 2: (4, 5), 31 | 3: (6, 7), 32 | 4: [None, None], 33 | 5: [None, None], 34 | 6: [None, None], 35 | 7: [None, None] 36 | } 37 | self.tree = tree.Tree(1, node_dict) 38 | self.tree.create_child_leaves_mapping([4, 5, 6, 7]) 39 | node = self.tree.get_node(1) 40 | node.split = True 41 | node = self.tree.get_node(2) 42 | node.split = True 43 | 44 | def assertNode(self, node, name, left, right): 45 | self.assertEqual(node.name, name) 46 | self.assertEqual(node.left.name, left) 47 | self.assertEqual(node.right.name, right) 48 | 49 | def testTreeRootSetCorrectly(self): 50 | self.assertNode(self.tree.root, 1, 2, 3) 51 | 52 | def testGetNode(self): 53 | node = self.tree.get_node(1) 54 | assert isinstance(node, tree.Node) 55 | self.assertEqual(node.name, 1) 56 | 57 | def testFillParent(self): 58 | node = self.tree.get_node(3) 59 | self.assertEqual(node.parent.name, 1) 60 | 61 | def testGetAncestors(self): 62 | ancestors = self.tree.get_ancestor(5) 63 | self.assertTrue(all([a in ancestors for a in [1, 2]])) 64 | 65 | def testChildLeaves(self): 66 | leaves = self.tree.get_child_leaves(3) 67 | self.assertTrue(all([c in leaves for c in [6, 7]])) 68 | 69 | def testFillWeights(self): 70 | node = self.tree.get_node(3) 71 | self.assertEqual(node.weight, 0.5) 72 | 73 | def testGetPruning(self): 74 | node = self.tree.get_node(1) 75 | pruning = self.tree.get_pruning(node) 76 | self.assertTrue(all([n in pruning for n in [3, 4, 5]])) 77 | 78 | if __name__ == '__main__': 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /sampling_methods/wrapper_sampler_def.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Abstract class for wrapper sampling methods that call base sampling methods. 16 | 17 | Provides interface to sampling methods that allow same signature 18 | for select_batch. Each subclass implements select_batch_ with the desired 19 | signature for readability. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import abc 27 | 28 | from sampling_methods.constants import AL_MAPPING 29 | from sampling_methods.constants import get_all_possible_arms 30 | from sampling_methods.sampling_def import SamplingMethod 31 | 32 | get_all_possible_arms() 33 | 34 | 35 | class WrapperSamplingMethod(SamplingMethod): 36 | __metaclass__ = abc.ABCMeta 37 | 38 | def initialize_samplers(self, mixtures): 39 | methods = [] 40 | for m in mixtures: 41 | methods += m['methods'] 42 | methods = set(methods) 43 | self.base_samplers = {} 44 | for s in methods: 45 | self.base_samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed) 46 | self.samplers = [] 47 | for m in mixtures: 48 | self.samplers.append( 49 | AL_MAPPING['mixture_of_samplers'](self.X, self.y, self.seed, m, 50 | self.base_samplers)) 51 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /utils/allconv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements allconv model in keras using tensorflow backend.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | 22 | import keras 23 | import keras.backend as K 24 | from keras.layers import Activation 25 | from keras.layers import Conv2D 26 | from keras.layers import Dropout 27 | from keras.layers import GlobalAveragePooling2D 28 | from keras.models import Sequential 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | 34 | class AllConv(object): 35 | """allconv network that matches sklearn api.""" 36 | 37 | def __init__(self, 38 | random_state=1, 39 | epochs=50, 40 | batch_size=32, 41 | solver='rmsprop', 42 | learning_rate=0.001, 43 | lr_decay=0.): 44 | # params 45 | self.solver = solver 46 | self.epochs = epochs 47 | self.batch_size = batch_size 48 | self.learning_rate = learning_rate 49 | self.lr_decay = lr_decay 50 | # data 51 | self.encode_map = None 52 | self.decode_map = None 53 | self.model = None 54 | self.random_state = random_state 55 | self.n_classes = None 56 | 57 | def build_model(self, X): 58 | # assumes that data axis order is same as the backend 59 | input_shape = X.shape[1:] 60 | np.random.seed(self.random_state) 61 | tf.set_random_seed(self.random_state) 62 | 63 | model = Sequential() 64 | model.add(Conv2D(96, (3, 3), padding='same', 65 | input_shape=input_shape, name='conv1')) 66 | model.add(Activation('relu')) 67 | model.add(Conv2D(96, (3, 3), name='conv2', padding='same')) 68 | model.add(Activation('relu')) 69 | model.add(Conv2D(96, (3, 3), strides=(2, 2), padding='same', name='conv3')) 70 | model.add(Activation('relu')) 71 | model.add(Dropout(0.5)) 72 | 73 | model.add(Conv2D(192, (3, 3), name='conv4', padding='same')) 74 | model.add(Activation('relu')) 75 | model.add(Conv2D(192, (3, 3), name='conv5', padding='same')) 76 | model.add(Activation('relu')) 77 | model.add(Conv2D(192, (3, 3), strides=(2, 2), name='conv6', padding='same')) 78 | model.add(Activation('relu')) 79 | model.add(Dropout(0.5)) 80 | 81 | model.add(Conv2D(192, (3, 3), name='conv7', padding='same')) 82 | model.add(Activation('relu')) 83 | model.add(Conv2D(192, (1, 1), name='conv8', padding='valid')) 84 | model.add(Activation('relu')) 85 | model.add(Conv2D(10, (1, 1), name='conv9', padding='valid')) 86 | 87 | model.add(GlobalAveragePooling2D()) 88 | model.add(Activation('softmax', name='activation_top')) 89 | model.summary() 90 | 91 | try: 92 | optimizer = getattr(keras.optimizers, self.solver) 93 | except: 94 | raise NotImplementedError('optimizer not implemented in keras') 95 | # All optimizers with the exception of nadam take decay as named arg 96 | try: 97 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay) 98 | except: 99 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay) 100 | 101 | model.compile(loss='categorical_crossentropy', 102 | optimizer=opt, 103 | metrics=['accuracy']) 104 | # Save initial weights so that model can be retrained with same 105 | # initialization 106 | self.initial_weights = copy.deepcopy(model.get_weights()) 107 | 108 | self.model = model 109 | 110 | def create_y_mat(self, y): 111 | y_encode = self.encode_y(y) 112 | y_encode = np.reshape(y_encode, (len(y_encode), 1)) 113 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes) 114 | return y_mat 115 | 116 | # Add handling for classes that do not start counting from 0 117 | def encode_y(self, y): 118 | if self.encode_map is None: 119 | self.classes_ = sorted(list(set(y))) 120 | self.n_classes = len(self.classes_) 121 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_)))) 122 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_)) 123 | mapper = lambda x: self.encode_map[x] 124 | transformed_y = np.array(map(mapper, y)) 125 | return transformed_y 126 | 127 | def decode_y(self, y): 128 | mapper = lambda x: self.decode_map[x] 129 | transformed_y = np.array(map(mapper, y)) 130 | return transformed_y 131 | 132 | def fit(self, X_train, y_train, sample_weight=None): 133 | y_mat = self.create_y_mat(y_train) 134 | 135 | if self.model is None: 136 | self.build_model(X_train) 137 | 138 | # We don't want incremental fit so reset learning rate and weights 139 | K.set_value(self.model.optimizer.lr, self.learning_rate) 140 | self.model.set_weights(self.initial_weights) 141 | self.model.fit( 142 | X_train, 143 | y_mat, 144 | batch_size=self.batch_size, 145 | epochs=self.epochs, 146 | shuffle=True, 147 | sample_weight=sample_weight, 148 | verbose=0) 149 | 150 | def predict(self, X_val): 151 | predicted = self.model.predict(X_val) 152 | return predicted 153 | 154 | def score(self, X_val, val_y): 155 | y_mat = self.create_y_mat(val_y) 156 | val_acc = self.model.evaluate(X_val, y_mat)[1] 157 | return val_acc 158 | 159 | def decision_function(self, X): 160 | return self.predict(X) 161 | 162 | def transform(self, X): 163 | model = self.model 164 | inp = [model.input] 165 | activations = [] 166 | 167 | # Get activations of the last conv layer. 168 | output = [layer.output for layer in model.layers if 169 | layer.name == 'conv9'][0] 170 | func = K.function(inp + [K.learning_phase()], [output]) 171 | for i in range(int(X.shape[0]/self.batch_size) + 1): 172 | minibatch = X[i * self.batch_size 173 | : min(X.shape[0], (i+1) * self.batch_size)] 174 | list_inputs = [minibatch, 0.] 175 | # Learning phase. 0 = Test mode (no dropout or batch normalization) 176 | layer_output = func(list_inputs)[0] 177 | activations.append(layer_output) 178 | output = np.vstack(tuple(activations)) 179 | output = np.reshape(output, (output.shape[0],np.product(output.shape[1:]))) 180 | return output 181 | 182 | def get_params(self, deep = False): 183 | params = {} 184 | params['solver'] = self.solver 185 | params['epochs'] = self.epochs 186 | params['batch_size'] = self.batch_size 187 | params['learning_rate'] = self.learning_rate 188 | params['weight_decay'] = self.lr_decay 189 | if deep: 190 | return copy.deepcopy(params) 191 | return copy.copy(params) 192 | 193 | def set_params(self, **parameters): 194 | for parameter, value in parameters.items(): 195 | setattr(self, parameter, value) 196 | return self 197 | -------------------------------------------------------------------------------- /utils/chart_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Experiment charting script. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import pickle 24 | 25 | import numpy as np 26 | import matplotlib.pyplot as plt 27 | from matplotlib.backends.backend_pdf import PdfPages 28 | 29 | from absl import app 30 | from absl import flags 31 | from tensorflow import gfile 32 | 33 | flags.DEFINE_string('source_dir', 34 | '/tmp/toy_experiments', 35 | 'Directory with the output to analyze.') 36 | flags.DEFINE_string('save_dir', '/tmp/active_learning', 37 | 'Directory to save charts.') 38 | flags.DEFINE_string('dataset', 'letter', 'Dataset to analyze.') 39 | flags.DEFINE_string( 40 | 'sampling_methods', 41 | ('uniform,margin,informative_diverse,' 42 | 'pred_expert_advice_trip_agg,' 43 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34'), 44 | 'Comma separated string of sampling methods to include in chart.') 45 | flags.DEFINE_string('scoring_methods', 'logistic,kernel_ls', 46 | 'Comma separated string of scoring methods to chart.') 47 | flags.DEFINE_bool('normalize', False, 'Chart runs using normalized data.') 48 | flags.DEFINE_bool('standardize', True, 'Chart runs using standardized data.') 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | 53 | def combine_results(files, diff=False): 54 | all_results = {} 55 | for f in files: 56 | data = pickle.load(gfile.FastGFile(f, 'r')) 57 | for k in data: 58 | if isinstance(k, tuple): 59 | data[k].pop('noisy_targets') 60 | data[k].pop('indices') 61 | data[k].pop('selected_inds') 62 | data[k].pop('sampler_output') 63 | key = list(k) 64 | seed = key[-1] 65 | key = key[0:10] 66 | key = tuple(key) 67 | if key in all_results: 68 | if seed not in all_results[key]['random_seeds']: 69 | all_results[key]['random_seeds'].append(seed) 70 | for field in [f for f in data[k] if f != 'n_points']: 71 | all_results[key][field] = np.vstack( 72 | (all_results[key][field], data[k][field])) 73 | else: 74 | all_results[key] = data[k] 75 | all_results[key]['random_seeds'] = [seed] 76 | else: 77 | all_results[k] = data[k] 78 | return all_results 79 | 80 | 81 | def plot_results(all_results, score_method, norm, stand, sampler_filter): 82 | colors = { 83 | 'margin': 84 | 'gold', 85 | 'uniform': 86 | 'k', 87 | 'informative_diverse': 88 | 'r', 89 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34': 90 | 'b', 91 | 'pred_expert_advice_trip_agg': 92 | 'g' 93 | } 94 | labels = { 95 | 'margin': 96 | 'margin', 97 | 'uniform': 98 | 'uniform', 99 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34': 100 | 'margin:0.33,informative_diverse:0.33, uniform:0.34', 101 | 'informative_diverse': 102 | 'informative and diverse', 103 | 'pred_expert_advice_trip_agg': 104 | 'expert: margin,informative_diverse,uniform' 105 | } 106 | markers = { 107 | 'margin': 108 | 'None', 109 | 'uniform': 110 | 'None', 111 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34': 112 | '>', 113 | 'informative_diverse': 114 | 'None', 115 | 'pred_expert_advice_trip_agg': 116 | 'p' 117 | } 118 | fields = all_results['tuple_keys'] 119 | fields = dict(zip(fields, range(len(fields)))) 120 | 121 | for k in sorted(all_results.keys()): 122 | sampler = k[fields['sampler']] 123 | if (isinstance(k, tuple) and 124 | k[fields['score_method']] == score_method and 125 | k[fields['standardize']] == stand and 126 | k[fields['normalize']] == norm and 127 | (sampler_filter is None or sampler in sampler_filter)): 128 | results = all_results[k] 129 | n_trials = results['accuracy'].shape[0] 130 | x = results['data_sizes'][0] 131 | mean_acc = np.mean(results['accuracy'], axis=0) 132 | CI_acc = np.std(results['accuracy'], axis=0) / np.sqrt(n_trials) * 2.96 133 | if sampler == 'uniform': 134 | plt.plot( 135 | x, 136 | mean_acc, 137 | linewidth=1, 138 | label=labels[sampler], 139 | color=colors[sampler], 140 | linestyle='--' 141 | ) 142 | plt.fill_between( 143 | x, 144 | mean_acc - CI_acc, 145 | mean_acc + CI_acc, 146 | color=colors[sampler], 147 | alpha=0.2 148 | ) 149 | else: 150 | plt.plot( 151 | x, 152 | mean_acc, 153 | linewidth=1, 154 | label=labels[sampler], 155 | color=colors[sampler], 156 | marker=markers[sampler], 157 | markeredgecolor=colors[sampler] 158 | ) 159 | plt.fill_between( 160 | x, 161 | mean_acc - CI_acc, 162 | mean_acc + CI_acc, 163 | color=colors[sampler], 164 | alpha=0.2 165 | ) 166 | plt.legend(loc=4) 167 | 168 | 169 | def get_between(filename, start, end): 170 | start_ind = filename.find(start) + len(start) 171 | end_ind = filename.rfind(end) 172 | return filename[start_ind:end_ind] 173 | 174 | 175 | def get_sampling_method(dataset, filename): 176 | return get_between(filename, dataset + '_', '/') 177 | 178 | 179 | def get_scoring_method(filename): 180 | return get_between(filename, 'results_score_', '_select_') 181 | 182 | 183 | def get_normalize(filename): 184 | return get_between(filename, '_norm_', '_stand_') == 'True' 185 | 186 | 187 | def get_standardize(filename): 188 | return get_between( 189 | filename, '_stand_', filename[filename.rfind('_'):]) == 'True' 190 | 191 | 192 | def main(argv): 193 | del argv # Unused. 194 | if not gfile.Exists(FLAGS.save_dir): 195 | gfile.MkDir(FLAGS.save_dir) 196 | charting_filepath = os.path.join(FLAGS.save_dir, 197 | FLAGS.dataset + '_charts.pdf') 198 | sampling_methods = FLAGS.sampling_methods.split(',') 199 | scoring_methods = FLAGS.scoring_methods.split(',') 200 | files = gfile.Glob( 201 | os.path.join(FLAGS.source_dir, FLAGS.dataset + '*/results*.pkl')) 202 | files = [ 203 | f for f in files 204 | if (get_sampling_method(FLAGS.dataset, f) in sampling_methods and 205 | get_scoring_method(f) in scoring_methods and 206 | get_normalize(f) == FLAGS.normalize and 207 | get_standardize(f) == FLAGS.standardize) 208 | ] 209 | 210 | print('Reading in %d files...' % len(files)) 211 | all_results = combine_results(files) 212 | pdf = PdfPages(charting_filepath) 213 | 214 | print('Plotting charts...') 215 | plt.style.use('ggplot') 216 | for m in scoring_methods: 217 | plot_results( 218 | all_results, 219 | m, 220 | FLAGS.normalize, 221 | FLAGS.standardize, 222 | sampler_filter=sampling_methods) 223 | plt.title('Dataset: %s, Score Method: %s' % (FLAGS.dataset, m)) 224 | pdf.savefig() 225 | plt.close() 226 | pdf.close() 227 | 228 | 229 | if __name__ == '__main__': 230 | app.run(main) 231 | -------------------------------------------------------------------------------- /utils/create_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Make datasets and save specified directory. 16 | 17 | Downloads datasets using scikit datasets and can also parse csv file 18 | to save into pickle format. 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | from io import BytesIO 26 | import os 27 | import pickle 28 | import StringIO 29 | import tarfile 30 | import urllib2 31 | 32 | import keras.backend as K 33 | from keras.datasets import cifar10 34 | from keras.datasets import cifar100 35 | from keras.datasets import mnist 36 | 37 | import numpy as np 38 | import pandas as pd 39 | from sklearn.datasets import fetch_20newsgroups_vectorized 40 | from sklearn.datasets import fetch_mldata 41 | from sklearn.datasets import load_breast_cancer 42 | from sklearn.datasets import load_iris 43 | import sklearn.datasets.rcv1 44 | from sklearn.feature_extraction.text import CountVectorizer 45 | from sklearn.feature_extraction.text import TfidfTransformer 46 | 47 | from absl import app 48 | from absl import flags 49 | from tensorflow import gfile 50 | 51 | flags.DEFINE_string('save_dir', '/tmp/data', 52 | 'Where to save outputs') 53 | flags.DEFINE_string('datasets', '', 54 | 'Which datasets to download, comma separated.') 55 | FLAGS = flags.FLAGS 56 | 57 | 58 | class Dataset(object): 59 | 60 | def __init__(self, X, y): 61 | self.data = X 62 | self.target = y 63 | 64 | 65 | def get_csv_data(filename): 66 | """Parse csv and return Dataset object with data and targets. 67 | 68 | Create pickle data from csv, assumes the first column contains the targets 69 | Args: 70 | filename: complete path of the csv file 71 | Returns: 72 | Dataset object 73 | """ 74 | f = gfile.GFile(filename, 'r') 75 | mat = [] 76 | for l in f: 77 | row = l.strip() 78 | row = row.replace('"', '') 79 | row = row.split(',') 80 | row = [float(x) for x in row] 81 | mat.append(row) 82 | mat = np.array(mat) 83 | y = mat[:, 0] 84 | X = mat[:, 1:] 85 | data = Dataset(X, y) 86 | return data 87 | 88 | 89 | def get_wikipedia_talk_data(): 90 | """Get wikipedia talk dataset. 91 | 92 | See here for more information about the dataset: 93 | https://figshare.com/articles/Wikipedia_Detox_Data/4054689 94 | Downloads annotated comments and annotations. 95 | """ 96 | 97 | ANNOTATED_COMMENTS_URL = 'https://ndownloader.figshare.com/files/7554634' 98 | ANNOTATIONS_URL = 'https://ndownloader.figshare.com/files/7554637' 99 | 100 | def download_file(url): 101 | req = urllib2.Request(url) 102 | response = urllib2.urlopen(req) 103 | return response 104 | 105 | # Process comments 106 | comments = pd.read_table( 107 | download_file(ANNOTATED_COMMENTS_URL), index_col=0, sep='\t') 108 | # remove newline and tab tokens 109 | comments['comment'] = comments['comment'].apply( 110 | lambda x: x.replace('NEWLINE_TOKEN', ' ')) 111 | comments['comment'] = comments['comment'].apply( 112 | lambda x: x.replace('TAB_TOKEN', ' ')) 113 | 114 | # Process labels 115 | annotations = pd.read_table(download_file(ANNOTATIONS_URL), sep='\t') 116 | # labels a comment as an atack if the majority of annoatators did so 117 | labels = annotations.groupby('rev_id')['attack'].mean() > 0.5 118 | 119 | # Perform data preprocessing, should probably tune these hyperparameters 120 | vect = CountVectorizer(max_features=30000, ngram_range=(1, 2)) 121 | tfidf = TfidfTransformer(norm='l2') 122 | X = tfidf.fit_transform(vect.fit_transform(comments['comment'])) 123 | y = np.array(labels) 124 | data = Dataset(X, y) 125 | return data 126 | 127 | 128 | def get_keras_data(dataname): 129 | """Get datasets using keras API and return as a Dataset object.""" 130 | if dataname == 'cifar10_keras': 131 | train, test = cifar10.load_data() 132 | elif dataname == 'cifar100_coarse_keras': 133 | train, test = cifar100.load_data('coarse') 134 | elif dataname == 'cifar100_keras': 135 | train, test = cifar100.load_data() 136 | elif dataname == 'mnist_keras': 137 | train, test = mnist.load_data() 138 | else: 139 | raise NotImplementedError('dataset not supported') 140 | 141 | X = np.concatenate((train[0], test[0])) 142 | y = np.concatenate((train[1], test[1])) 143 | 144 | if dataname == 'mnist_keras': 145 | # Add extra dimension for channel 146 | num_rows = X.shape[1] 147 | num_cols = X.shape[2] 148 | X = X.reshape(X.shape[0], 1, num_rows, num_cols) 149 | if K.image_data_format() == 'channels_last': 150 | X = X.transpose(0, 2, 3, 1) 151 | 152 | y = y.flatten() 153 | data = Dataset(X, y) 154 | return data 155 | 156 | 157 | # TODO(lishal): remove regular cifar10 dataset and only use dataset downloaded 158 | # from keras to maintain image dims to create tensor for tf models 159 | # Requires adding handling in run_experiment.py for handling of different 160 | # training methods that require either 2d or tensor data. 161 | def get_cifar10(): 162 | """Get CIFAR-10 dataset from source dir. 163 | 164 | Slightly redundant with keras function to get cifar10 but this returns 165 | in flat format instead of keras numpy image tensor. 166 | """ 167 | url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' 168 | def download_file(url): 169 | req = urllib2.Request(url) 170 | response = urllib2.urlopen(req) 171 | return response 172 | response = download_file(url) 173 | tmpfile = BytesIO() 174 | while True: 175 | # Download a piece of the file from the connection 176 | s = response.read(16384) 177 | # Once the entire file has been downloaded, tarfile returns b'' 178 | # (the empty bytes) which is a falsey value 179 | if not s: 180 | break 181 | # Otherwise, write the piece of the file to the temporary file. 182 | tmpfile.write(s) 183 | response.close() 184 | 185 | tmpfile.seek(0) 186 | tar_dir = tarfile.open(mode='r:gz', fileobj=tmpfile) 187 | X = None 188 | y = None 189 | for member in tar_dir.getnames(): 190 | if '_batch' in member: 191 | filestream = tar_dir.extractfile(member).read() 192 | batch = pickle.load(StringIO.StringIO(filestream)) 193 | if X is None: 194 | X = np.array(batch['data'], dtype=np.uint8) 195 | y = np.array(batch['labels']) 196 | else: 197 | X = np.concatenate((X, np.array(batch['data'], dtype=np.uint8))) 198 | y = np.concatenate((y, np.array(batch['labels']))) 199 | data = Dataset(X, y) 200 | return data 201 | 202 | 203 | def get_mldata(dataset): 204 | # Use scikit to grab datasets and save them save_dir. 205 | save_dir = FLAGS.save_dir 206 | filename = os.path.join(save_dir, dataset[1]+'.pkl') 207 | 208 | if not gfile.Exists(save_dir): 209 | gfile.MkDir(save_dir) 210 | if not gfile.Exists(filename): 211 | if dataset[0][-3:] == 'csv': 212 | data = get_csv_data(dataset[0]) 213 | elif dataset[0] == 'breast_cancer': 214 | data = load_breast_cancer() 215 | elif dataset[0] == 'iris': 216 | data = load_iris() 217 | elif dataset[0] == 'newsgroup': 218 | # Removing header information to make sure that no newsgroup identifying 219 | # information is included in data 220 | data = fetch_20newsgroups_vectorized(subset='all', remove=('headers')) 221 | tfidf = TfidfTransformer(norm='l2') 222 | X = tfidf.fit_transform(data.data) 223 | data.data = X 224 | elif dataset[0] == 'rcv1': 225 | sklearn.datasets.rcv1.URL = ( 226 | 'http://www.ai.mit.edu/projects/jmlr/papers/' 227 | 'volume5/lewis04a/a13-vector-files/lyrl2004_vectors') 228 | sklearn.datasets.rcv1.URL_topics = ( 229 | 'http://www.ai.mit.edu/projects/jmlr/papers/' 230 | 'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz') 231 | data = sklearn.datasets.fetch_rcv1( 232 | data_home='/tmp') 233 | elif dataset[0] == 'wikipedia_attack': 234 | data = get_wikipedia_talk_data() 235 | elif dataset[0] == 'cifar10': 236 | data = get_cifar10() 237 | elif 'keras' in dataset[0]: 238 | data = get_keras_data(dataset[0]) 239 | else: 240 | try: 241 | data = fetch_mldata(dataset[0]) 242 | except: 243 | raise Exception('ERROR: failed to fetch data from mldata.org') 244 | X = data.data 245 | y = data.target 246 | if X.shape[0] != y.shape[0]: 247 | X = np.transpose(X) 248 | assert X.shape[0] == y.shape[0] 249 | 250 | data = {'data': X, 'target': y} 251 | pickle.dump(data, gfile.GFile(filename, 'w')) 252 | 253 | 254 | def main(argv): 255 | del argv # Unused. 256 | # First entry of tuple is mldata.org name, second is the name that we'll use 257 | # to reference the data. 258 | datasets = [('mnist (original)', 'mnist'), ('australian', 'australian'), 259 | ('heart', 'heart'), ('breast_cancer', 'breast_cancer'), 260 | ('iris', 'iris'), ('vehicle', 'vehicle'), ('wine', 'wine'), 261 | ('waveform ida', 'waveform'), ('german ida', 'german'), 262 | ('splice ida', 'splice'), ('ringnorm ida', 'ringnorm'), 263 | ('twonorm ida', 'twonorm'), ('diabetes_scale', 'diabetes'), 264 | ('mushrooms', 'mushrooms'), ('letter', 'letter'), ('dna', 'dna'), 265 | ('banana-ida', 'banana'), ('letter', 'letter'), ('dna', 'dna'), 266 | ('newsgroup', 'newsgroup'), ('cifar10', 'cifar10'), 267 | ('cifar10_keras', 'cifar10_keras'), 268 | ('cifar100_keras', 'cifar100_keras'), 269 | ('cifar100_coarse_keras', 'cifar100_coarse_keras'), 270 | ('mnist_keras', 'mnist_keras'), 271 | ('wikipedia_attack', 'wikipedia_attack'), 272 | ('rcv1', 'rcv1')] 273 | 274 | if FLAGS.datasets: 275 | subset = FLAGS.datasets.split(',') 276 | datasets = [d for d in datasets if d[1] in subset] 277 | 278 | for d in datasets: 279 | print(d[1]) 280 | get_mldata(d) 281 | 282 | 283 | if __name__ == '__main__': 284 | app.run(main) 285 | -------------------------------------------------------------------------------- /utils/kernel_block_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Block kernel lsqr solver for multi-class classification.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | import math 22 | 23 | import numpy as np 24 | import scipy.linalg as linalg 25 | from scipy.sparse.linalg import spsolve 26 | from sklearn import metrics 27 | 28 | 29 | class BlockKernelSolver(object): 30 | """Inspired by algorithm from https://arxiv.org/pdf/1602.05310.pdf.""" 31 | # TODO(lishal): save preformed kernel matrix and reuse if possible 32 | # perhaps not possible if want to keep scikitlearn signature 33 | 34 | def __init__(self, 35 | random_state=1, 36 | C=0.1, 37 | block_size=4000, 38 | epochs=3, 39 | verbose=False, 40 | gamma=None): 41 | self.block_size = block_size 42 | self.epochs = epochs 43 | self.C = C 44 | self.kernel = 'rbf' 45 | self.coef_ = None 46 | self.verbose = verbose 47 | self.encode_map = None 48 | self.decode_map = None 49 | self.gamma = gamma 50 | self.X_train = None 51 | self.random_state = random_state 52 | 53 | def encode_y(self, y): 54 | # Handles classes that do not start counting from 0. 55 | if self.encode_map is None: 56 | self.classes_ = sorted(list(set(y))) 57 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_)))) 58 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_)) 59 | mapper = lambda x: self.encode_map[x] 60 | transformed_y = np.array(map(mapper, y)) 61 | return transformed_y 62 | 63 | def decode_y(self, y): 64 | mapper = lambda x: self.decode_map[x] 65 | transformed_y = np.array(map(mapper, y)) 66 | return transformed_y 67 | 68 | def fit(self, X_train, y_train, sample_weight=None): 69 | """Form K and solve (K + lambda * I)x = y in a block-wise fashion.""" 70 | np.random.seed(self.random_state) 71 | self.X_train = X_train 72 | n_features = X_train.shape[1] 73 | y = self.encode_y(y_train) 74 | if self.gamma is None: 75 | self.gamma = 1./n_features 76 | K = metrics.pairwise.pairwise_kernels( 77 | X_train, metric=self.kernel, gamma=self.gamma) 78 | if self.verbose: 79 | print('Finished forming kernel matrix.') 80 | 81 | # compute some constants 82 | num_classes = len(list(set(y))) 83 | num_samples = K.shape[0] 84 | num_blocks = math.ceil(num_samples*1.0/self.block_size) 85 | x = np.zeros((K.shape[0], num_classes)) 86 | y_hat = np.zeros((K.shape[0], num_classes)) 87 | onehot = lambda x: np.eye(num_classes)[x] 88 | y_onehot = np.array(map(onehot, y)) 89 | idxes = np.diag_indices(num_samples) 90 | if sample_weight is not None: 91 | weights = np.sqrt(sample_weight) 92 | weights = weights[:, np.newaxis] 93 | y_onehot = weights * y_onehot 94 | K *= np.outer(weights, weights) 95 | if num_blocks == 1: 96 | epochs = 1 97 | else: 98 | epochs = self.epochs 99 | 100 | for e in range(epochs): 101 | shuffled_coords = np.random.choice( 102 | num_samples, num_samples, replace=False) 103 | for b in range(int(num_blocks)): 104 | residuals = y_onehot - y_hat 105 | 106 | # Form a block of K. 107 | K[idxes] += (self.C * num_samples) 108 | block = shuffled_coords[b*self.block_size: 109 | min((b+1)*self.block_size, num_samples)] 110 | K_block = K[:, block] 111 | # Dim should be block size x block size 112 | KbTKb = K_block.T.dot(K_block) 113 | 114 | if self.verbose: 115 | print('solving block {0}'.format(b)) 116 | # Try linalg solve then sparse solve for handling of sparse input. 117 | try: 118 | x_block = linalg.solve(KbTKb, K_block.T.dot(residuals)) 119 | except: 120 | try: 121 | x_block = spsolve(KbTKb, K_block.T.dot(residuals)) 122 | except: 123 | return None 124 | 125 | # update model 126 | x[block] = x[block] + x_block 127 | K[idxes] = K[idxes] - (self.C * num_samples) 128 | y_hat = K.dot(x) 129 | 130 | y_pred = np.argmax(y_hat, axis=1) 131 | train_acc = metrics.accuracy_score(y, y_pred) 132 | if self.verbose: 133 | print('Epoch: {0}, Block: {1}, Train Accuracy: {2}' 134 | .format(e, b, train_acc)) 135 | self.coef_ = x 136 | 137 | def predict(self, X_val): 138 | val_K = metrics.pairwise.pairwise_kernels( 139 | X_val, self.X_train, metric=self.kernel, gamma=self.gamma) 140 | val_pred = np.argmax(val_K.dot(self.coef_), axis=1) 141 | return self.decode_y(val_pred) 142 | 143 | def score(self, X_val, val_y): 144 | val_pred = self.predict(X_val) 145 | val_acc = metrics.accuracy_score(val_y, val_pred) 146 | return val_acc 147 | 148 | def decision_function(self, X, type='predicted'): 149 | # Return the predicted value of the best class 150 | # Margin_AL will see that a vector is returned and not a matrix and 151 | # simply select the points that have the lowest predicted value to label 152 | K = metrics.pairwise.pairwise_kernels( 153 | X, self.X_train, metric=self.kernel, gamma=self.gamma) 154 | predicted = K.dot(self.coef_) 155 | if type == 'scores': 156 | val_best = np.max(K.dot(self.coef_), axis=1) 157 | return val_best 158 | elif type == 'predicted': 159 | return predicted 160 | else: 161 | raise NotImplementedError('Invalid return type for decision function.') 162 | 163 | def get_params(self, deep=False): 164 | params = {} 165 | params['C'] = self.C 166 | params['gamma'] = self.gamma 167 | if deep: 168 | return copy.deepcopy(params) 169 | return copy.copy(params) 170 | 171 | def set_params(self, **parameters): 172 | for parameter, value in parameters.items(): 173 | setattr(self, parameter, value) 174 | return self 175 | 176 | def softmax_over_predicted(self, X): 177 | val_K = metrics.pairwise.pairwise_kernels( 178 | X, self.X_train, metric=self.kernel, gamma=self.gamma) 179 | val_pred = val_K.dot(self.coef_) 180 | row_min = np.min(val_pred, axis=1) 181 | val_pred = val_pred - row_min[:, None] 182 | val_pred = np.exp(val_pred) 183 | sum_exp = np.sum(val_pred, axis=1) 184 | val_pred = val_pred/sum_exp[:, None] 185 | return val_pred 186 | -------------------------------------------------------------------------------- /utils/small_cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements Small CNN model in keras using tensorflow backend.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | 22 | import keras 23 | import keras.backend as K 24 | from keras.layers import Activation 25 | from keras.layers import Conv2D 26 | from keras.layers import Dense 27 | from keras.layers import Dropout 28 | from keras.layers import Flatten 29 | from keras.layers import MaxPooling2D 30 | from keras.models import Sequential 31 | 32 | import numpy as np 33 | import tensorflow as tf 34 | 35 | 36 | class SmallCNN(object): 37 | """Small convnet that matches sklearn api. 38 | 39 | Implements model from 40 | https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py 41 | Adapts for inputs of variable size, expects data to be 4d tensor, with 42 | # of obserations as first dimension and other dimensions to correspond to 43 | length width and # of channels in image. 44 | """ 45 | 46 | def __init__(self, 47 | random_state=1, 48 | epochs=50, 49 | batch_size=32, 50 | solver='rmsprop', 51 | learning_rate=0.001, 52 | lr_decay=0.): 53 | # params 54 | self.solver = solver 55 | self.epochs = epochs 56 | self.batch_size = batch_size 57 | self.learning_rate = learning_rate 58 | self.lr_decay = lr_decay 59 | # data 60 | self.encode_map = None 61 | self.decode_map = None 62 | self.model = None 63 | self.random_state = random_state 64 | self.n_classes = None 65 | 66 | def build_model(self, X): 67 | # assumes that data axis order is same as the backend 68 | input_shape = X.shape[1:] 69 | np.random.seed(self.random_state) 70 | tf.set_random_seed(self.random_state) 71 | 72 | model = Sequential() 73 | model.add(Conv2D(32, (3, 3), padding='same', 74 | input_shape=input_shape, name='conv1')) 75 | model.add(Activation('relu')) 76 | model.add(Conv2D(32, (3, 3), name='conv2')) 77 | model.add(Activation('relu')) 78 | model.add(MaxPooling2D(pool_size=(2, 2))) 79 | model.add(Dropout(0.25)) 80 | 81 | model.add(Conv2D(64, (3, 3), padding='same', name='conv3')) 82 | model.add(Activation('relu')) 83 | model.add(Conv2D(64, (3, 3), name='conv4')) 84 | model.add(Activation('relu')) 85 | model.add(MaxPooling2D(pool_size=(2, 2))) 86 | model.add(Dropout(0.25)) 87 | 88 | model.add(Flatten()) 89 | model.add(Dense(512, name='dense1')) 90 | model.add(Activation('relu')) 91 | model.add(Dropout(0.5)) 92 | model.add(Dense(self.n_classes, name='dense2')) 93 | model.add(Activation('softmax')) 94 | 95 | try: 96 | optimizer = getattr(keras.optimizers, self.solver) 97 | except: 98 | raise NotImplementedError('optimizer not implemented in keras') 99 | # All optimizers with the exception of nadam take decay as named arg 100 | try: 101 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay) 102 | except: 103 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay) 104 | 105 | model.compile(loss='categorical_crossentropy', 106 | optimizer=opt, 107 | metrics=['accuracy']) 108 | # Save initial weights so that model can be retrained with same 109 | # initialization 110 | self.initial_weights = copy.deepcopy(model.get_weights()) 111 | 112 | self.model = model 113 | 114 | def create_y_mat(self, y): 115 | y_encode = self.encode_y(y) 116 | y_encode = np.reshape(y_encode, (len(y_encode), 1)) 117 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes) 118 | return y_mat 119 | 120 | # Add handling for classes that do not start counting from 0 121 | def encode_y(self, y): 122 | if self.encode_map is None: 123 | self.classes_ = sorted(list(set(y))) 124 | self.n_classes = len(self.classes_) 125 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_)))) 126 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_)) 127 | mapper = lambda x: self.encode_map[x] 128 | transformed_y = np.array(map(mapper, y)) 129 | return transformed_y 130 | 131 | def decode_y(self, y): 132 | mapper = lambda x: self.decode_map[x] 133 | transformed_y = np.array(map(mapper, y)) 134 | return transformed_y 135 | 136 | def fit(self, X_train, y_train, sample_weight=None): 137 | y_mat = self.create_y_mat(y_train) 138 | 139 | if self.model is None: 140 | self.build_model(X_train) 141 | 142 | # We don't want incremental fit so reset learning rate and weights 143 | K.set_value(self.model.optimizer.lr, self.learning_rate) 144 | self.model.set_weights(self.initial_weights) 145 | self.model.fit( 146 | X_train, 147 | y_mat, 148 | batch_size=self.batch_size, 149 | epochs=self.epochs, 150 | shuffle=True, 151 | sample_weight=sample_weight, 152 | verbose=0) 153 | 154 | def predict(self, X_val): 155 | predicted = self.model.predict(X_val) 156 | return predicted 157 | 158 | def score(self, X_val, val_y): 159 | y_mat = self.create_y_mat(val_y) 160 | val_acc = self.model.evaluate(X_val, y_mat)[1] 161 | return val_acc 162 | 163 | def decision_function(self, X): 164 | return self.predict(X) 165 | 166 | def transform(self, X): 167 | model = self.model 168 | inp = [model.input] 169 | activations = [] 170 | 171 | # Get activations of the first dense layer. 172 | output = [layer.output for layer in model.layers if 173 | layer.name == 'dense1'][0] 174 | func = K.function(inp + [K.learning_phase()], [output]) 175 | for i in range(int(X.shape[0]/self.batch_size) + 1): 176 | minibatch = X[i * self.batch_size 177 | : min(X.shape[0], (i+1) * self.batch_size)] 178 | list_inputs = [minibatch, 0.] 179 | # Learning phase. 0 = Test mode (no dropout or batch normalization) 180 | layer_output = func(list_inputs)[0] 181 | activations.append(layer_output) 182 | output = np.vstack(tuple(activations)) 183 | return output 184 | 185 | def get_params(self, deep = False): 186 | params = {} 187 | params['solver'] = self.solver 188 | params['epochs'] = self.epochs 189 | params['batch_size'] = self.batch_size 190 | params['learning_rate'] = self.learning_rate 191 | params['weight_decay'] = self.lr_decay 192 | if deep: 193 | return copy.deepcopy(params) 194 | return copy.copy(params) 195 | 196 | def set_params(self, **parameters): 197 | for parameter, value in parameters.items(): 198 | setattr(self, parameter, value) 199 | return self 200 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions for run_experiment.py.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | import os 22 | import pickle 23 | import sys 24 | 25 | import numpy as np 26 | import scipy 27 | 28 | from sklearn.linear_model import LogisticRegression 29 | from sklearn.model_selection import GridSearchCV 30 | from sklearn.svm import LinearSVC 31 | from sklearn.svm import SVC 32 | 33 | from tensorflow import gfile 34 | 35 | 36 | from utils.kernel_block_solver import BlockKernelSolver 37 | from utils.small_cnn import SmallCNN 38 | from utils.allconv import AllConv 39 | 40 | 41 | class Logger(object): 42 | """Logging object to write to file and stdout.""" 43 | 44 | def __init__(self, filename): 45 | self.terminal = sys.stdout 46 | self.log = gfile.GFile(filename, "w") 47 | 48 | def write(self, message): 49 | self.terminal.write(message) 50 | self.log.write(message) 51 | 52 | def flush(self): 53 | self.terminal.flush() 54 | 55 | def flush_file(self): 56 | self.log.flush() 57 | 58 | 59 | def create_checker_unbalanced(split, n, grid_size): 60 | """Creates a dataset with two classes that occupy one color of checkboard. 61 | 62 | Args: 63 | split: splits to use for class imbalance. 64 | n: number of datapoints to sample. 65 | grid_size: checkerboard size. 66 | Returns: 67 | X: 2d features. 68 | y: binary class. 69 | """ 70 | y = np.zeros(0) 71 | X = np.zeros((0, 2)) 72 | for i in range(grid_size): 73 | for j in range(grid_size): 74 | label = 0 75 | n_0 = int(n/(grid_size*grid_size) * split[0] * 2) 76 | if (i-j) % 2 == 0: 77 | label = 1 78 | n_0 = int(n/(grid_size*grid_size) * split[1] * 2) 79 | x_1 = np.random.uniform(i, i+1, n_0) 80 | x_2 = np.random.uniform(j, j+1, n_0) 81 | x = np.vstack((x_1, x_2)) 82 | x = x.T 83 | X = np.concatenate((X, x)) 84 | y_0 = label * np.ones(n_0) 85 | y = np.concatenate((y, y_0)) 86 | return X, y 87 | 88 | 89 | def flatten_X(X): 90 | shape = X.shape 91 | flat_X = X 92 | if len(shape) > 2: 93 | flat_X = np.reshape(X, (shape[0], np.product(shape[1:]))) 94 | return flat_X 95 | 96 | 97 | def get_mldata(data_dir, name): 98 | """Loads data from data_dir. 99 | 100 | Looks for the file in data_dir. 101 | Assumes that data is in pickle format with dictionary fields data and target. 102 | 103 | 104 | Args: 105 | data_dir: directory to look in 106 | name: dataset name, assumes data is saved in the save_dir with filename 107 | .pkl 108 | Returns: 109 | data and targets 110 | Raises: 111 | NameError: dataset not found in data folder. 112 | """ 113 | dataname = name 114 | if dataname == "checkerboard": 115 | X, y = create_checker_unbalanced(split=[1./5, 4./5], n=10000, grid_size=4) 116 | else: 117 | filename = os.path.join(data_dir, dataname + ".pkl") 118 | if not gfile.Exists(filename): 119 | raise NameError("ERROR: dataset not available") 120 | data = pickle.load(gfile.GFile(filename, "r")) 121 | X = data["data"] 122 | y = data["target"] 123 | if "keras" in dataname: 124 | X = X / 255 125 | y = y.flatten() 126 | return X, y 127 | 128 | 129 | def filter_data(X, y, keep=None): 130 | """Filters data by class indicated in keep. 131 | 132 | Args: 133 | X: train data 134 | y: train targets 135 | keep: defaults to None which will keep everything, otherwise takes a list 136 | of classes to keep 137 | 138 | Returns: 139 | filtered data and targets 140 | """ 141 | if keep is None: 142 | return X, y 143 | keep_ind = [i for i in range(len(y)) if y[i] in keep] 144 | return X[keep_ind], y[keep_ind] 145 | 146 | 147 | def get_class_counts(y_full, y): 148 | """Gets the count of all classes in a sample. 149 | 150 | Args: 151 | y_full: full target vector containing all classes 152 | y: sample vector for which to perform the count 153 | Returns: 154 | count of classes for the sample vector y, the class order for count will 155 | be the same as long as same y_full is fed in 156 | """ 157 | classes = np.unique(y_full) 158 | classes = np.sort(classes) 159 | unique, counts = np.unique(y, return_counts=True) 160 | complete_counts = [] 161 | for c in classes: 162 | if c not in unique: 163 | complete_counts.append(0) 164 | else: 165 | index = np.where(unique == c)[0][0] 166 | complete_counts.append(counts[index]) 167 | return np.array(complete_counts) 168 | 169 | 170 | def flip_label(y, percent_random): 171 | """Flips a percentage of labels for one class to the other. 172 | 173 | Randomly sample a percent of points and randomly label the sampled points as 174 | one of the other classes. 175 | Does not introduce bias. 176 | 177 | Args: 178 | y: labels of all datapoints 179 | percent_random: percent of datapoints to corrupt the labels 180 | 181 | Returns: 182 | new labels with noisy labels for indicated percent of data 183 | """ 184 | classes = np.unique(y) 185 | y_orig = copy.copy(y) 186 | indices = range(y_orig.shape[0]) 187 | np.random.shuffle(indices) 188 | sample = indices[0:int(len(indices) * 1.0 * percent_random)] 189 | fake_labels = [] 190 | for s in sample: 191 | label = y[s] 192 | class_ind = np.where(classes == label)[0][0] 193 | other_classes = np.delete(classes, class_ind) 194 | np.random.shuffle(other_classes) 195 | fake_label = other_classes[0] 196 | assert fake_label != label 197 | fake_labels.append(fake_label) 198 | y[sample] = np.array(fake_labels) 199 | assert all(y[indices[len(sample):]] == y_orig[indices[len(sample):]]) 200 | return y 201 | 202 | 203 | def get_model(method, seed=13): 204 | """Construct sklearn model using either logistic regression or linear svm. 205 | 206 | Wraps grid search on regularization parameter over either logistic regression 207 | or svm, returns constructed model 208 | 209 | Args: 210 | method: string indicating scikit method to use, currently accepts logistic 211 | and linear svm. 212 | seed: int or rng to use for random state fed to scikit method 213 | 214 | Returns: 215 | scikit learn model 216 | """ 217 | # TODO(lishal): extend to include any scikit model that implements 218 | # a decision function. 219 | # TODO(lishal): for kernel methods, currently using default value for gamma 220 | # but should probably tune. 221 | if method == "logistic": 222 | model = LogisticRegression(random_state=seed, multi_class="multinomial", 223 | solver="lbfgs", max_iter=200) 224 | params = {"C": [10.0**(i) for i in range(-4, 5)]} 225 | elif method == "logistic_ovr": 226 | model = LogisticRegression(random_state=seed) 227 | params = {"C": [10.0**(i) for i in range(-5, 4)]} 228 | elif method == "linear_svm": 229 | model = LinearSVC(random_state=seed) 230 | params = {"C": [10.0**(i) for i in range(-4, 5)]} 231 | elif method == "kernel_svm": 232 | model = SVC(random_state=seed) 233 | params = {"C": [10.0**(i) for i in range(-4, 5)]} 234 | elif method == "kernel_ls": 235 | model = BlockKernelSolver(random_state=seed) 236 | params = {"C": [10.0**(i) for i in range(-6, 1)]} 237 | elif method == "small_cnn": 238 | # Model does not work with weighted_expert or simulate_batch 239 | model = SmallCNN(random_state=seed) 240 | return model 241 | elif method == "allconv": 242 | # Model does not work with weighted_expert or simulate_batch 243 | model = AllConv(random_state=seed) 244 | return model 245 | 246 | else: 247 | raise NotImplementedError("ERROR: " + method + " not implemented") 248 | 249 | model = GridSearchCV(model, params, cv=3) 250 | return model 251 | 252 | 253 | def calculate_entropy(batch_size, y_s): 254 | """Calculates KL div between training targets and targets selected by AL. 255 | 256 | Args: 257 | batch_size: batch size of datapoints selected by AL 258 | y_s: vector of datapoints selected by AL. Assumes that the order of the 259 | data is the order in which points were labeled by AL. Also assumes 260 | that in the offline setting y_s will eventually overlap completely with 261 | original training targets. 262 | Returns: 263 | entropy between actual distribution of classes and distribution of 264 | samples selected by AL 265 | """ 266 | n_batches = int(np.ceil(len(y_s) * 1.0 / batch_size)) 267 | counts = get_class_counts(y_s, y_s) 268 | true_dist = counts / (len(y_s) * 1.0) 269 | entropy = [] 270 | for b in range(n_batches): 271 | sample = y_s[b * batch_size:(b + 1) * batch_size] 272 | counts = get_class_counts(y_s, sample) 273 | sample_dist = counts / (1.0 * len(sample)) 274 | entropy.append(scipy.stats.entropy(true_dist, sample_dist)) 275 | return entropy 276 | 277 | 278 | def get_train_val_test_splits(X, y, max_points, seed, confusion, seed_batch, 279 | split=(2./3, 1./6, 1./6)): 280 | """Return training, validation, and test splits for X and y. 281 | 282 | Args: 283 | X: features 284 | y: targets 285 | max_points: # of points to use when creating splits. 286 | seed: seed for shuffling. 287 | confusion: labeling noise to introduce. 0.1 means randomize 10% of labels. 288 | seed_batch: # of initial datapoints to ensure sufficient class membership. 289 | split: percent splits for train, val, and test. 290 | Returns: 291 | indices: shuffled indices to recreate splits given original input data X. 292 | y_noise: y with noise injected, needed to reproduce results outside of 293 | run_experiments using original data. 294 | """ 295 | np.random.seed(seed) 296 | X_copy = copy.copy(X) 297 | y_copy = copy.copy(y) 298 | 299 | # Introduce labeling noise 300 | y_noise = flip_label(y_copy, confusion) 301 | 302 | indices = np.arange(len(y)) 303 | 304 | if max_points is None: 305 | max_points = len(y_noise) 306 | else: 307 | max_points = min(len(y_noise), max_points) 308 | train_split = int(max_points * split[0]) 309 | val_split = train_split + int(max_points * split[1]) 310 | assert seed_batch <= train_split 311 | 312 | # Do this to make sure that the initial batch has examples from all classes 313 | min_shuffle = 3 314 | n_shuffle = 0 315 | y_tmp = y_noise 316 | 317 | # Need at least 4 obs of each class for 2 fold CV to work in grid search step 318 | while (any(get_class_counts(y_tmp, y_tmp[0:seed_batch]) < 4) 319 | or n_shuffle < min_shuffle): 320 | np.random.shuffle(indices) 321 | y_tmp = y_noise[indices] 322 | n_shuffle += 1 323 | 324 | X_train = X_copy[indices[0:train_split]] 325 | X_val = X_copy[indices[train_split:val_split]] 326 | X_test = X_copy[indices[val_split:max_points]] 327 | y_train = y_noise[indices[0:train_split]] 328 | y_val = y_noise[indices[train_split:val_split]] 329 | y_test = y_noise[indices[val_split:max_points]] 330 | # Make sure that we have enough observations of each class for 2-fold cv 331 | assert all(get_class_counts(y_noise, y_train[0:seed_batch]) >= 4) 332 | # Make sure that returned shuffled indices are correct 333 | assert all(y_noise[indices[0:max_points]] == 334 | np.concatenate((y_train, y_val, y_test), axis=0)) 335 | return (indices[0:max_points], X_train, y_train, 336 | X_val, y_val, X_test, y_test, y_noise) 337 | --------------------------------------------------------------------------------