├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── __init__.py
├── requirements.txt
├── run_experiment.py
├── sampling_methods
├── __init__.py
├── bandit_discrete.py
├── constants.py
├── graph_density.py
├── hierarchical_clustering_AL.py
├── informative_diverse.py
├── kcenter_greedy.py
├── margin_AL.py
├── mixture_of_samplers.py
├── represent_cluster_centers.py
├── sampling_def.py
├── simulate_batch.py
├── uniform_sampling.py
├── utils
│ ├── __init__.py
│ ├── tree.py
│ └── tree_test.py
└── wrapper_sampler_def.py
└── utils
├── __init__.py
├── allconv.py
├── chart_data.py
├── create_data.py
├── kernel_block_solver.py
├── small_cnn.py
└── utils.py
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Active Learning Playground
2 |
3 | ## Introduction
4 |
5 | This is a python module for experimenting with different active learning
6 | algorithms. There are a few key components to running active learning
7 | experiments:
8 |
9 | * Main experiment script is
10 | [`run_experiment.py`](run_experiment.py)
11 | with many flags for different run options.
12 |
13 | * Supported datasets can be downloaded to a specified directory by running
14 | [`utils/create_data.py`](utils/create_data.py).
15 |
16 | * Supported active learning methods are in
17 | [`sampling_methods`](sampling_methods/).
18 |
19 | Below I will go into each component in more detail.
20 |
21 | DISCLAIMER: This is not an official Google product.
22 |
23 | ## Setup
24 | The dependencies are in [`requirements.txt`](requirements.txt). Please make sure these packages are
25 | installed before running experiments. If GPU capable `tensorflow` is desired, please follow
26 | instructions [here](https://www.tensorflow.org/install/).
27 |
28 | It is highly suggested that you install all dependencies into a separate `virtualenv` for
29 | easy package management.
30 |
31 | ## Getting benchmark datasets
32 |
33 | By default the datasets are saved to `/tmp/data`. You can specify another directory via the
34 | `--save_dir` flag.
35 |
36 | Redownloading all the datasets will be very time consuming so please be patient.
37 | You can specify a subset of the data to download by passing in a comma separated
38 | string of datasets via the `--datasets` flag.
39 |
40 | ## Running experiments
41 |
42 | There are a few key flags for
43 | [`run_experiment.py`](run_experiment.py):
44 |
45 | * `dataset`: name of the dataset, must match the save name used in
46 | `create_data.py`. Must also exist in the data_dir.
47 |
48 | * `sampling_method`: active learning method to use. Must be specified in
49 | [`sampling_methods/constants.py`](sampling_methods/constants.py).
50 |
51 | * `warmstart_size`: initial batch of uniformly sampled examples to use as seed
52 | data. Float indicates percentage of total training data and integer
53 | indicates raw size.
54 |
55 | * `batch_size`: number of datapoints to request in each batch. Float indicates
56 | percentage of total training data and integer indicates raw size.
57 |
58 | * `score_method`: model to use to evaluate the performance of the sampling
59 | method. Must be in `get_model` method of
60 | [`utils/utils.py`](utils/utils.py).
61 |
62 | * `data_dir`: directory with saved datasets.
63 |
64 | * `save_dir`: directory to save results.
65 |
66 | This is just a subset of all the flags. There are also options for
67 | preprocessing, introducing labeling noise, dataset subsampling, and using a
68 | different model to select than to score/evaluate.
69 |
70 | ## Available active learning methods
71 |
72 | All named active learning methods are in
73 | [`sampling_methods/constants.py`](sampling_methods/constants.py).
74 |
75 | You can also specify a mixture of active learning methods by following the
76 | pattern of `[sampling_method]-[mixture_weight]` separated by dashes; i.e.
77 | `mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34`.
78 |
79 | Some supported sampling methods include:
80 |
81 | * Uniform: samples are selected via uniform sampling.
82 |
83 | * Margin: uncertainty based sampling method.
84 |
85 | * Informative and diverse: margin and cluster based sampling method.
86 |
87 | * k-center greedy: representative strategy that greedily forms a batch of
88 | points to minimize maximum distance from a labeled point.
89 |
90 | * Graph density: representative strategy that selects points in dense regions
91 | of pool.
92 |
93 | * Exp3 bandit: meta-active learning method that tries to learns optimal
94 | sampling method using a popular multi-armed bandit algorithm.
95 |
96 | ### Adding new active learning methods
97 |
98 | Implement either a base sampler that inherits from
99 | [`SamplingMethod`](sampling_methods/sampling_def.py)
100 | or a meta-sampler that calls base samplers which inherits from
101 | [`WrapperSamplingMethod`](sampling_methods/wrapper_sampler_def.py).
102 |
103 | The only method that must be implemented by any sampler is `select_batch_`,
104 | which can have arbitrary named arguments. The only restriction is that the name
105 | for the same input must be consistent across all the samplers (i.e. the indices
106 | for already selected examples all have the same name across samplers). Adding a
107 | new named argument that hasn't been used in other sampling methods will require
108 | feeding that into the `select_batch` call in
109 | [`run_experiment.py`](run_experiment.py).
110 |
111 | After implementing your sampler, be sure to add it to
112 | [`constants.py`](sampling_methods/constants.py)
113 | so that it can be called from
114 | [`run_experiment.py`](run_experiment.py).
115 |
116 | ## Available models
117 |
118 | All available models are in the `get_model` method of
119 | [`utils/utils.py`](utils/utils.py).
120 |
121 | Supported methods:
122 |
123 | * Linear SVM: scikit method with grid search wrapper for regularization
124 | parameter.
125 |
126 | * Kernel SVM: scikit method with grid search wrapper for regularization
127 | parameter.
128 |
129 | * Logistc Regression: scikit method with grid search wrapper for
130 | regularization parameter.
131 |
132 | * Small CNN: 4 layer CNN optimized using rmsprop implemented in Keras with
133 | tensorflow backend.
134 |
135 | * Kernel Least Squares Classification: block gradient descient solver that can
136 | use multiple cores so is often faster than scikit Kernel SVM.
137 |
138 | ### Adding new models
139 |
140 | New models must follow the scikit learn api and implement the following methods
141 |
142 | * `fit(X, y[, sample_weight])`: fit the model to the input features and
143 | target.
144 |
145 | * `predict(X)`: predict the value of the input features.
146 |
147 | * `score(X, y)`: returns target metric given test features and test targets.
148 |
149 | * `decision_function(X)` (optional): return class probabilities, distance to
150 | decision boundaries, or other metric that can be used by margin sampler as a
151 | measure of uncertainty.
152 |
153 | See
154 | [`small_cnn.py`](utils/small_cnn.py)
155 | for an example.
156 |
157 | After implementing your new model, be sure to add it to `get_model` method of
158 | [`utils/utils.py`](utils/utils.py).
159 |
160 | Currently models must be added on a one-off basis and not all scikit-learn
161 | classifiers are supported due to the need for user input on whether and how to
162 | tune the hyperparameters of the model. However, it is very easy to add a
163 | scikit-learn model with hyperparameter search wrapped around as a supported
164 | model.
165 |
166 | ## Collecting results and charting
167 |
168 | The
169 | [`utils/chart_data.py`](utils/chart_data.py)
170 | script handles processing of data and charting for a specified dataset and
171 | source directory.
172 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | numpy>=1.13
3 | scipy>=0.19
4 | pandas>=0.20
5 | scikit-learn>=0.19
6 | matplotlib>=2.0.2
7 | tensorflow>=1.3
8 | keras>=2.0.8
9 |
--------------------------------------------------------------------------------
/run_experiment.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Run active learner on classification tasks.
16 |
17 | Supported datasets include mnist, letter, cifar10, newsgroup20, rcv1,
18 | wikipedia attack, and select classification datasets from mldata.
19 | See utils/create_data.py for all available datasets.
20 |
21 | For binary classification, mnist_4_9 indicates mnist filtered down to just 4 and
22 | 9.
23 | By default uses logistic regression but can also train using kernel SVM.
24 | 2 fold cv is used to tune regularization parameter over a exponential grid.
25 |
26 | """
27 |
28 | from __future__ import absolute_import
29 | from __future__ import division
30 | from __future__ import print_function
31 |
32 | import os
33 | import pickle
34 | import sys
35 | from time import gmtime
36 | from time import strftime
37 |
38 | import numpy as np
39 | from sklearn.preprocessing import normalize
40 | from sklearn.preprocessing import StandardScaler
41 |
42 | from absl import app
43 | from absl import flags
44 | from tensorflow import gfile
45 |
46 | from sampling_methods.constants import AL_MAPPING
47 | from sampling_methods.constants import get_AL_sampler
48 | from sampling_methods.constants import get_wrapper_AL_mapping
49 | from utils import utils
50 |
51 | flags.DEFINE_string("dataset", "letter", "Dataset name")
52 | flags.DEFINE_string("sampling_method", "margin",
53 | ("Name of sampling method to use, can be any defined in "
54 | "AL_MAPPING in sampling_methods.constants"))
55 | flags.DEFINE_float(
56 | "warmstart_size", 0.02,
57 | ("Can be float or integer. Float indicates percentage of training data "
58 | "to use in the initial warmstart model")
59 | )
60 | flags.DEFINE_float(
61 | "batch_size", 0.02,
62 | ("Can be float or integer. Float indicates batch size as a percentage "
63 | "of training data size.")
64 | )
65 | flags.DEFINE_integer("trials", 1,
66 | "Number of curves to create using different seeds")
67 | flags.DEFINE_integer("seed", 1, "Seed to use for rng and random state")
68 | # TODO(lisha): add feature noise to simulate data outliers
69 | flags.DEFINE_string("confusions", "0.", "Percentage of labels to randomize")
70 | flags.DEFINE_string("active_sampling_percentage", "1.0",
71 | "Mixture weights on active sampling.")
72 | flags.DEFINE_string(
73 | "score_method", "logistic",
74 | "Method to use to calculate accuracy.")
75 | flags.DEFINE_string(
76 | "select_method", "None",
77 | "Method to use for selecting points.")
78 | flags.DEFINE_string("normalize_data", "False", "Whether to normalize the data.")
79 | flags.DEFINE_string("standardize_data", "True",
80 | "Whether to standardize the data.")
81 | flags.DEFINE_string("save_dir", "/tmp/toy_experiments",
82 | "Where to save outputs")
83 | flags.DEFINE_string("data_dir", "/tmp/data",
84 | "Directory with predownloaded and saved datasets.")
85 | flags.DEFINE_string("max_dataset_size", "15000",
86 | ("maximum number of datapoints to include in data "
87 | "zero indicates no limit"))
88 | flags.DEFINE_float("train_horizon", "1.0",
89 | "how far to extend learning curve as a percent of train")
90 | flags.DEFINE_string("do_save", "True",
91 | "whether to save log and results")
92 | FLAGS = flags.FLAGS
93 |
94 |
95 | get_wrapper_AL_mapping()
96 |
97 |
98 | def generate_one_curve(X,
99 | y,
100 | sampler,
101 | score_model,
102 | seed,
103 | warmstart_size,
104 | batch_size,
105 | select_model=None,
106 | confusion=0.,
107 | active_p=1.0,
108 | max_points=None,
109 | standardize_data=False,
110 | norm_data=False,
111 | train_horizon=0.5):
112 | """Creates one learning curve for both active and passive learning.
113 |
114 | Will calculate accuracy on validation set as the number of training data
115 | points increases for both PL and AL.
116 | Caveats: training method used is sensitive to sorting of the data so we
117 | resort all intermediate datasets
118 |
119 | Args:
120 | X: training data
121 | y: training labels
122 | sampler: sampling class from sampling_methods, assumes reference
123 | passed in and sampler not yet instantiated.
124 | score_model: model used to score the samplers. Expects fit and predict
125 | methods to be implemented.
126 | seed: seed used for data shuffle and other sources of randomness in sampler
127 | or model training
128 | warmstart_size: float or int. float indicates percentage of train data
129 | to use for initial model
130 | batch_size: float or int. float indicates batch size as a percent of
131 | training data
132 | select_model: defaults to None, in which case the score model will be
133 | used to select new datapoints to label. Model must implement fit, predict
134 | and depending on AL method may also need decision_function.
135 | confusion: percentage of labels of one class to flip to the other
136 | active_p: percent of batch to allocate to active learning
137 | max_points: limit dataset size for preliminary
138 | standardize_data: wheter to standardize the data to 0 mean unit variance
139 | norm_data: whether to normalize the data. Default is False for logistic
140 | regression.
141 | train_horizon: how long to draw the curve for. Percent of training data.
142 |
143 | Returns:
144 | results: dictionary of results for all samplers
145 | sampler_states: dictionary of sampler objects for debugging
146 | """
147 | # TODO(lishal): add option to find best hyperparameter setting first on
148 | # full dataset and fix the hyperparameter for the rest of the routine
149 | # This will save computation and also lead to more stable behavior for the
150 | # test accuracy
151 |
152 | # TODO(lishal): remove mixture parameter and have the mixture be specified as
153 | # a mixture of samplers strategy
154 | def select_batch(sampler, uniform_sampler, mixture, N, already_selected,
155 | **kwargs):
156 | n_active = int(mixture * N)
157 | n_passive = N - n_active
158 | kwargs["N"] = n_active
159 | kwargs["already_selected"] = already_selected
160 | batch_AL = sampler.select_batch(**kwargs)
161 | already_selected = already_selected + batch_AL
162 | kwargs["N"] = n_passive
163 | kwargs["already_selected"] = already_selected
164 | batch_PL = uniform_sampler.select_batch(**kwargs)
165 | return batch_AL + batch_PL
166 |
167 | np.random.seed(seed)
168 | data_splits = [2./3, 1./6, 1./6]
169 |
170 | # 2/3 of data for training
171 | if max_points is None:
172 | max_points = len(y)
173 | train_size = int(min(max_points, len(y)) * data_splits[0])
174 | if batch_size < 1:
175 | batch_size = int(batch_size * train_size)
176 | else:
177 | batch_size = int(batch_size)
178 | if warmstart_size < 1:
179 | # Set seed batch to provide enough samples to get at least 4 per class
180 | # TODO(lishal): switch to sklearn stratified sampler
181 | seed_batch = int(warmstart_size * train_size)
182 | else:
183 | seed_batch = int(warmstart_size)
184 | seed_batch = max(seed_batch, 6 * len(np.unique(y)))
185 |
186 | indices, X_train, y_train, X_val, y_val, X_test, y_test, y_noise = (
187 | utils.get_train_val_test_splits(X,y,max_points,seed,confusion,
188 | seed_batch, split=data_splits))
189 |
190 | # Preprocess data
191 | if norm_data:
192 | print("Normalizing data")
193 | X_train = normalize(X_train)
194 | X_val = normalize(X_val)
195 | X_test = normalize(X_test)
196 | if standardize_data:
197 | print("Standardizing data")
198 | scaler = StandardScaler().fit(X_train)
199 | X_train = scaler.transform(X_train)
200 | X_val = scaler.transform(X_val)
201 | X_test = scaler.transform(X_test)
202 | print("active percentage: " + str(active_p) + " warmstart batch: " +
203 | str(seed_batch) + " batch size: " + str(batch_size) + " confusion: " +
204 | str(confusion) + " seed: " + str(seed))
205 |
206 | # Initialize samplers
207 | uniform_sampler = AL_MAPPING["uniform"](X_train, y_train, seed)
208 | sampler = sampler(X_train, y_train, seed)
209 |
210 | results = {}
211 | data_sizes = []
212 | accuracy = []
213 | selected_inds = range(seed_batch)
214 |
215 | # If select model is None, use score_model
216 | same_score_select = False
217 | if select_model is None:
218 | select_model = score_model
219 | same_score_select = True
220 |
221 | n_batches = int(np.ceil((train_horizon * train_size - seed_batch) *
222 | 1.0 / batch_size)) + 1
223 | for b in range(n_batches):
224 | n_train = seed_batch + min(train_size - seed_batch, b * batch_size)
225 | print("Training model on " + str(n_train) + " datapoints")
226 |
227 | assert n_train == len(selected_inds)
228 | data_sizes.append(n_train)
229 |
230 | # Sort active_ind so that the end results matches that of uniform sampling
231 | partial_X = X_train[sorted(selected_inds)]
232 | partial_y = y_train[sorted(selected_inds)]
233 | score_model.fit(partial_X, partial_y)
234 | if not same_score_select:
235 | select_model.fit(partial_X, partial_y)
236 | acc = score_model.score(X_test, y_test)
237 | accuracy.append(acc)
238 | print("Sampler: %s, Accuracy: %.2f%%" % (sampler.name, accuracy[-1]*100))
239 |
240 | n_sample = min(batch_size, train_size - len(selected_inds))
241 | select_batch_inputs = {
242 | "model": select_model,
243 | "labeled": dict(zip(selected_inds, y_train[selected_inds])),
244 | "eval_acc": accuracy[-1],
245 | "X_test": X_val,
246 | "y_test": y_val,
247 | "y": y_train
248 | }
249 | new_batch = select_batch(sampler, uniform_sampler, active_p, n_sample,
250 | selected_inds, **select_batch_inputs)
251 | selected_inds.extend(new_batch)
252 | print('Requested: %d, Selected: %d' % (n_sample, len(new_batch)))
253 | assert len(new_batch) == n_sample
254 | assert len(list(set(selected_inds))) == len(selected_inds)
255 |
256 | # Check that the returned indice are correct and will allow mapping to
257 | # training set from original data
258 | assert all(y_noise[indices[selected_inds]] == y_train[selected_inds])
259 | results["accuracy"] = accuracy
260 | results["selected_inds"] = selected_inds
261 | results["data_sizes"] = data_sizes
262 | results["indices"] = indices
263 | results["noisy_targets"] = y_noise
264 | return results, sampler
265 |
266 |
267 | def main(argv):
268 | del argv
269 |
270 | if not gfile.Exists(FLAGS.save_dir):
271 | try:
272 | gfile.MkDir(FLAGS.save_dir)
273 | except:
274 | print(('WARNING: error creating save directory, '
275 | 'directory most likely already created.'))
276 |
277 | save_dir = os.path.join(
278 | FLAGS.save_dir,
279 | FLAGS.dataset + "_" + FLAGS.sampling_method)
280 | do_save = FLAGS.do_save == "True"
281 |
282 | if do_save:
283 | if not gfile.Exists(save_dir):
284 | try:
285 | gfile.MkDir(save_dir)
286 | except:
287 | print(('WARNING: error creating save directory, '
288 | 'directory most likely already created.'))
289 | # Set up logging
290 | filename = os.path.join(
291 | save_dir, "log-" + strftime("%Y-%m-%d-%H-%M-%S", gmtime()) + ".txt")
292 | sys.stdout = utils.Logger(filename)
293 |
294 | confusions = [float(t) for t in FLAGS.confusions.split(" ")]
295 | mixtures = [float(t) for t in FLAGS.active_sampling_percentage.split(" ")]
296 | all_results = {}
297 | max_dataset_size = None if FLAGS.max_dataset_size == "0" else int(
298 | FLAGS.max_dataset_size)
299 | normalize_data = FLAGS.normalize_data == "True"
300 | standardize_data = FLAGS.standardize_data == "True"
301 | X, y = utils.get_mldata(FLAGS.data_dir, FLAGS.dataset)
302 | starting_seed = FLAGS.seed
303 |
304 | for c in confusions:
305 | for m in mixtures:
306 | for seed in range(starting_seed, starting_seed + FLAGS.trials):
307 | sampler = get_AL_sampler(FLAGS.sampling_method)
308 | score_model = utils.get_model(FLAGS.score_method, seed)
309 | if (FLAGS.select_method == "None" or
310 | FLAGS.select_method == FLAGS.score_method):
311 | select_model = None
312 | else:
313 | select_model = utils.get_model(FLAGS.select_method, seed)
314 | results, sampler_state = generate_one_curve(
315 | X, y, sampler, score_model, seed, FLAGS.warmstart_size,
316 | FLAGS.batch_size, select_model, c, m, max_dataset_size,
317 | standardize_data, normalize_data, FLAGS.train_horizon)
318 | key = (FLAGS.dataset, FLAGS.sampling_method, FLAGS.score_method,
319 | FLAGS.select_method, m, FLAGS.warmstart_size, FLAGS.batch_size,
320 | c, standardize_data, normalize_data, seed)
321 | sampler_output = sampler_state.to_dict()
322 | results["sampler_output"] = sampler_output
323 | all_results[key] = results
324 | fields = [
325 | "dataset", "sampler", "score_method", "select_method",
326 | "active percentage", "warmstart size", "batch size", "confusion",
327 | "standardize", "normalize", "seed"
328 | ]
329 | all_results["tuple_keys"] = fields
330 |
331 | if do_save:
332 | filename = ("results_score_" + FLAGS.score_method +
333 | "_select_" + FLAGS.select_method +
334 | "_norm_" + str(normalize_data) +
335 | "_stand_" + str(standardize_data))
336 | existing_files = gfile.Glob(os.path.join(save_dir, filename + "*.pkl"))
337 | filename = os.path.join(save_dir,
338 | filename + "_" + str(1000+len(existing_files))[1:] + ".pkl")
339 | pickle.dump(all_results, gfile.GFile(filename, "w"))
340 | sys.stdout.flush_file()
341 |
342 |
343 | if __name__ == "__main__":
344 | app.run(main)
345 |
--------------------------------------------------------------------------------
/sampling_methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/sampling_methods/bandit_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Bandit wrapper around base AL sampling methods.
16 |
17 | Assumes adversarial multi-armed bandit setting where arms correspond to
18 | mixtures of different AL methods.
19 |
20 | Uses EXP3 algorithm to decide which AL method to use to create the next batch.
21 | Similar to Hsu & Lin 2015, Active Learning by Learning.
22 | https://www.csie.ntu.edu.tw/~htlin/paper/doc/aaai15albl.pdf
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import numpy as np
30 |
31 | from sampling_methods.wrapper_sampler_def import AL_MAPPING, WrapperSamplingMethod
32 |
33 |
34 | class BanditDiscreteSampler(WrapperSamplingMethod):
35 | """Wraps EXP3 around mixtures of indicated methods.
36 |
37 | Uses EXP3 mult-armed bandit algorithm to select sampler methods.
38 | """
39 | def __init__(self,
40 | X,
41 | y,
42 | seed,
43 | reward_function = lambda AL_acc: AL_acc[-1],
44 | gamma=0.5,
45 | samplers=[{'methods':('margin','uniform'),'weights':(0,1)},
46 | {'methods':('margin','uniform'),'weights':(1,0)}]):
47 | """Initializes sampler with indicated gamma and arms.
48 |
49 | Args:
50 | X: training data
51 | y: labels, may need to be input into base samplers
52 | seed: seed to use for random sampling
53 | reward_function: reward based on previously observed accuracies. Assumes
54 | that the input is a sequence of observed accuracies. Will ultimately be
55 | a class method and may need access to other class properties.
56 | gamma: weight on uniform mixture. Arm probability updates are a weighted
57 | mixture of uniform and an exponentially weighted distribution.
58 | Lower gamma more aggressively updates based on observed rewards.
59 | samplers: list of dicts with two fields
60 | 'samplers': list of named samplers
61 | 'weights': percentage of batch to allocate to each sampler
62 | """
63 |
64 | self.name = 'bandit_discrete'
65 | np.random.seed(seed)
66 | self.X = X
67 | self.y = y
68 | self.seed = seed
69 | self.initialize_samplers(samplers)
70 |
71 | self.gamma = gamma
72 | self.n_arms = len(samplers)
73 | self.reward_function = reward_function
74 |
75 | self.pull_history = []
76 | self.acc_history = []
77 | self.w = np.ones(self.n_arms)
78 | self.x = np.zeros(self.n_arms)
79 | self.p = self.w / (1.0 * self.n_arms)
80 | self.probs = []
81 |
82 | def update_vars(self, arm_pulled):
83 | reward = self.reward_function(self.acc_history)
84 | self.x = np.zeros(self.n_arms)
85 | self.x[arm_pulled] = reward / self.p[arm_pulled]
86 | self.w = self.w * np.exp(self.gamma * self.x / self.n_arms)
87 | self.p = ((1.0 - self.gamma) * self.w / sum(self.w)
88 | + self.gamma / self.n_arms)
89 | print(self.p)
90 | self.probs.append(self.p)
91 |
92 | def select_batch_(self, already_selected, N, eval_acc, **kwargs):
93 | """Returns batch of datapoints sampled using mixture of AL_methods.
94 |
95 | Assumes that data has already been shuffled.
96 |
97 | Args:
98 | already_selected: index of datapoints already selected
99 | N: batch size
100 | eval_acc: accuracy of model trained after incorporating datapoints from
101 | last recommended batch
102 |
103 | Returns:
104 | indices of points selected to label
105 | """
106 | # Update observed reward and arm probabilities
107 | self.acc_history.append(eval_acc)
108 | if len(self.pull_history) > 0:
109 | self.update_vars(self.pull_history[-1])
110 | # Sample an arm
111 | arm = np.random.choice(range(self.n_arms), p=self.p)
112 | self.pull_history.append(arm)
113 | kwargs['N'] = N
114 | kwargs['already_selected'] = already_selected
115 | sample = self.samplers[arm].select_batch(**kwargs)
116 | return sample
117 |
118 | def to_dict(self):
119 | output = {}
120 | output['samplers'] = self.base_samplers
121 | output['arm_probs'] = self.probs
122 | output['pull_history'] = self.pull_history
123 | output['rewards'] = self.acc_history
124 | return output
125 |
126 |
--------------------------------------------------------------------------------
/sampling_methods/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Controls imports to fill up dictionary of different sampling methods.
16 | """
17 |
18 | from functools import partial
19 | AL_MAPPING = {}
20 |
21 |
22 | def get_base_AL_mapping():
23 | from sampling_methods.margin_AL import MarginAL
24 | from sampling_methods.informative_diverse import InformativeClusterDiverseSampler
25 | from sampling_methods.hierarchical_clustering_AL import HierarchicalClusterAL
26 | from sampling_methods.uniform_sampling import UniformSampling
27 | from sampling_methods.represent_cluster_centers import RepresentativeClusterMeanSampling
28 | from sampling_methods.graph_density import GraphDensitySampler
29 | from sampling_methods.kcenter_greedy import kCenterGreedy
30 | AL_MAPPING['margin'] = MarginAL
31 | AL_MAPPING['informative_diverse'] = InformativeClusterDiverseSampler
32 | AL_MAPPING['hierarchical'] = HierarchicalClusterAL
33 | AL_MAPPING['uniform'] = UniformSampling
34 | AL_MAPPING['margin_cluster_mean'] = RepresentativeClusterMeanSampling
35 | AL_MAPPING['graph_density'] = GraphDensitySampler
36 | AL_MAPPING['kcenter'] = kCenterGreedy
37 |
38 |
39 | def get_all_possible_arms():
40 | from sampling_methods.mixture_of_samplers import MixtureOfSamplers
41 | AL_MAPPING['mixture_of_samplers'] = MixtureOfSamplers
42 |
43 |
44 | def get_wrapper_AL_mapping():
45 | from sampling_methods.bandit_discrete import BanditDiscreteSampler
46 | from sampling_methods.simulate_batch import SimulateBatchSampler
47 | AL_MAPPING['bandit_mixture'] = partial(
48 | BanditDiscreteSampler,
49 | samplers=[{
50 | 'methods': ['margin', 'uniform'],
51 | 'weights': [0, 1]
52 | }, {
53 | 'methods': ['margin', 'uniform'],
54 | 'weights': [0.25, 0.75]
55 | }, {
56 | 'methods': ['margin', 'uniform'],
57 | 'weights': [0.5, 0.5]
58 | }, {
59 | 'methods': ['margin', 'uniform'],
60 | 'weights': [0.75, 0.25]
61 | }, {
62 | 'methods': ['margin', 'uniform'],
63 | 'weights': [1, 0]
64 | }])
65 | AL_MAPPING['bandit_discrete'] = partial(
66 | BanditDiscreteSampler,
67 | samplers=[{
68 | 'methods': ['margin', 'uniform'],
69 | 'weights': [0, 1]
70 | }, {
71 | 'methods': ['margin', 'uniform'],
72 | 'weights': [1, 0]
73 | }])
74 | AL_MAPPING['simulate_batch_mixture'] = partial(
75 | SimulateBatchSampler,
76 | samplers=({
77 | 'methods': ['margin', 'uniform'],
78 | 'weights': [1, 0]
79 | }, {
80 | 'methods': ['margin', 'uniform'],
81 | 'weights': [0.5, 0.5]
82 | }, {
83 | 'methods': ['margin', 'uniform'],
84 | 'weights': [0, 1]
85 | }),
86 | n_sims=5,
87 | train_per_sim=10,
88 | return_best_sim=False)
89 | AL_MAPPING['simulate_batch_best_sim'] = partial(
90 | SimulateBatchSampler,
91 | samplers=[{
92 | 'methods': ['margin', 'uniform'],
93 | 'weights': [1, 0]
94 | }],
95 | n_sims=10,
96 | train_per_sim=10,
97 | return_type='best_sim')
98 | AL_MAPPING['simulate_batch_frequency'] = partial(
99 | SimulateBatchSampler,
100 | samplers=[{
101 | 'methods': ['margin', 'uniform'],
102 | 'weights': [1, 0]
103 | }],
104 | n_sims=10,
105 | train_per_sim=10,
106 | return_type='frequency')
107 |
108 | def get_mixture_of_samplers(name):
109 | assert 'mixture_of_samplers' in name
110 | if 'mixture_of_samplers' not in AL_MAPPING:
111 | raise KeyError('Mixture of Samplers not yet loaded.')
112 | args = name.split('-')[1:]
113 | samplers = args[0::2]
114 | weights = args[1::2]
115 | weights = [float(w) for w in weights]
116 | assert sum(weights) == 1
117 | mixture = {'methods': samplers, 'weights': weights}
118 | print(mixture)
119 | return partial(AL_MAPPING['mixture_of_samplers'], mixture=mixture)
120 |
121 |
122 | def get_AL_sampler(name):
123 | if name in AL_MAPPING and name != 'mixture_of_samplers':
124 | return AL_MAPPING[name]
125 | if 'mixture_of_samplers' in name:
126 | return get_mixture_of_samplers(name)
127 | raise NotImplementedError('The specified sampler is not available.')
128 |
--------------------------------------------------------------------------------
/sampling_methods/graph_density.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Diversity promoting sampling method that uses graph density to determine
16 | most representative points.
17 |
18 | This is an implementation of the method described in
19 | https://www.mpi-inf.mpg.de/fileadmin/inf/d2/Research_projects_files/EbertCVPR2012.pdf
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import copy
27 |
28 | from sklearn.neighbors import kneighbors_graph
29 | from sklearn.metrics import pairwise_distances
30 | import numpy as np
31 | from sampling_methods.sampling_def import SamplingMethod
32 |
33 |
34 | class GraphDensitySampler(SamplingMethod):
35 | """Diversity promoting sampling method that uses graph density to determine
36 | most representative points.
37 | """
38 |
39 | def __init__(self, X, y, seed):
40 | self.name = 'graph_density'
41 | self.X = X
42 | self.flat_X = self.flatten_X()
43 | # Set gamma for gaussian kernel to be equal to 1/n_features
44 | self.gamma = 1. / self.X.shape[1]
45 | self.compute_graph_density()
46 |
47 | def compute_graph_density(self, n_neighbor=10):
48 | # kneighbors graph is constructed using k=10
49 | connect = kneighbors_graph(self.flat_X, n_neighbor,p=1)
50 | # Make connectivity matrix symmetric, if a point is a k nearest neighbor of
51 | # another point, make it vice versa
52 | neighbors = connect.nonzero()
53 | inds = zip(neighbors[0],neighbors[1])
54 | # Graph edges are weighted by applying gaussian kernel to manhattan dist.
55 | # By default, gamma for rbf kernel is equal to 1/n_features but may
56 | # get better results if gamma is tuned.
57 | for entry in inds:
58 | i = entry[0]
59 | j = entry[1]
60 | distance = pairwise_distances(self.flat_X[[i]],self.flat_X[[j]],metric='manhattan')
61 | distance = distance[0,0]
62 | weight = np.exp(-distance * self.gamma)
63 | connect[i,j] = weight
64 | connect[j,i] = weight
65 | self.connect = connect
66 | # Define graph density for an observation to be sum of weights for all
67 | # edges to the node representing the datapoint. Normalize sum weights
68 | # by total number of neighbors.
69 | self.graph_density = np.zeros(self.X.shape[0])
70 | for i in np.arange(self.X.shape[0]):
71 | self.graph_density[i] = connect[i,:].sum() / (connect[i,:]>0).sum()
72 | self.starting_density = copy.deepcopy(self.graph_density)
73 |
74 | def select_batch_(self, N, already_selected, **kwargs):
75 | # If a neighbor has already been sampled, reduce the graph density
76 | # for its direct neighbors to promote diversity.
77 | batch = set()
78 | self.graph_density[already_selected] = min(self.graph_density) - 1
79 | while len(batch) < N:
80 | selected = np.argmax(self.graph_density)
81 | neighbors = (self.connect[selected,:] > 0).nonzero()[1]
82 | self.graph_density[neighbors] = self.graph_density[neighbors] - self.graph_density[selected]
83 | batch.add(selected)
84 | self.graph_density[already_selected] = min(self.graph_density) - 1
85 | self.graph_density[list(batch)] = min(self.graph_density) - 1
86 | return list(batch)
87 |
88 | def to_dict(self):
89 | output = {}
90 | output['connectivity'] = self.connect
91 | output['graph_density'] = self.starting_density
92 | return output
--------------------------------------------------------------------------------
/sampling_methods/hierarchical_clustering_AL.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Hierarchical cluster AL method.
16 |
17 | Implements algorithm described in Dasgupta, S and Hsu, D,
18 | "Hierarchical Sampling for Active Learning, 2008
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import numpy as np
26 | from sklearn.cluster import AgglomerativeClustering
27 | from sklearn.decomposition import PCA
28 | from sklearn.neighbors import kneighbors_graph
29 | from sampling_methods.sampling_def import SamplingMethod
30 | from sampling_methods.utils.tree import Tree
31 |
32 |
33 | class HierarchicalClusterAL(SamplingMethod):
34 | """Implements hierarchical cluster AL based method.
35 |
36 | All methods are internal. select_batch_ is called via abstract classes
37 | outward facing method select_batch.
38 |
39 | Default affininity is euclidean and default linkage is ward which links
40 | cluster based on variance reduction. Hence, good results depend on
41 | having normalized and standardized data.
42 | """
43 |
44 | def __init__(self, X, y, seed, beta=2, affinity='euclidean', linkage='ward',
45 | clustering=None, max_features=None):
46 | """Initializes AL method and fits hierarchical cluster to data.
47 |
48 | Args:
49 | X: data
50 | y: labels for determinining number of clusters as an input to
51 | AgglomerativeClustering
52 | seed: random seed used for sampling datapoints for batch
53 | beta: width of error used to decide admissble labels, higher value of beta
54 | corresponds to wider confidence and less stringent definition of
55 | admissibility
56 | See scikit Aggloerative clustering method for more info
57 | affinity: distance metric used for hierarchical clustering
58 | linkage: linkage method used to determine when to join clusters
59 | clustering: can provide an AgglomerativeClustering that is already fit
60 | max_features: limit number of features used to construct hierarchical
61 | cluster. If specified, PCA is used to perform feature reduction and
62 | the hierarchical clustering is performed using transformed features.
63 | """
64 | self.name = 'hierarchical'
65 | self.seed = seed
66 | np.random.seed(seed)
67 | # Variables for the hierarchical cluster
68 | self.already_clustered = False
69 | if clustering is not None:
70 | self.model = clustering
71 | self.already_clustered = True
72 | self.n_leaves = None
73 | self.n_components = None
74 | self.children_list = None
75 | self.node_dict = None
76 | self.root = None # Node name, all node instances access through self.tree
77 | self.tree = None
78 | # Variables for the AL algorithm
79 | self.initialized = False
80 | self.beta = beta
81 | self.labels = {}
82 | self.pruning = []
83 | self.admissible = {}
84 | self.selected_nodes = None
85 | # Data variables
86 | self.classes = None
87 | self.X = X
88 |
89 | classes = list(set(y))
90 | self.n_classes = len(classes)
91 | if max_features is not None:
92 | transformer = PCA(n_components=max_features)
93 | transformer.fit(X)
94 | self.transformed_X = transformer.fit_transform(X)
95 | #connectivity = kneighbors_graph(self.transformed_X,max_features)
96 | self.model = AgglomerativeClustering(
97 | affinity=affinity, linkage=linkage, n_clusters=len(classes))
98 | self.fit_cluster(self.transformed_X)
99 | else:
100 | self.model = AgglomerativeClustering(
101 | affinity=affinity, linkage=linkage, n_clusters=len(classes))
102 | self.fit_cluster(self.X)
103 | self.y = y
104 |
105 | self.y_labels = {}
106 | # Fit cluster and update cluster variables
107 |
108 | self.create_tree()
109 | print('Finished creating hierarchical cluster')
110 |
111 | def fit_cluster(self, X):
112 | if not self.already_clustered:
113 | self.model.fit(X)
114 | self.already_clustered = True
115 | self.n_leaves = self.model.n_leaves_
116 | self.n_components = self.model.n_components_
117 | self.children_list = self.model.children_
118 |
119 | def create_tree(self):
120 | node_dict = {}
121 | for i in range(self.n_leaves):
122 | node_dict[i] = [None, None]
123 | for i in range(len(self.children_list)):
124 | node_dict[self.n_leaves + i] = self.children_list[i]
125 | self.node_dict = node_dict
126 | # The sklearn hierarchical clustering algo numbers leaves which correspond
127 | # to actual datapoints 0 to n_points - 1 and all internal nodes have
128 | # ids greater than n_points - 1 with the root having the highest node id
129 | self.root = max(self.node_dict.keys())
130 | self.tree = Tree(self.root, self.node_dict)
131 | self.tree.create_child_leaves_mapping(range(self.n_leaves))
132 | for v in node_dict:
133 | self.admissible[v] = set()
134 |
135 | def get_child_leaves(self, node):
136 | return self.tree.get_child_leaves(node)
137 |
138 | def get_node_leaf_counts(self, node_list):
139 | node_counts = []
140 | for v in node_list:
141 | node_counts.append(len(self.get_child_leaves(v)))
142 | return np.array(node_counts)
143 |
144 | def get_class_counts(self, y):
145 | """Gets the count of all classes in a sample.
146 |
147 | Args:
148 | y: sample vector for which to perform the count
149 | Returns:
150 | count of classes for the sample vector y, the class order for count will
151 | be the same as that of self.classes
152 | """
153 | unique, counts = np.unique(y, return_counts=True)
154 | complete_counts = []
155 | for c in self.classes:
156 | if c not in unique:
157 | complete_counts.append(0)
158 | else:
159 | index = np.where(unique == c)[0][0]
160 | complete_counts.append(counts[index])
161 | return np.array(complete_counts)
162 |
163 | def observe_labels(self, labeled):
164 | for i in labeled:
165 | self.y_labels[i] = labeled[i]
166 | self.classes = np.array(
167 | sorted(list(set([self.y_labels[k] for k in self.y_labels]))))
168 | self.n_classes = len(self.classes)
169 |
170 | def initialize_algo(self):
171 | self.pruning = [self.root]
172 | self.labels[self.root] = np.random.choice(self.classes)
173 | node = self.tree.get_node(self.root)
174 | node.best_label = self.labels[self.root]
175 | self.selected_nodes = [self.root]
176 |
177 | def get_node_class_probabilities(self, node, y=None):
178 | children = self.get_child_leaves(node)
179 | if y is None:
180 | y_dict = self.y_labels
181 | else:
182 | y_dict = dict(zip(range(len(y)), y))
183 | labels = [y_dict[c] for c in children if c in y_dict]
184 | # If no labels have been observed, simply return uniform distribution
185 | if len(labels) == 0:
186 | return 0, np.ones(self.n_classes)/self.n_classes
187 | return len(labels), self.get_class_counts(labels) / (len(labels) * 1.0)
188 |
189 | def get_node_upper_lower_bounds(self, node):
190 | n_v, p_v = self.get_node_class_probabilities(node)
191 | # If no observations, return worst possible upper lower bounds
192 | if n_v == 0:
193 | return np.zeros(len(p_v)), np.ones(len(p_v))
194 | delta = 1. / n_v + np.sqrt(p_v * (1 - p_v) / (1. * n_v))
195 | return (np.maximum(p_v - delta, np.zeros(len(p_v))),
196 | np.minimum(p_v + delta, np.ones(len(p_v))))
197 |
198 | def get_node_admissibility(self, node):
199 | p_lb, p_up = self.get_node_upper_lower_bounds(node)
200 | all_other_min = np.vectorize(
201 | lambda i:min([1 - p_up[c] for c in range(len(self.classes)) if c != i]))
202 | lowest_alternative_error = self.beta * all_other_min(
203 | np.arange(len(self.classes)))
204 | return 1 - p_lb < lowest_alternative_error
205 |
206 | def get_adjusted_error(self, node):
207 | _, prob = self.get_node_class_probabilities(node)
208 | error = 1 - prob
209 | admissible = self.get_node_admissibility(node)
210 | not_admissible = np.where(admissible != True)[0]
211 | error[not_admissible] = 1.0
212 | return error
213 |
214 | def get_class_probability_pruning(self, method='lower'):
215 | prob_pruning = []
216 | for v in self.pruning:
217 | label = self.labels[v]
218 | label_ind = np.where(self.classes == label)[0][0]
219 | if method == 'empirical':
220 | _, v_prob = self.get_node_class_probabilities(v)
221 | else:
222 | lower, upper = self.get_node_upper_lower_bounds(v)
223 | if method == 'lower':
224 | v_prob = lower
225 | elif method == 'upper':
226 | v_prob = upper
227 | else:
228 | raise NotImplementedError
229 | prob = v_prob[label_ind]
230 | prob_pruning.append(prob)
231 | return np.array(prob_pruning)
232 |
233 | def get_pruning_impurity(self, y):
234 | impurity = []
235 | for v in self.pruning:
236 | _, prob = self.get_node_class_probabilities(v, y)
237 | impurity.append(1-max(prob))
238 | impurity = np.array(impurity)
239 | weights = self.get_node_leaf_counts(self.pruning)
240 | weights = weights / sum(weights)
241 | return sum(impurity*weights)
242 |
243 | def update_scores(self):
244 | node_list = set(range(self.n_leaves))
245 | # Loop through generations from bottom to top
246 | while len(node_list) > 0:
247 | parents = set()
248 | for v in node_list:
249 | node = self.tree.get_node(v)
250 | # Update admissible labels for node
251 | admissible = self.get_node_admissibility(v)
252 | admissable_indices = np.where(admissible)[0]
253 | for l in self.classes[admissable_indices]:
254 | self.admissible[v].add(l)
255 | # Calculate score
256 | v_error = self.get_adjusted_error(v)
257 | best_label_ind = np.argmin(v_error)
258 | if admissible[best_label_ind]:
259 | node.best_label = self.classes[best_label_ind]
260 | score = v_error[best_label_ind]
261 | node.split = False
262 |
263 | # Determine if node should be split
264 | if v >= self.n_leaves: # v is not a leaf
265 | if len(admissable_indices) > 0: # There exists an admissible label
266 | # Make sure label set for node so that we can flow to children
267 | # if necessary
268 | assert node.best_label is not None
269 | # Only split if all ancestors are admissible nodes
270 | # This is part of definition of admissible pruning
271 | admissible_ancestors = [len(self.admissible[a]) > 0 for a in
272 | self.tree.get_ancestor(node)]
273 | if all(admissible_ancestors):
274 | left = self.node_dict[v][0]
275 | left_node = self.tree.get_node(left)
276 | right = self.node_dict[v][1]
277 | right_node = self.tree.get_node(right)
278 | node_counts = self.get_node_leaf_counts([v, left, right])
279 | split_score = (node_counts[1] / node_counts[0] *
280 | left_node.score + node_counts[2] /
281 | node_counts[0] * right_node.score)
282 | if split_score < score:
283 | score = split_score
284 | node.split = True
285 | node.score = score
286 | if node.parent:
287 | parents.add(node.parent.name)
288 | node_list = parents
289 |
290 | def update_pruning_labels(self):
291 | for v in self.selected_nodes:
292 | node = self.tree.get_node(v)
293 | pruning = self.tree.get_pruning(node)
294 | self.pruning.remove(v)
295 | self.pruning.extend(pruning)
296 | # Check that pruning covers all leave nodes
297 | node_counts = self.get_node_leaf_counts(self.pruning)
298 | assert sum(node_counts) == self.n_leaves
299 | # Fill in labels
300 | for v in self.pruning:
301 | node = self.tree.get_node(v)
302 | if node.best_label is None:
303 | node.best_label = node.parent.best_label
304 | self.labels[v] = node.best_label
305 |
306 | def get_fake_labels(self):
307 | fake_y = np.zeros(self.X.shape[0])
308 | for p in self.pruning:
309 | indices = self.get_child_leaves(p)
310 | fake_y[indices] = self.labels[p]
311 | return fake_y
312 |
313 | def train_using_fake_labels(self, model, X_test, y_test):
314 | classes_labeled = set([self.labels[p] for p in self.pruning])
315 | if len(classes_labeled) == self.n_classes:
316 | fake_y = self.get_fake_labels()
317 | model.fit(self.X, fake_y)
318 | test_acc = model.score(X_test, y_test)
319 | return test_acc
320 | return 0
321 |
322 | def select_batch_(self, N, already_selected, labeled, y, **kwargs):
323 | # Observe labels for previously recommended batches
324 | self.observe_labels(labeled)
325 |
326 | if not self.initialized:
327 | self.initialize_algo()
328 | self.initialized = True
329 | print('Initialized algo')
330 |
331 | print('Updating scores and pruning for labels from last batch')
332 | self.update_scores()
333 | self.update_pruning_labels()
334 | print('Nodes in pruning: %d' % (len(self.pruning)))
335 | print('Actual impurity for pruning is: %.2f' %
336 | (self.get_pruning_impurity(y)))
337 |
338 | # TODO(lishal): implement multiple selection methods
339 | selected_nodes = set()
340 | weights = self.get_node_leaf_counts(self.pruning)
341 | probs = 1 - self.get_class_probability_pruning()
342 | weights = weights * probs
343 | weights = weights / sum(weights)
344 | batch = []
345 |
346 | print('Sampling batch')
347 | while len(batch) < N:
348 | node = np.random.choice(list(self.pruning), p=weights)
349 | children = self.get_child_leaves(node)
350 | children = [
351 | c for c in children if c not in self.y_labels and c not in batch
352 | ]
353 | if len(children) > 0:
354 | selected_nodes.add(node)
355 | batch.append(np.random.choice(children))
356 | self.selected_nodes = selected_nodes
357 | return batch
358 |
359 | def to_dict(self):
360 | output = {}
361 | output['node_dict'] = self.node_dict
362 | return output
363 |
--------------------------------------------------------------------------------
/sampling_methods/informative_diverse.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Informative and diverse batch sampler that samples points with small margin
16 | while maintaining same distribution over clusters as entire training data.
17 |
18 | Batch is created by sorting datapoints by increasing margin and then growing
19 | the batch greedily. A point is added to the batch if the result batch still
20 | respects the constraint that the cluster distribution of the batch will
21 | match the cluster distribution of the entire training set.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from sklearn.cluster import MiniBatchKMeans
29 | import numpy as np
30 | from sampling_methods.sampling_def import SamplingMethod
31 |
32 |
33 | class InformativeClusterDiverseSampler(SamplingMethod):
34 | """Selects batch based on informative and diverse criteria.
35 |
36 | Returns highest uncertainty lowest margin points while maintaining
37 | same distribution over clusters as entire dataset.
38 | """
39 |
40 | def __init__(self, X, y, seed):
41 | self.name = 'informative_and_diverse'
42 | self.X = X
43 | self.flat_X = self.flatten_X()
44 | # y only used for determining how many clusters there should be
45 | # probably not practical to assume we know # of classes before hand
46 | # should also probably scale with dimensionality of data
47 | self.y = y
48 | self.n_clusters = len(list(set(y)))
49 | self.cluster_model = MiniBatchKMeans(n_clusters=self.n_clusters)
50 | self.cluster_data()
51 |
52 | def cluster_data(self):
53 | # Probably okay to always use MiniBatchKMeans
54 | # Should standardize data before clustering
55 | # Can cluster on standardized data but train on raw features if desired
56 | self.cluster_model.fit(self.flat_X)
57 | unique, counts = np.unique(self.cluster_model.labels_, return_counts=True)
58 | self.cluster_prob = counts/sum(counts)
59 | self.cluster_labels = self.cluster_model.labels_
60 |
61 | def select_batch_(self, model, already_selected, N, **kwargs):
62 | """Returns a batch of size N using informative and diverse selection.
63 |
64 | Args:
65 | model: scikit learn model with decision_function implemented
66 | already_selected: index of datapoints already selected
67 | N: batch size
68 |
69 | Returns:
70 | indices of points selected to add using margin active learner
71 | """
72 | # TODO(lishal): have MarginSampler and this share margin function
73 | try:
74 | distances = model.decision_function(self.X)
75 | except:
76 | distances = model.predict_proba(self.X)
77 | if len(distances.shape) < 2:
78 | min_margin = abs(distances)
79 | else:
80 | sort_distances = np.sort(distances, 1)[:, -2:]
81 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
82 | rank_ind = np.argsort(min_margin)
83 | rank_ind = [i for i in rank_ind if i not in already_selected]
84 | new_batch_cluster_counts = [0 for _ in range(self.n_clusters)]
85 | new_batch = []
86 | for i in rank_ind:
87 | if len(new_batch) == N:
88 | break
89 | label = self.cluster_labels[i]
90 | if new_batch_cluster_counts[label] / N < self.cluster_prob[label]:
91 | new_batch.append(i)
92 | new_batch_cluster_counts[label] += 1
93 | n_slot_remaining = N - len(new_batch)
94 | batch_filler = list(set(rank_ind) - set(already_selected) - set(new_batch))
95 | new_batch.extend(batch_filler[0:n_slot_remaining])
96 | return new_batch
97 |
98 | def to_dict(self):
99 | output = {}
100 | output['cluster_membership'] = self.cluster_labels
101 | return output
102 |
--------------------------------------------------------------------------------
/sampling_methods/kcenter_greedy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Returns points that minimizes the maximum distance of any point to a center.
16 |
17 | Implements the k-Center-Greedy method in
18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for
19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017
20 |
21 | Distance metric defaults to l2 distance. Features used to calculate distance
22 | are either raw features or if a model has transform method then uses the output
23 | of model.transform(X).
24 |
25 | Can be extended to a robust k centers algorithm that ignores a certain number of
26 | outlier datapoints. Resulting centers are solution to multiple integer program.
27 | """
28 |
29 | from __future__ import absolute_import
30 | from __future__ import division
31 | from __future__ import print_function
32 |
33 | import numpy as np
34 | from sklearn.metrics import pairwise_distances
35 | from sampling_methods.sampling_def import SamplingMethod
36 |
37 |
38 | class kCenterGreedy(SamplingMethod):
39 |
40 | def __init__(self, X, y, seed, metric='euclidean'):
41 | self.X = X
42 | self.y = y
43 | self.flat_X = self.flatten_X()
44 | self.name = 'kcenter'
45 | self.features = self.flat_X
46 | self.metric = metric
47 | self.min_distances = None
48 | self.n_obs = self.X.shape[0]
49 | self.already_selected = []
50 |
51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
52 | """Update min distances given cluster centers.
53 |
54 | Args:
55 | cluster_centers: indices of cluster centers
56 | only_new: only calculate distance for newly selected points and update
57 | min_distances.
58 | rest_dist: whether to reset min_distances.
59 | """
60 |
61 | if reset_dist:
62 | self.min_distances = None
63 | if only_new:
64 | cluster_centers = [d for d in cluster_centers
65 | if d not in self.already_selected]
66 | if cluster_centers:
67 | # Update min_distances for all examples given new cluster center.
68 | x = self.features[cluster_centers]
69 | dist = pairwise_distances(self.features, x, metric=self.metric)
70 |
71 | if self.min_distances is None:
72 | self.min_distances = np.min(dist, axis=1).reshape(-1,1)
73 | else:
74 | self.min_distances = np.minimum(self.min_distances, dist)
75 |
76 | def select_batch_(self, model, already_selected, N, **kwargs):
77 | """
78 | Diversity promoting active learning method that greedily forms a batch
79 | to minimize the maximum distance to a cluster center among all unlabeled
80 | datapoints.
81 |
82 | Args:
83 | model: model with scikit-like API with decision_function implemented
84 | already_selected: index of datapoints already selected
85 | N: batch size
86 |
87 | Returns:
88 | indices of points selected to minimize distance to cluster centers
89 | """
90 |
91 | try:
92 | # Assumes that the transform function takes in original data and not
93 | # flattened data.
94 | print('Getting transformed features...')
95 | self.features = model.transform(self.X)
96 | print('Calculating distances...')
97 | self.update_distances(already_selected, only_new=False, reset_dist=True)
98 | except:
99 | print('Using flat_X as features.')
100 | self.update_distances(already_selected, only_new=True, reset_dist=False)
101 |
102 | new_batch = []
103 |
104 | for _ in range(N):
105 | if self.already_selected is None:
106 | # Initialize centers with a randomly selected datapoint
107 | ind = np.random.choice(np.arange(self.n_obs))
108 | else:
109 | ind = np.argmax(self.min_distances)
110 | # New examples should not be in already selected since those points
111 | # should have min_distance of zero to a cluster center.
112 | assert ind not in already_selected
113 |
114 | self.update_distances([ind], only_new=True, reset_dist=False)
115 | new_batch.append(ind)
116 | print('Maximum distance from cluster centers is %0.2f'
117 | % max(self.min_distances))
118 |
119 |
120 | self.already_selected = already_selected
121 |
122 | return new_batch
123 |
124 |
--------------------------------------------------------------------------------
/sampling_methods/margin_AL.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Margin based AL method.
16 |
17 | Samples in batches based on margin scores.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import numpy as np
25 | from sampling_methods.sampling_def import SamplingMethod
26 |
27 |
28 | class MarginAL(SamplingMethod):
29 | def __init__(self, X, y, seed):
30 | self.X = X
31 | self.y = y
32 | self.name = 'margin'
33 |
34 | def select_batch_(self, model, already_selected, N, **kwargs):
35 | """Returns batch of datapoints with smallest margin/highest uncertainty.
36 |
37 | For binary classification, can just take the absolute distance to decision
38 | boundary for each point.
39 | For multiclass classification, must consider the margin between distance for
40 | top two most likely classes.
41 |
42 | Args:
43 | model: scikit learn model with decision_function implemented
44 | already_selected: index of datapoints already selected
45 | N: batch size
46 |
47 | Returns:
48 | indices of points selected to add using margin active learner
49 | """
50 |
51 | try:
52 | distances = model.decision_function(self.X)
53 | except:
54 | distances = model.predict_proba(self.X)
55 | if len(distances.shape) < 2:
56 | min_margin = abs(distances)
57 | else:
58 | sort_distances = np.sort(distances, 1)[:, -2:]
59 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
60 | rank_ind = np.argsort(min_margin)
61 | rank_ind = [i for i in rank_ind if i not in already_selected]
62 | active_samples = rank_ind[0:N]
63 | return active_samples
64 |
65 |
--------------------------------------------------------------------------------
/sampling_methods/mixture_of_samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mixture of base sampling strategies
16 |
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import copy
24 |
25 | from sampling_methods.sampling_def import SamplingMethod
26 | from sampling_methods.constants import AL_MAPPING, get_base_AL_mapping
27 |
28 | get_base_AL_mapping()
29 |
30 |
31 | class MixtureOfSamplers(SamplingMethod):
32 | """Samples according to mixture of base sampling methods.
33 |
34 | If duplicate points are selected by the mixed strategies when forming the batch
35 | then the remaining slots are divided according to mixture weights and
36 | another partial batch is requested until the batch is full.
37 | """
38 | def __init__(self,
39 | X,
40 | y,
41 | seed,
42 | mixture={'methods': ('margin', 'uniform'),
43 | 'weight': (0.5, 0.5)},
44 | samplers=None):
45 | self.X = X
46 | self.y = y
47 | self.name = 'mixture_of_samplers'
48 | self.sampling_methods = mixture['methods']
49 | self.sampling_weights = dict(zip(mixture['methods'], mixture['weights']))
50 | self.seed = seed
51 | # A list of initialized samplers is allowed as an input because
52 | # for AL_methods that search over different mixtures, may want mixtures to
53 | # have shared AL_methods so that initialization is only performed once for
54 | # computation intensive methods like HierarchicalClusteringAL and
55 | # states are shared between mixtures.
56 | # If initialized samplers are not provided, initialize them ourselves.
57 | if samplers is None:
58 | self.samplers = {}
59 | self.initialize(self.sampling_methods)
60 | else:
61 | self.samplers = samplers
62 | self.history = []
63 |
64 | def initialize(self, samplers):
65 | self.samplers = {}
66 | for s in samplers:
67 | self.samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed)
68 |
69 | def select_batch_(self, already_selected, N, **kwargs):
70 | """Returns batch of datapoints selected according to mixture weights.
71 |
72 | Args:
73 | already_included: index of datapoints already selected
74 | N: batch size
75 |
76 | Returns:
77 | indices of points selected to add using margin active learner
78 | """
79 | kwargs['already_selected'] = copy.copy(already_selected)
80 | inds = set()
81 | self.selected_by_sampler = {}
82 | for s in self.sampling_methods:
83 | self.selected_by_sampler[s] = []
84 | effective_N = 0
85 | while len(inds) < N:
86 | effective_N += N - len(inds)
87 | for s in self.sampling_methods:
88 | if len(inds) < N:
89 | batch_size = min(max(int(self.sampling_weights[s] * effective_N), 1), N)
90 | sampler = self.samplers[s]
91 | kwargs['N'] = batch_size
92 | s_inds = sampler.select_batch(**kwargs)
93 | for ind in s_inds:
94 | if ind not in self.selected_by_sampler[s]:
95 | self.selected_by_sampler[s].append(ind)
96 | s_inds = [d for d in s_inds if d not in inds]
97 | s_inds = s_inds[0 : min(len(s_inds), N-len(inds))]
98 | inds.update(s_inds)
99 | self.history.append(copy.deepcopy(self.selected_by_sampler))
100 | return list(inds)
101 |
102 | def to_dict(self):
103 | output = {}
104 | output['history'] = self.history
105 | output['samplers'] = self.sampling_methods
106 | output['mixture_weights'] = self.sampling_weights
107 | for s in self.samplers:
108 | s_output = self.samplers[s].to_dict()
109 | output[s] = s_output
110 | return output
111 |
--------------------------------------------------------------------------------
/sampling_methods/represent_cluster_centers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Another informative and diverse sampler that mirrors the algorithm described
16 | in Xu, et. al., Representative Sampling for Text Classification Using
17 | Support Vector Machines, 2003
18 |
19 | Batch is created by clustering points within the margin of the classifier and
20 | choosing points closest to the k centroids.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from sklearn.cluster import MiniBatchKMeans
28 | import numpy as np
29 | from sampling_methods.sampling_def import SamplingMethod
30 |
31 |
32 | class RepresentativeClusterMeanSampling(SamplingMethod):
33 | """Selects batch based on informative and diverse criteria.
34 |
35 | Returns points within the margin of the classifier that are closest to the
36 | k-means centers of those points.
37 | """
38 |
39 | def __init__(self, X, y, seed):
40 | self.name = 'cluster_mean'
41 | self.X = X
42 | self.flat_X = self.flatten_X()
43 | self.y = y
44 | self.seed = seed
45 |
46 | def select_batch_(self, model, N, already_selected, **kwargs):
47 | # Probably okay to always use MiniBatchKMeans
48 | # Should standardize data before clustering
49 | # Can cluster on standardized data but train on raw features if desired
50 | try:
51 | distances = model.decision_function(self.X)
52 | except:
53 | distances = model.predict_proba(self.X)
54 | if len(distances.shape) < 2:
55 | min_margin = abs(distances)
56 | else:
57 | sort_distances = np.sort(distances, 1)[:, -2:]
58 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
59 | rank_ind = np.argsort(min_margin)
60 | rank_ind = [i for i in rank_ind if i not in already_selected]
61 |
62 | distances = abs(model.decision_function(self.X))
63 | min_margin_by_class = np.min(abs(distances[already_selected]),axis=0)
64 | unlabeled_in_margin = np.array([i for i in range(len(self.y))
65 | if i not in already_selected and
66 | any(distances[i] 2:
42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
43 | return flat_X
44 |
45 |
46 | @abc.abstractmethod
47 | def select_batch_(self):
48 | return
49 |
50 | def select_batch(self, **kwargs):
51 | return self.select_batch_(**kwargs)
52 |
53 | def to_dict(self):
54 | return None
--------------------------------------------------------------------------------
/sampling_methods/simulate_batch.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ Select a new batch based on results of simulated trajectories."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import copy
22 | import math
23 |
24 | import numpy as np
25 |
26 | from sampling_methods.wrapper_sampler_def import AL_MAPPING
27 | from sampling_methods.wrapper_sampler_def import WrapperSamplingMethod
28 |
29 |
30 | class SimulateBatchSampler(WrapperSamplingMethod):
31 | """Creates batch based on trajectories simulated using smaller batch sizes.
32 |
33 | Current support use case: simulate smaller batches than the batch size
34 | actually indicated to emulate which points would be selected in a
35 | smaller batch setting. This method can do better than just selecting
36 | a batch straight out if smaller batches perform better and the simulations
37 | are informative enough and are not hurt too much by labeling noise.
38 | """
39 |
40 | def __init__(self,
41 | X,
42 | y,
43 | seed,
44 | samplers=[{'methods': ('margin', 'uniform'),'weight': (1, 0)}],
45 | n_sims=10,
46 | train_per_sim=10,
47 | return_type='best_sim'):
48 | """ Initialize sampler with options.
49 |
50 | Args:
51 | X: training data
52 | y: labels may be used by base sampling methods
53 | seed: seed for np.random
54 | samplers: list of dicts with two fields
55 | 'samplers': list of named samplers
56 | 'weights': percentage of batch to allocate to each sampler
57 | n_sims: number of total trajectories to simulate
58 | train_per_sim: number of minibatches to split the batch into
59 | return_type: two return types supported right now
60 | best_sim: return points selected by the best trajectory
61 | frequency: returns points selected the most over all trajectories
62 | """
63 | self.name = 'simulate_batch'
64 | self.X = X
65 | self.y = y
66 | self.seed = seed
67 | self.n_sims = n_sims
68 | self.train_per_sim = train_per_sim
69 | self.return_type = return_type
70 | self.samplers_list = samplers
71 | self.initialize_samplers(self.samplers_list)
72 | self.trace = []
73 | self.selected = []
74 | np.random.seed(seed)
75 |
76 | def simulate_batch(self, sampler, N, already_selected, y, model, X_test,
77 | y_test, **kwargs):
78 | """Simulates smaller batches by using hallucinated y to select next batch.
79 |
80 | Assumes that select_batch is only dependent on already_selected and not on
81 | any other states internal to the sampler. i.e. this would not work with
82 | BanditDiscreteSampler but will work with margin, hierarchical, and uniform.
83 |
84 | Args:
85 | sampler: dict with two fields
86 | 'samplers': list of named samplers
87 | 'weights': percentage of batch to allocate to each sampler
88 | N: batch size
89 | already_selected: indices already labeled
90 | y: y to use for training
91 | model: model to use for margin calc
92 | X_test: validaiton data
93 | y_test: validation labels
94 |
95 | Returns:
96 | - mean accuracy
97 | - indices selected by best hallucinated trajectory
98 | - best accuracy achieved by one of the trajectories
99 | """
100 | minibatch = max(int(math.ceil(N / self.train_per_sim)), 1)
101 | results = []
102 | best_acc = 0
103 | best_inds = []
104 | self.selected = []
105 | n_minibatch = int(N/minibatch) + (N % minibatch > 0)
106 |
107 | for _ in range(self.n_sims):
108 | inds = []
109 | hallucinated_y = []
110 |
111 | # Copy these objects to make sure they are not modified while simulating
112 | # trajectories as they are used later by the main run_experiment script.
113 | kwargs['already_selected'] = copy.copy(already_selected)
114 | kwargs['y'] = copy.copy(y)
115 | # Assumes that model has already by fit using all labeled data so
116 | # the probabilities can be used immediately to hallucinate labels
117 | kwargs['model'] = copy.deepcopy(model)
118 |
119 | for _ in range(n_minibatch):
120 | batch_size = min(minibatch, N-len(inds))
121 | if batch_size > 0:
122 | kwargs['N'] = batch_size
123 | new_inds = sampler.select_batch(**kwargs)
124 | inds.extend(new_inds)
125 |
126 | # All models need to have predict_proba method
127 | probs = kwargs['model'].predict_proba(self.X[new_inds])
128 | # Hallucinate labels for selected datapoints to be label
129 | # using class probabilities from model
130 | try:
131 | classes = kwargs['model'].best_estimator_.classes_
132 | except:
133 | classes = kwargs['model'].classes_
134 | new_y = ([
135 | np.random.choice(classes, p=probs[i, :])
136 | for i in range(batch_size)
137 | ])
138 | hallucinated_y.extend(new_y)
139 | # Not saving already_selected here, if saving then should sort
140 | # only for the input to fit but preserve ordering of indices in
141 | # already_selected
142 | kwargs['already_selected'] = sorted(kwargs['already_selected']
143 | + new_inds)
144 | kwargs['y'][new_inds] = new_y
145 | kwargs['model'].fit(self.X[kwargs['already_selected']],
146 | kwargs['y'][kwargs['already_selected']])
147 | acc_hallucinated = kwargs['model'].score(X_test, y_test)
148 | if acc_hallucinated > best_acc:
149 | best_acc = acc_hallucinated
150 | best_inds = inds
151 | kwargs['model'].fit(self.X[kwargs['already_selected']],
152 | y[kwargs['already_selected']])
153 | # Useful to know how accuracy compares for model trained on hallucinated
154 | # labels vs trained on true labels. But can remove this train to speed
155 | # up simulations. Won't speed up significantly since many more models
156 | # are being trained inside the loop above.
157 | acc_true = kwargs['model'].score(X_test, y_test)
158 | results.append([acc_hallucinated, acc_true])
159 | print('Hallucinated acc: %.3f, Actual Acc: %.3f' % (acc_hallucinated,
160 | acc_true))
161 |
162 | # Save trajectory for reference
163 | t = {}
164 | t['arm'] = sampler
165 | t['data_size'] = len(kwargs['already_selected'])
166 | t['inds'] = inds
167 | t['y_hal'] = hallucinated_y
168 | t['acc_hal'] = acc_hallucinated
169 | t['acc_true'] = acc_true
170 | self.trace.append(t)
171 | self.selected.extend(inds)
172 | # Delete created copies
173 | del kwargs['model']
174 | del kwargs['already_selected']
175 | results = np.array(results)
176 | return np.mean(results, axis=0), best_inds, best_acc
177 |
178 | def sampler_select_batch(self, sampler, N, already_selected, y, model, X_test, y_test, **kwargs):
179 | """Calculate the performance of the model if the batch had been selected using the base method without simulation.
180 |
181 | Args:
182 | sampler: dict with two fields
183 | 'samplers': list of named samplers
184 | 'weights': percentage of batch to allocate to each sampler
185 | N: batch size
186 | already_selected: indices already selected
187 | y: labels to use for training
188 | model: model to use for training
189 | X_test, y_test: validation set
190 |
191 | Returns:
192 | - indices selected by base method
193 | - validation accuracy of model trained on new batch
194 | """
195 | m = copy.deepcopy(model)
196 | kwargs['y'] = y
197 | kwargs['model'] = m
198 | kwargs['already_selected'] = copy.copy(already_selected)
199 | inds = []
200 | kwargs['N'] = N
201 | inds.extend(sampler.select_batch(**kwargs))
202 | kwargs['already_selected'] = sorted(kwargs['already_selected'] + inds)
203 |
204 | m.fit(self.X[kwargs['already_selected']], y[kwargs['already_selected']])
205 | acc = m.score(X_test, y_test)
206 | del m
207 | del kwargs['already_selected']
208 | return inds, acc
209 |
210 | def select_batch_(self, N, already_selected, y, model,
211 | X_test, y_test, **kwargs):
212 | """ Returns a batch of size N selected by using the best sampler in simulation
213 |
214 | Args:
215 | samplers: list of sampling methods represented by dict with two fields
216 | 'samplers': list of named samplers
217 | 'weights': percentage of batch to allocate to each sampler
218 | N: batch size
219 | already_selected: indices of datapoints already labeled
220 | y: actual labels, used to compare simulation with actual
221 | model: training model to use to evaluate different samplers. Model must
222 | have a predict_proba method with same signature as that in sklearn
223 | n_sims: the number of simulations to perform for each sampler
224 | minibatch: batch size to use for simulation
225 | """
226 |
227 | results = []
228 |
229 | # THE INPUTS CANNOT BE MODIFIED SO WE MAKE COPIES FOR THE CHECK LATER
230 | # Should check model but kernel_svm does not have coef_ so need better
231 | # handling here
232 | copy_selected = copy.copy(already_selected)
233 | copy_y = copy.copy(y)
234 |
235 | for s in self.samplers:
236 | sim_results, sim_inds, sim_acc = self.simulate_batch(
237 | s, N, already_selected, y, model, X_test, y_test, **kwargs)
238 | real_inds, acc = self.sampler_select_batch(
239 | s, N, already_selected, y, model, X_test, y_test, **kwargs)
240 | print('Best simulated acc: %.3f, Actual acc: %.3f' % (sim_acc, acc))
241 | results.append([sim_results, sim_inds, real_inds, acc])
242 | best_s = np.argmax([r[0][0] for r in results])
243 |
244 | # Make sure that model object fed in did not change during simulations
245 | assert all(y == copy_y)
246 | assert all([copy_selected[i] == already_selected[i]
247 | for i in range(len(already_selected))])
248 |
249 | # Return indices based on return type specified
250 | if self.return_type == 'best_sim':
251 | return results[best_s][1]
252 | elif self.return_type == 'frequency':
253 | unique, counts = np.unique(self.selected, return_counts=True)
254 | argcount = np.argsort(-counts)
255 | return list(unique[argcount[0:N]])
256 | return results[best_s][2]
257 |
258 | def to_dict(self):
259 | output = {}
260 | output['simulated_trajectories'] = self.trace
261 | return output
262 |
--------------------------------------------------------------------------------
/sampling_methods/uniform_sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Uniform sampling method.
16 |
17 | Samples in batches.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import numpy as np
25 |
26 | from sampling_methods.sampling_def import SamplingMethod
27 |
28 |
29 | class UniformSampling(SamplingMethod):
30 |
31 | def __init__(self, X, y, seed):
32 | self.X = X
33 | self.y = y
34 | self.name = 'uniform'
35 | np.random.seed(seed)
36 |
37 | def select_batch_(self, already_selected, N, **kwargs):
38 | """Returns batch of randomly sampled datapoints.
39 |
40 | Assumes that data has already been shuffled.
41 |
42 | Args:
43 | already_selected: index of datapoints already selected
44 | N: batch size
45 |
46 | Returns:
47 | indices of points selected to label
48 | """
49 |
50 | # This is uniform given the remaining pool but biased wrt the entire pool.
51 | sample = [i for i in range(self.X.shape[0]) if i not in already_selected]
52 | return sample[0:N]
53 |
--------------------------------------------------------------------------------
/sampling_methods/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/sampling_methods/utils/tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Node and Tree class to support hierarchical clustering AL method.
16 |
17 | Assumed to be binary tree.
18 |
19 | Node class is used to represent each node in a hierarchical clustering.
20 | Each node has certain properties that are used in the AL method.
21 |
22 | Tree class is used to traverse a hierarchical clustering.
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import copy
30 |
31 |
32 | class Node(object):
33 | """Node class for hierarchical clustering.
34 |
35 | Initialized with name and left right children.
36 | """
37 |
38 | def __init__(self, name, left=None, right=None):
39 | self.name = name
40 | self.left = left
41 | self.right = right
42 | self.is_leaf = left is None and right is None
43 | self.parent = None
44 | # Fields for hierarchical clustering AL
45 | self.score = 1.0
46 | self.split = False
47 | self.best_label = None
48 | self.weight = None
49 |
50 | def set_parent(self, parent):
51 | self.parent = parent
52 |
53 |
54 | class Tree(object):
55 | """Tree object for traversing a binary tree.
56 |
57 | Most methods apply to trees in general with the exception of get_pruning
58 | which is specific to the hierarchical clustering AL method.
59 | """
60 |
61 | def __init__(self, root, node_dict):
62 | """Initializes tree and creates all nodes in node_dict.
63 |
64 | Args:
65 | root: id of the root node
66 | node_dict: dictionary with node_id as keys and entries indicating
67 | left and right child of node respectively.
68 | """
69 | self.node_dict = node_dict
70 | self.root = self.make_tree(root)
71 | self.nodes = {}
72 | self.leaves_mapping = {}
73 | self.fill_parents()
74 | self.n_leaves = None
75 |
76 | def print_tree(self, node, max_depth):
77 | """Helper function to print out tree for debugging."""
78 | node_list = [node]
79 | output = ""
80 | level = 0
81 | while level < max_depth and len(node_list):
82 | children = set()
83 | for n in node_list:
84 | node = self.get_node(n)
85 | output += ("\t"*level+"node %d: score %.2f, weight %.2f" %
86 | (node.name, node.score, node.weight)+"\n")
87 | if node.left:
88 | children.add(node.left.name)
89 | if node.right:
90 | children.add(node.right.name)
91 | level += 1
92 | node_list = children
93 | return print(output)
94 |
95 | def make_tree(self, node_id):
96 | if node_id is not None:
97 | return Node(node_id,
98 | self.make_tree(self.node_dict[node_id][0]),
99 | self.make_tree(self.node_dict[node_id][1]))
100 |
101 | def fill_parents(self):
102 | # Setting parent and storing nodes in dict for fast access
103 | def rec(pointer, parent):
104 | if pointer is not None:
105 | self.nodes[pointer.name] = pointer
106 | pointer.set_parent(parent)
107 | rec(pointer.left, pointer)
108 | rec(pointer.right, pointer)
109 | rec(self.root, None)
110 |
111 | def get_node(self, node_id):
112 | return self.nodes[node_id]
113 |
114 | def get_ancestor(self, node):
115 | ancestors = []
116 | if isinstance(node, int):
117 | node = self.get_node(node)
118 | while node.name != self.root.name:
119 | node = node.parent
120 | ancestors.append(node.name)
121 | return ancestors
122 |
123 | def fill_weights(self):
124 | for v in self.node_dict:
125 | node = self.get_node(v)
126 | node.weight = len(self.leaves_mapping[v]) / (1.0 * self.n_leaves)
127 |
128 | def create_child_leaves_mapping(self, leaves):
129 | """DP for creating child leaves mapping.
130 |
131 | Storing in dict to save recompute.
132 | """
133 | self.n_leaves = len(leaves)
134 | for v in leaves:
135 | self.leaves_mapping[v] = [v]
136 | node_list = set([self.get_node(v).parent for v in leaves])
137 | while node_list:
138 | to_fill = copy.copy(node_list)
139 | for v in node_list:
140 | if (v.left.name in self.leaves_mapping
141 | and v.right.name in self.leaves_mapping):
142 | to_fill.remove(v)
143 | self.leaves_mapping[v.name] = (self.leaves_mapping[v.left.name] +
144 | self.leaves_mapping[v.right.name])
145 | if v.parent is not None:
146 | to_fill.add(v.parent)
147 | node_list = to_fill
148 | self.fill_weights()
149 |
150 | def get_child_leaves(self, node):
151 | return self.leaves_mapping[node]
152 |
153 | def get_pruning(self, node):
154 | if node.split:
155 | return self.get_pruning(node.left) + self.get_pruning(node.right)
156 | else:
157 | return [node.name]
158 |
159 |
--------------------------------------------------------------------------------
/sampling_methods/utils/tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sampling_methods.utils.tree."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import unittest
22 | from sampling_methods.utils import tree
23 |
24 |
25 | class TreeTest(unittest.TestCase):
26 |
27 | def setUp(self):
28 | node_dict = {
29 | 1: (2, 3),
30 | 2: (4, 5),
31 | 3: (6, 7),
32 | 4: [None, None],
33 | 5: [None, None],
34 | 6: [None, None],
35 | 7: [None, None]
36 | }
37 | self.tree = tree.Tree(1, node_dict)
38 | self.tree.create_child_leaves_mapping([4, 5, 6, 7])
39 | node = self.tree.get_node(1)
40 | node.split = True
41 | node = self.tree.get_node(2)
42 | node.split = True
43 |
44 | def assertNode(self, node, name, left, right):
45 | self.assertEqual(node.name, name)
46 | self.assertEqual(node.left.name, left)
47 | self.assertEqual(node.right.name, right)
48 |
49 | def testTreeRootSetCorrectly(self):
50 | self.assertNode(self.tree.root, 1, 2, 3)
51 |
52 | def testGetNode(self):
53 | node = self.tree.get_node(1)
54 | assert isinstance(node, tree.Node)
55 | self.assertEqual(node.name, 1)
56 |
57 | def testFillParent(self):
58 | node = self.tree.get_node(3)
59 | self.assertEqual(node.parent.name, 1)
60 |
61 | def testGetAncestors(self):
62 | ancestors = self.tree.get_ancestor(5)
63 | self.assertTrue(all([a in ancestors for a in [1, 2]]))
64 |
65 | def testChildLeaves(self):
66 | leaves = self.tree.get_child_leaves(3)
67 | self.assertTrue(all([c in leaves for c in [6, 7]]))
68 |
69 | def testFillWeights(self):
70 | node = self.tree.get_node(3)
71 | self.assertEqual(node.weight, 0.5)
72 |
73 | def testGetPruning(self):
74 | node = self.tree.get_node(1)
75 | pruning = self.tree.get_pruning(node)
76 | self.assertTrue(all([n in pruning for n in [3, 4, 5]]))
77 |
78 | if __name__ == '__main__':
79 | unittest.main()
80 |
--------------------------------------------------------------------------------
/sampling_methods/wrapper_sampler_def.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Abstract class for wrapper sampling methods that call base sampling methods.
16 |
17 | Provides interface to sampling methods that allow same signature
18 | for select_batch. Each subclass implements select_batch_ with the desired
19 | signature for readability.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import abc
27 |
28 | from sampling_methods.constants import AL_MAPPING
29 | from sampling_methods.constants import get_all_possible_arms
30 | from sampling_methods.sampling_def import SamplingMethod
31 |
32 | get_all_possible_arms()
33 |
34 |
35 | class WrapperSamplingMethod(SamplingMethod):
36 | __metaclass__ = abc.ABCMeta
37 |
38 | def initialize_samplers(self, mixtures):
39 | methods = []
40 | for m in mixtures:
41 | methods += m['methods']
42 | methods = set(methods)
43 | self.base_samplers = {}
44 | for s in methods:
45 | self.base_samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed)
46 | self.samplers = []
47 | for m in mixtures:
48 | self.samplers.append(
49 | AL_MAPPING['mixture_of_samplers'](self.X, self.y, self.seed, m,
50 | self.base_samplers))
51 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/utils/allconv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implements allconv model in keras using tensorflow backend."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 |
22 | import keras
23 | import keras.backend as K
24 | from keras.layers import Activation
25 | from keras.layers import Conv2D
26 | from keras.layers import Dropout
27 | from keras.layers import GlobalAveragePooling2D
28 | from keras.models import Sequential
29 |
30 | import numpy as np
31 | import tensorflow as tf
32 |
33 |
34 | class AllConv(object):
35 | """allconv network that matches sklearn api."""
36 |
37 | def __init__(self,
38 | random_state=1,
39 | epochs=50,
40 | batch_size=32,
41 | solver='rmsprop',
42 | learning_rate=0.001,
43 | lr_decay=0.):
44 | # params
45 | self.solver = solver
46 | self.epochs = epochs
47 | self.batch_size = batch_size
48 | self.learning_rate = learning_rate
49 | self.lr_decay = lr_decay
50 | # data
51 | self.encode_map = None
52 | self.decode_map = None
53 | self.model = None
54 | self.random_state = random_state
55 | self.n_classes = None
56 |
57 | def build_model(self, X):
58 | # assumes that data axis order is same as the backend
59 | input_shape = X.shape[1:]
60 | np.random.seed(self.random_state)
61 | tf.set_random_seed(self.random_state)
62 |
63 | model = Sequential()
64 | model.add(Conv2D(96, (3, 3), padding='same',
65 | input_shape=input_shape, name='conv1'))
66 | model.add(Activation('relu'))
67 | model.add(Conv2D(96, (3, 3), name='conv2', padding='same'))
68 | model.add(Activation('relu'))
69 | model.add(Conv2D(96, (3, 3), strides=(2, 2), padding='same', name='conv3'))
70 | model.add(Activation('relu'))
71 | model.add(Dropout(0.5))
72 |
73 | model.add(Conv2D(192, (3, 3), name='conv4', padding='same'))
74 | model.add(Activation('relu'))
75 | model.add(Conv2D(192, (3, 3), name='conv5', padding='same'))
76 | model.add(Activation('relu'))
77 | model.add(Conv2D(192, (3, 3), strides=(2, 2), name='conv6', padding='same'))
78 | model.add(Activation('relu'))
79 | model.add(Dropout(0.5))
80 |
81 | model.add(Conv2D(192, (3, 3), name='conv7', padding='same'))
82 | model.add(Activation('relu'))
83 | model.add(Conv2D(192, (1, 1), name='conv8', padding='valid'))
84 | model.add(Activation('relu'))
85 | model.add(Conv2D(10, (1, 1), name='conv9', padding='valid'))
86 |
87 | model.add(GlobalAveragePooling2D())
88 | model.add(Activation('softmax', name='activation_top'))
89 | model.summary()
90 |
91 | try:
92 | optimizer = getattr(keras.optimizers, self.solver)
93 | except:
94 | raise NotImplementedError('optimizer not implemented in keras')
95 | # All optimizers with the exception of nadam take decay as named arg
96 | try:
97 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay)
98 | except:
99 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay)
100 |
101 | model.compile(loss='categorical_crossentropy',
102 | optimizer=opt,
103 | metrics=['accuracy'])
104 | # Save initial weights so that model can be retrained with same
105 | # initialization
106 | self.initial_weights = copy.deepcopy(model.get_weights())
107 |
108 | self.model = model
109 |
110 | def create_y_mat(self, y):
111 | y_encode = self.encode_y(y)
112 | y_encode = np.reshape(y_encode, (len(y_encode), 1))
113 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes)
114 | return y_mat
115 |
116 | # Add handling for classes that do not start counting from 0
117 | def encode_y(self, y):
118 | if self.encode_map is None:
119 | self.classes_ = sorted(list(set(y)))
120 | self.n_classes = len(self.classes_)
121 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_))))
122 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_))
123 | mapper = lambda x: self.encode_map[x]
124 | transformed_y = np.array(map(mapper, y))
125 | return transformed_y
126 |
127 | def decode_y(self, y):
128 | mapper = lambda x: self.decode_map[x]
129 | transformed_y = np.array(map(mapper, y))
130 | return transformed_y
131 |
132 | def fit(self, X_train, y_train, sample_weight=None):
133 | y_mat = self.create_y_mat(y_train)
134 |
135 | if self.model is None:
136 | self.build_model(X_train)
137 |
138 | # We don't want incremental fit so reset learning rate and weights
139 | K.set_value(self.model.optimizer.lr, self.learning_rate)
140 | self.model.set_weights(self.initial_weights)
141 | self.model.fit(
142 | X_train,
143 | y_mat,
144 | batch_size=self.batch_size,
145 | epochs=self.epochs,
146 | shuffle=True,
147 | sample_weight=sample_weight,
148 | verbose=0)
149 |
150 | def predict(self, X_val):
151 | predicted = self.model.predict(X_val)
152 | return predicted
153 |
154 | def score(self, X_val, val_y):
155 | y_mat = self.create_y_mat(val_y)
156 | val_acc = self.model.evaluate(X_val, y_mat)[1]
157 | return val_acc
158 |
159 | def decision_function(self, X):
160 | return self.predict(X)
161 |
162 | def transform(self, X):
163 | model = self.model
164 | inp = [model.input]
165 | activations = []
166 |
167 | # Get activations of the last conv layer.
168 | output = [layer.output for layer in model.layers if
169 | layer.name == 'conv9'][0]
170 | func = K.function(inp + [K.learning_phase()], [output])
171 | for i in range(int(X.shape[0]/self.batch_size) + 1):
172 | minibatch = X[i * self.batch_size
173 | : min(X.shape[0], (i+1) * self.batch_size)]
174 | list_inputs = [minibatch, 0.]
175 | # Learning phase. 0 = Test mode (no dropout or batch normalization)
176 | layer_output = func(list_inputs)[0]
177 | activations.append(layer_output)
178 | output = np.vstack(tuple(activations))
179 | output = np.reshape(output, (output.shape[0],np.product(output.shape[1:])))
180 | return output
181 |
182 | def get_params(self, deep = False):
183 | params = {}
184 | params['solver'] = self.solver
185 | params['epochs'] = self.epochs
186 | params['batch_size'] = self.batch_size
187 | params['learning_rate'] = self.learning_rate
188 | params['weight_decay'] = self.lr_decay
189 | if deep:
190 | return copy.deepcopy(params)
191 | return copy.copy(params)
192 |
193 | def set_params(self, **parameters):
194 | for parameter, value in parameters.items():
195 | setattr(self, parameter, value)
196 | return self
197 |
--------------------------------------------------------------------------------
/utils/chart_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Experiment charting script.
16 | """
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import pickle
24 |
25 | import numpy as np
26 | import matplotlib.pyplot as plt
27 | from matplotlib.backends.backend_pdf import PdfPages
28 |
29 | from absl import app
30 | from absl import flags
31 | from tensorflow import gfile
32 |
33 | flags.DEFINE_string('source_dir',
34 | '/tmp/toy_experiments',
35 | 'Directory with the output to analyze.')
36 | flags.DEFINE_string('save_dir', '/tmp/active_learning',
37 | 'Directory to save charts.')
38 | flags.DEFINE_string('dataset', 'letter', 'Dataset to analyze.')
39 | flags.DEFINE_string(
40 | 'sampling_methods',
41 | ('uniform,margin,informative_diverse,'
42 | 'pred_expert_advice_trip_agg,'
43 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34'),
44 | 'Comma separated string of sampling methods to include in chart.')
45 | flags.DEFINE_string('scoring_methods', 'logistic,kernel_ls',
46 | 'Comma separated string of scoring methods to chart.')
47 | flags.DEFINE_bool('normalize', False, 'Chart runs using normalized data.')
48 | flags.DEFINE_bool('standardize', True, 'Chart runs using standardized data.')
49 |
50 | FLAGS = flags.FLAGS
51 |
52 |
53 | def combine_results(files, diff=False):
54 | all_results = {}
55 | for f in files:
56 | data = pickle.load(gfile.FastGFile(f, 'r'))
57 | for k in data:
58 | if isinstance(k, tuple):
59 | data[k].pop('noisy_targets')
60 | data[k].pop('indices')
61 | data[k].pop('selected_inds')
62 | data[k].pop('sampler_output')
63 | key = list(k)
64 | seed = key[-1]
65 | key = key[0:10]
66 | key = tuple(key)
67 | if key in all_results:
68 | if seed not in all_results[key]['random_seeds']:
69 | all_results[key]['random_seeds'].append(seed)
70 | for field in [f for f in data[k] if f != 'n_points']:
71 | all_results[key][field] = np.vstack(
72 | (all_results[key][field], data[k][field]))
73 | else:
74 | all_results[key] = data[k]
75 | all_results[key]['random_seeds'] = [seed]
76 | else:
77 | all_results[k] = data[k]
78 | return all_results
79 |
80 |
81 | def plot_results(all_results, score_method, norm, stand, sampler_filter):
82 | colors = {
83 | 'margin':
84 | 'gold',
85 | 'uniform':
86 | 'k',
87 | 'informative_diverse':
88 | 'r',
89 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34':
90 | 'b',
91 | 'pred_expert_advice_trip_agg':
92 | 'g'
93 | }
94 | labels = {
95 | 'margin':
96 | 'margin',
97 | 'uniform':
98 | 'uniform',
99 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34':
100 | 'margin:0.33,informative_diverse:0.33, uniform:0.34',
101 | 'informative_diverse':
102 | 'informative and diverse',
103 | 'pred_expert_advice_trip_agg':
104 | 'expert: margin,informative_diverse,uniform'
105 | }
106 | markers = {
107 | 'margin':
108 | 'None',
109 | 'uniform':
110 | 'None',
111 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34':
112 | '>',
113 | 'informative_diverse':
114 | 'None',
115 | 'pred_expert_advice_trip_agg':
116 | 'p'
117 | }
118 | fields = all_results['tuple_keys']
119 | fields = dict(zip(fields, range(len(fields))))
120 |
121 | for k in sorted(all_results.keys()):
122 | sampler = k[fields['sampler']]
123 | if (isinstance(k, tuple) and
124 | k[fields['score_method']] == score_method and
125 | k[fields['standardize']] == stand and
126 | k[fields['normalize']] == norm and
127 | (sampler_filter is None or sampler in sampler_filter)):
128 | results = all_results[k]
129 | n_trials = results['accuracy'].shape[0]
130 | x = results['data_sizes'][0]
131 | mean_acc = np.mean(results['accuracy'], axis=0)
132 | CI_acc = np.std(results['accuracy'], axis=0) / np.sqrt(n_trials) * 2.96
133 | if sampler == 'uniform':
134 | plt.plot(
135 | x,
136 | mean_acc,
137 | linewidth=1,
138 | label=labels[sampler],
139 | color=colors[sampler],
140 | linestyle='--'
141 | )
142 | plt.fill_between(
143 | x,
144 | mean_acc - CI_acc,
145 | mean_acc + CI_acc,
146 | color=colors[sampler],
147 | alpha=0.2
148 | )
149 | else:
150 | plt.plot(
151 | x,
152 | mean_acc,
153 | linewidth=1,
154 | label=labels[sampler],
155 | color=colors[sampler],
156 | marker=markers[sampler],
157 | markeredgecolor=colors[sampler]
158 | )
159 | plt.fill_between(
160 | x,
161 | mean_acc - CI_acc,
162 | mean_acc + CI_acc,
163 | color=colors[sampler],
164 | alpha=0.2
165 | )
166 | plt.legend(loc=4)
167 |
168 |
169 | def get_between(filename, start, end):
170 | start_ind = filename.find(start) + len(start)
171 | end_ind = filename.rfind(end)
172 | return filename[start_ind:end_ind]
173 |
174 |
175 | def get_sampling_method(dataset, filename):
176 | return get_between(filename, dataset + '_', '/')
177 |
178 |
179 | def get_scoring_method(filename):
180 | return get_between(filename, 'results_score_', '_select_')
181 |
182 |
183 | def get_normalize(filename):
184 | return get_between(filename, '_norm_', '_stand_') == 'True'
185 |
186 |
187 | def get_standardize(filename):
188 | return get_between(
189 | filename, '_stand_', filename[filename.rfind('_'):]) == 'True'
190 |
191 |
192 | def main(argv):
193 | del argv # Unused.
194 | if not gfile.Exists(FLAGS.save_dir):
195 | gfile.MkDir(FLAGS.save_dir)
196 | charting_filepath = os.path.join(FLAGS.save_dir,
197 | FLAGS.dataset + '_charts.pdf')
198 | sampling_methods = FLAGS.sampling_methods.split(',')
199 | scoring_methods = FLAGS.scoring_methods.split(',')
200 | files = gfile.Glob(
201 | os.path.join(FLAGS.source_dir, FLAGS.dataset + '*/results*.pkl'))
202 | files = [
203 | f for f in files
204 | if (get_sampling_method(FLAGS.dataset, f) in sampling_methods and
205 | get_scoring_method(f) in scoring_methods and
206 | get_normalize(f) == FLAGS.normalize and
207 | get_standardize(f) == FLAGS.standardize)
208 | ]
209 |
210 | print('Reading in %d files...' % len(files))
211 | all_results = combine_results(files)
212 | pdf = PdfPages(charting_filepath)
213 |
214 | print('Plotting charts...')
215 | plt.style.use('ggplot')
216 | for m in scoring_methods:
217 | plot_results(
218 | all_results,
219 | m,
220 | FLAGS.normalize,
221 | FLAGS.standardize,
222 | sampler_filter=sampling_methods)
223 | plt.title('Dataset: %s, Score Method: %s' % (FLAGS.dataset, m))
224 | pdf.savefig()
225 | plt.close()
226 | pdf.close()
227 |
228 |
229 | if __name__ == '__main__':
230 | app.run(main)
231 |
--------------------------------------------------------------------------------
/utils/create_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Make datasets and save specified directory.
16 |
17 | Downloads datasets using scikit datasets and can also parse csv file
18 | to save into pickle format.
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | from io import BytesIO
26 | import os
27 | import pickle
28 | import StringIO
29 | import tarfile
30 | import urllib2
31 |
32 | import keras.backend as K
33 | from keras.datasets import cifar10
34 | from keras.datasets import cifar100
35 | from keras.datasets import mnist
36 |
37 | import numpy as np
38 | import pandas as pd
39 | from sklearn.datasets import fetch_20newsgroups_vectorized
40 | from sklearn.datasets import fetch_mldata
41 | from sklearn.datasets import load_breast_cancer
42 | from sklearn.datasets import load_iris
43 | import sklearn.datasets.rcv1
44 | from sklearn.feature_extraction.text import CountVectorizer
45 | from sklearn.feature_extraction.text import TfidfTransformer
46 |
47 | from absl import app
48 | from absl import flags
49 | from tensorflow import gfile
50 |
51 | flags.DEFINE_string('save_dir', '/tmp/data',
52 | 'Where to save outputs')
53 | flags.DEFINE_string('datasets', '',
54 | 'Which datasets to download, comma separated.')
55 | FLAGS = flags.FLAGS
56 |
57 |
58 | class Dataset(object):
59 |
60 | def __init__(self, X, y):
61 | self.data = X
62 | self.target = y
63 |
64 |
65 | def get_csv_data(filename):
66 | """Parse csv and return Dataset object with data and targets.
67 |
68 | Create pickle data from csv, assumes the first column contains the targets
69 | Args:
70 | filename: complete path of the csv file
71 | Returns:
72 | Dataset object
73 | """
74 | f = gfile.GFile(filename, 'r')
75 | mat = []
76 | for l in f:
77 | row = l.strip()
78 | row = row.replace('"', '')
79 | row = row.split(',')
80 | row = [float(x) for x in row]
81 | mat.append(row)
82 | mat = np.array(mat)
83 | y = mat[:, 0]
84 | X = mat[:, 1:]
85 | data = Dataset(X, y)
86 | return data
87 |
88 |
89 | def get_wikipedia_talk_data():
90 | """Get wikipedia talk dataset.
91 |
92 | See here for more information about the dataset:
93 | https://figshare.com/articles/Wikipedia_Detox_Data/4054689
94 | Downloads annotated comments and annotations.
95 | """
96 |
97 | ANNOTATED_COMMENTS_URL = 'https://ndownloader.figshare.com/files/7554634'
98 | ANNOTATIONS_URL = 'https://ndownloader.figshare.com/files/7554637'
99 |
100 | def download_file(url):
101 | req = urllib2.Request(url)
102 | response = urllib2.urlopen(req)
103 | return response
104 |
105 | # Process comments
106 | comments = pd.read_table(
107 | download_file(ANNOTATED_COMMENTS_URL), index_col=0, sep='\t')
108 | # remove newline and tab tokens
109 | comments['comment'] = comments['comment'].apply(
110 | lambda x: x.replace('NEWLINE_TOKEN', ' '))
111 | comments['comment'] = comments['comment'].apply(
112 | lambda x: x.replace('TAB_TOKEN', ' '))
113 |
114 | # Process labels
115 | annotations = pd.read_table(download_file(ANNOTATIONS_URL), sep='\t')
116 | # labels a comment as an atack if the majority of annoatators did so
117 | labels = annotations.groupby('rev_id')['attack'].mean() > 0.5
118 |
119 | # Perform data preprocessing, should probably tune these hyperparameters
120 | vect = CountVectorizer(max_features=30000, ngram_range=(1, 2))
121 | tfidf = TfidfTransformer(norm='l2')
122 | X = tfidf.fit_transform(vect.fit_transform(comments['comment']))
123 | y = np.array(labels)
124 | data = Dataset(X, y)
125 | return data
126 |
127 |
128 | def get_keras_data(dataname):
129 | """Get datasets using keras API and return as a Dataset object."""
130 | if dataname == 'cifar10_keras':
131 | train, test = cifar10.load_data()
132 | elif dataname == 'cifar100_coarse_keras':
133 | train, test = cifar100.load_data('coarse')
134 | elif dataname == 'cifar100_keras':
135 | train, test = cifar100.load_data()
136 | elif dataname == 'mnist_keras':
137 | train, test = mnist.load_data()
138 | else:
139 | raise NotImplementedError('dataset not supported')
140 |
141 | X = np.concatenate((train[0], test[0]))
142 | y = np.concatenate((train[1], test[1]))
143 |
144 | if dataname == 'mnist_keras':
145 | # Add extra dimension for channel
146 | num_rows = X.shape[1]
147 | num_cols = X.shape[2]
148 | X = X.reshape(X.shape[0], 1, num_rows, num_cols)
149 | if K.image_data_format() == 'channels_last':
150 | X = X.transpose(0, 2, 3, 1)
151 |
152 | y = y.flatten()
153 | data = Dataset(X, y)
154 | return data
155 |
156 |
157 | # TODO(lishal): remove regular cifar10 dataset and only use dataset downloaded
158 | # from keras to maintain image dims to create tensor for tf models
159 | # Requires adding handling in run_experiment.py for handling of different
160 | # training methods that require either 2d or tensor data.
161 | def get_cifar10():
162 | """Get CIFAR-10 dataset from source dir.
163 |
164 | Slightly redundant with keras function to get cifar10 but this returns
165 | in flat format instead of keras numpy image tensor.
166 | """
167 | url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
168 | def download_file(url):
169 | req = urllib2.Request(url)
170 | response = urllib2.urlopen(req)
171 | return response
172 | response = download_file(url)
173 | tmpfile = BytesIO()
174 | while True:
175 | # Download a piece of the file from the connection
176 | s = response.read(16384)
177 | # Once the entire file has been downloaded, tarfile returns b''
178 | # (the empty bytes) which is a falsey value
179 | if not s:
180 | break
181 | # Otherwise, write the piece of the file to the temporary file.
182 | tmpfile.write(s)
183 | response.close()
184 |
185 | tmpfile.seek(0)
186 | tar_dir = tarfile.open(mode='r:gz', fileobj=tmpfile)
187 | X = None
188 | y = None
189 | for member in tar_dir.getnames():
190 | if '_batch' in member:
191 | filestream = tar_dir.extractfile(member).read()
192 | batch = pickle.load(StringIO.StringIO(filestream))
193 | if X is None:
194 | X = np.array(batch['data'], dtype=np.uint8)
195 | y = np.array(batch['labels'])
196 | else:
197 | X = np.concatenate((X, np.array(batch['data'], dtype=np.uint8)))
198 | y = np.concatenate((y, np.array(batch['labels'])))
199 | data = Dataset(X, y)
200 | return data
201 |
202 |
203 | def get_mldata(dataset):
204 | # Use scikit to grab datasets and save them save_dir.
205 | save_dir = FLAGS.save_dir
206 | filename = os.path.join(save_dir, dataset[1]+'.pkl')
207 |
208 | if not gfile.Exists(save_dir):
209 | gfile.MkDir(save_dir)
210 | if not gfile.Exists(filename):
211 | if dataset[0][-3:] == 'csv':
212 | data = get_csv_data(dataset[0])
213 | elif dataset[0] == 'breast_cancer':
214 | data = load_breast_cancer()
215 | elif dataset[0] == 'iris':
216 | data = load_iris()
217 | elif dataset[0] == 'newsgroup':
218 | # Removing header information to make sure that no newsgroup identifying
219 | # information is included in data
220 | data = fetch_20newsgroups_vectorized(subset='all', remove=('headers'))
221 | tfidf = TfidfTransformer(norm='l2')
222 | X = tfidf.fit_transform(data.data)
223 | data.data = X
224 | elif dataset[0] == 'rcv1':
225 | sklearn.datasets.rcv1.URL = (
226 | 'http://www.ai.mit.edu/projects/jmlr/papers/'
227 | 'volume5/lewis04a/a13-vector-files/lyrl2004_vectors')
228 | sklearn.datasets.rcv1.URL_topics = (
229 | 'http://www.ai.mit.edu/projects/jmlr/papers/'
230 | 'volume5/lewis04a/a08-topic-qrels/rcv1-v2.topics.qrels.gz')
231 | data = sklearn.datasets.fetch_rcv1(
232 | data_home='/tmp')
233 | elif dataset[0] == 'wikipedia_attack':
234 | data = get_wikipedia_talk_data()
235 | elif dataset[0] == 'cifar10':
236 | data = get_cifar10()
237 | elif 'keras' in dataset[0]:
238 | data = get_keras_data(dataset[0])
239 | else:
240 | try:
241 | data = fetch_mldata(dataset[0])
242 | except:
243 | raise Exception('ERROR: failed to fetch data from mldata.org')
244 | X = data.data
245 | y = data.target
246 | if X.shape[0] != y.shape[0]:
247 | X = np.transpose(X)
248 | assert X.shape[0] == y.shape[0]
249 |
250 | data = {'data': X, 'target': y}
251 | pickle.dump(data, gfile.GFile(filename, 'w'))
252 |
253 |
254 | def main(argv):
255 | del argv # Unused.
256 | # First entry of tuple is mldata.org name, second is the name that we'll use
257 | # to reference the data.
258 | datasets = [('mnist (original)', 'mnist'), ('australian', 'australian'),
259 | ('heart', 'heart'), ('breast_cancer', 'breast_cancer'),
260 | ('iris', 'iris'), ('vehicle', 'vehicle'), ('wine', 'wine'),
261 | ('waveform ida', 'waveform'), ('german ida', 'german'),
262 | ('splice ida', 'splice'), ('ringnorm ida', 'ringnorm'),
263 | ('twonorm ida', 'twonorm'), ('diabetes_scale', 'diabetes'),
264 | ('mushrooms', 'mushrooms'), ('letter', 'letter'), ('dna', 'dna'),
265 | ('banana-ida', 'banana'), ('letter', 'letter'), ('dna', 'dna'),
266 | ('newsgroup', 'newsgroup'), ('cifar10', 'cifar10'),
267 | ('cifar10_keras', 'cifar10_keras'),
268 | ('cifar100_keras', 'cifar100_keras'),
269 | ('cifar100_coarse_keras', 'cifar100_coarse_keras'),
270 | ('mnist_keras', 'mnist_keras'),
271 | ('wikipedia_attack', 'wikipedia_attack'),
272 | ('rcv1', 'rcv1')]
273 |
274 | if FLAGS.datasets:
275 | subset = FLAGS.datasets.split(',')
276 | datasets = [d for d in datasets if d[1] in subset]
277 |
278 | for d in datasets:
279 | print(d[1])
280 | get_mldata(d)
281 |
282 |
283 | if __name__ == '__main__':
284 | app.run(main)
285 |
--------------------------------------------------------------------------------
/utils/kernel_block_solver.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Block kernel lsqr solver for multi-class classification."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 | import math
22 |
23 | import numpy as np
24 | import scipy.linalg as linalg
25 | from scipy.sparse.linalg import spsolve
26 | from sklearn import metrics
27 |
28 |
29 | class BlockKernelSolver(object):
30 | """Inspired by algorithm from https://arxiv.org/pdf/1602.05310.pdf."""
31 | # TODO(lishal): save preformed kernel matrix and reuse if possible
32 | # perhaps not possible if want to keep scikitlearn signature
33 |
34 | def __init__(self,
35 | random_state=1,
36 | C=0.1,
37 | block_size=4000,
38 | epochs=3,
39 | verbose=False,
40 | gamma=None):
41 | self.block_size = block_size
42 | self.epochs = epochs
43 | self.C = C
44 | self.kernel = 'rbf'
45 | self.coef_ = None
46 | self.verbose = verbose
47 | self.encode_map = None
48 | self.decode_map = None
49 | self.gamma = gamma
50 | self.X_train = None
51 | self.random_state = random_state
52 |
53 | def encode_y(self, y):
54 | # Handles classes that do not start counting from 0.
55 | if self.encode_map is None:
56 | self.classes_ = sorted(list(set(y)))
57 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_))))
58 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_))
59 | mapper = lambda x: self.encode_map[x]
60 | transformed_y = np.array(map(mapper, y))
61 | return transformed_y
62 |
63 | def decode_y(self, y):
64 | mapper = lambda x: self.decode_map[x]
65 | transformed_y = np.array(map(mapper, y))
66 | return transformed_y
67 |
68 | def fit(self, X_train, y_train, sample_weight=None):
69 | """Form K and solve (K + lambda * I)x = y in a block-wise fashion."""
70 | np.random.seed(self.random_state)
71 | self.X_train = X_train
72 | n_features = X_train.shape[1]
73 | y = self.encode_y(y_train)
74 | if self.gamma is None:
75 | self.gamma = 1./n_features
76 | K = metrics.pairwise.pairwise_kernels(
77 | X_train, metric=self.kernel, gamma=self.gamma)
78 | if self.verbose:
79 | print('Finished forming kernel matrix.')
80 |
81 | # compute some constants
82 | num_classes = len(list(set(y)))
83 | num_samples = K.shape[0]
84 | num_blocks = math.ceil(num_samples*1.0/self.block_size)
85 | x = np.zeros((K.shape[0], num_classes))
86 | y_hat = np.zeros((K.shape[0], num_classes))
87 | onehot = lambda x: np.eye(num_classes)[x]
88 | y_onehot = np.array(map(onehot, y))
89 | idxes = np.diag_indices(num_samples)
90 | if sample_weight is not None:
91 | weights = np.sqrt(sample_weight)
92 | weights = weights[:, np.newaxis]
93 | y_onehot = weights * y_onehot
94 | K *= np.outer(weights, weights)
95 | if num_blocks == 1:
96 | epochs = 1
97 | else:
98 | epochs = self.epochs
99 |
100 | for e in range(epochs):
101 | shuffled_coords = np.random.choice(
102 | num_samples, num_samples, replace=False)
103 | for b in range(int(num_blocks)):
104 | residuals = y_onehot - y_hat
105 |
106 | # Form a block of K.
107 | K[idxes] += (self.C * num_samples)
108 | block = shuffled_coords[b*self.block_size:
109 | min((b+1)*self.block_size, num_samples)]
110 | K_block = K[:, block]
111 | # Dim should be block size x block size
112 | KbTKb = K_block.T.dot(K_block)
113 |
114 | if self.verbose:
115 | print('solving block {0}'.format(b))
116 | # Try linalg solve then sparse solve for handling of sparse input.
117 | try:
118 | x_block = linalg.solve(KbTKb, K_block.T.dot(residuals))
119 | except:
120 | try:
121 | x_block = spsolve(KbTKb, K_block.T.dot(residuals))
122 | except:
123 | return None
124 |
125 | # update model
126 | x[block] = x[block] + x_block
127 | K[idxes] = K[idxes] - (self.C * num_samples)
128 | y_hat = K.dot(x)
129 |
130 | y_pred = np.argmax(y_hat, axis=1)
131 | train_acc = metrics.accuracy_score(y, y_pred)
132 | if self.verbose:
133 | print('Epoch: {0}, Block: {1}, Train Accuracy: {2}'
134 | .format(e, b, train_acc))
135 | self.coef_ = x
136 |
137 | def predict(self, X_val):
138 | val_K = metrics.pairwise.pairwise_kernels(
139 | X_val, self.X_train, metric=self.kernel, gamma=self.gamma)
140 | val_pred = np.argmax(val_K.dot(self.coef_), axis=1)
141 | return self.decode_y(val_pred)
142 |
143 | def score(self, X_val, val_y):
144 | val_pred = self.predict(X_val)
145 | val_acc = metrics.accuracy_score(val_y, val_pred)
146 | return val_acc
147 |
148 | def decision_function(self, X, type='predicted'):
149 | # Return the predicted value of the best class
150 | # Margin_AL will see that a vector is returned and not a matrix and
151 | # simply select the points that have the lowest predicted value to label
152 | K = metrics.pairwise.pairwise_kernels(
153 | X, self.X_train, metric=self.kernel, gamma=self.gamma)
154 | predicted = K.dot(self.coef_)
155 | if type == 'scores':
156 | val_best = np.max(K.dot(self.coef_), axis=1)
157 | return val_best
158 | elif type == 'predicted':
159 | return predicted
160 | else:
161 | raise NotImplementedError('Invalid return type for decision function.')
162 |
163 | def get_params(self, deep=False):
164 | params = {}
165 | params['C'] = self.C
166 | params['gamma'] = self.gamma
167 | if deep:
168 | return copy.deepcopy(params)
169 | return copy.copy(params)
170 |
171 | def set_params(self, **parameters):
172 | for parameter, value in parameters.items():
173 | setattr(self, parameter, value)
174 | return self
175 |
176 | def softmax_over_predicted(self, X):
177 | val_K = metrics.pairwise.pairwise_kernels(
178 | X, self.X_train, metric=self.kernel, gamma=self.gamma)
179 | val_pred = val_K.dot(self.coef_)
180 | row_min = np.min(val_pred, axis=1)
181 | val_pred = val_pred - row_min[:, None]
182 | val_pred = np.exp(val_pred)
183 | sum_exp = np.sum(val_pred, axis=1)
184 | val_pred = val_pred/sum_exp[:, None]
185 | return val_pred
186 |
--------------------------------------------------------------------------------
/utils/small_cnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implements Small CNN model in keras using tensorflow backend."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 |
22 | import keras
23 | import keras.backend as K
24 | from keras.layers import Activation
25 | from keras.layers import Conv2D
26 | from keras.layers import Dense
27 | from keras.layers import Dropout
28 | from keras.layers import Flatten
29 | from keras.layers import MaxPooling2D
30 | from keras.models import Sequential
31 |
32 | import numpy as np
33 | import tensorflow as tf
34 |
35 |
36 | class SmallCNN(object):
37 | """Small convnet that matches sklearn api.
38 |
39 | Implements model from
40 | https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py
41 | Adapts for inputs of variable size, expects data to be 4d tensor, with
42 | # of obserations as first dimension and other dimensions to correspond to
43 | length width and # of channels in image.
44 | """
45 |
46 | def __init__(self,
47 | random_state=1,
48 | epochs=50,
49 | batch_size=32,
50 | solver='rmsprop',
51 | learning_rate=0.001,
52 | lr_decay=0.):
53 | # params
54 | self.solver = solver
55 | self.epochs = epochs
56 | self.batch_size = batch_size
57 | self.learning_rate = learning_rate
58 | self.lr_decay = lr_decay
59 | # data
60 | self.encode_map = None
61 | self.decode_map = None
62 | self.model = None
63 | self.random_state = random_state
64 | self.n_classes = None
65 |
66 | def build_model(self, X):
67 | # assumes that data axis order is same as the backend
68 | input_shape = X.shape[1:]
69 | np.random.seed(self.random_state)
70 | tf.set_random_seed(self.random_state)
71 |
72 | model = Sequential()
73 | model.add(Conv2D(32, (3, 3), padding='same',
74 | input_shape=input_shape, name='conv1'))
75 | model.add(Activation('relu'))
76 | model.add(Conv2D(32, (3, 3), name='conv2'))
77 | model.add(Activation('relu'))
78 | model.add(MaxPooling2D(pool_size=(2, 2)))
79 | model.add(Dropout(0.25))
80 |
81 | model.add(Conv2D(64, (3, 3), padding='same', name='conv3'))
82 | model.add(Activation('relu'))
83 | model.add(Conv2D(64, (3, 3), name='conv4'))
84 | model.add(Activation('relu'))
85 | model.add(MaxPooling2D(pool_size=(2, 2)))
86 | model.add(Dropout(0.25))
87 |
88 | model.add(Flatten())
89 | model.add(Dense(512, name='dense1'))
90 | model.add(Activation('relu'))
91 | model.add(Dropout(0.5))
92 | model.add(Dense(self.n_classes, name='dense2'))
93 | model.add(Activation('softmax'))
94 |
95 | try:
96 | optimizer = getattr(keras.optimizers, self.solver)
97 | except:
98 | raise NotImplementedError('optimizer not implemented in keras')
99 | # All optimizers with the exception of nadam take decay as named arg
100 | try:
101 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay)
102 | except:
103 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay)
104 |
105 | model.compile(loss='categorical_crossentropy',
106 | optimizer=opt,
107 | metrics=['accuracy'])
108 | # Save initial weights so that model can be retrained with same
109 | # initialization
110 | self.initial_weights = copy.deepcopy(model.get_weights())
111 |
112 | self.model = model
113 |
114 | def create_y_mat(self, y):
115 | y_encode = self.encode_y(y)
116 | y_encode = np.reshape(y_encode, (len(y_encode), 1))
117 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes)
118 | return y_mat
119 |
120 | # Add handling for classes that do not start counting from 0
121 | def encode_y(self, y):
122 | if self.encode_map is None:
123 | self.classes_ = sorted(list(set(y)))
124 | self.n_classes = len(self.classes_)
125 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_))))
126 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_))
127 | mapper = lambda x: self.encode_map[x]
128 | transformed_y = np.array(map(mapper, y))
129 | return transformed_y
130 |
131 | def decode_y(self, y):
132 | mapper = lambda x: self.decode_map[x]
133 | transformed_y = np.array(map(mapper, y))
134 | return transformed_y
135 |
136 | def fit(self, X_train, y_train, sample_weight=None):
137 | y_mat = self.create_y_mat(y_train)
138 |
139 | if self.model is None:
140 | self.build_model(X_train)
141 |
142 | # We don't want incremental fit so reset learning rate and weights
143 | K.set_value(self.model.optimizer.lr, self.learning_rate)
144 | self.model.set_weights(self.initial_weights)
145 | self.model.fit(
146 | X_train,
147 | y_mat,
148 | batch_size=self.batch_size,
149 | epochs=self.epochs,
150 | shuffle=True,
151 | sample_weight=sample_weight,
152 | verbose=0)
153 |
154 | def predict(self, X_val):
155 | predicted = self.model.predict(X_val)
156 | return predicted
157 |
158 | def score(self, X_val, val_y):
159 | y_mat = self.create_y_mat(val_y)
160 | val_acc = self.model.evaluate(X_val, y_mat)[1]
161 | return val_acc
162 |
163 | def decision_function(self, X):
164 | return self.predict(X)
165 |
166 | def transform(self, X):
167 | model = self.model
168 | inp = [model.input]
169 | activations = []
170 |
171 | # Get activations of the first dense layer.
172 | output = [layer.output for layer in model.layers if
173 | layer.name == 'dense1'][0]
174 | func = K.function(inp + [K.learning_phase()], [output])
175 | for i in range(int(X.shape[0]/self.batch_size) + 1):
176 | minibatch = X[i * self.batch_size
177 | : min(X.shape[0], (i+1) * self.batch_size)]
178 | list_inputs = [minibatch, 0.]
179 | # Learning phase. 0 = Test mode (no dropout or batch normalization)
180 | layer_output = func(list_inputs)[0]
181 | activations.append(layer_output)
182 | output = np.vstack(tuple(activations))
183 | return output
184 |
185 | def get_params(self, deep = False):
186 | params = {}
187 | params['solver'] = self.solver
188 | params['epochs'] = self.epochs
189 | params['batch_size'] = self.batch_size
190 | params['learning_rate'] = self.learning_rate
191 | params['weight_decay'] = self.lr_decay
192 | if deep:
193 | return copy.deepcopy(params)
194 | return copy.copy(params)
195 |
196 | def set_params(self, **parameters):
197 | for parameter, value in parameters.items():
198 | setattr(self, parameter, value)
199 | return self
200 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions for run_experiment.py."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 | import os
22 | import pickle
23 | import sys
24 |
25 | import numpy as np
26 | import scipy
27 |
28 | from sklearn.linear_model import LogisticRegression
29 | from sklearn.model_selection import GridSearchCV
30 | from sklearn.svm import LinearSVC
31 | from sklearn.svm import SVC
32 |
33 | from tensorflow import gfile
34 |
35 |
36 | from utils.kernel_block_solver import BlockKernelSolver
37 | from utils.small_cnn import SmallCNN
38 | from utils.allconv import AllConv
39 |
40 |
41 | class Logger(object):
42 | """Logging object to write to file and stdout."""
43 |
44 | def __init__(self, filename):
45 | self.terminal = sys.stdout
46 | self.log = gfile.GFile(filename, "w")
47 |
48 | def write(self, message):
49 | self.terminal.write(message)
50 | self.log.write(message)
51 |
52 | def flush(self):
53 | self.terminal.flush()
54 |
55 | def flush_file(self):
56 | self.log.flush()
57 |
58 |
59 | def create_checker_unbalanced(split, n, grid_size):
60 | """Creates a dataset with two classes that occupy one color of checkboard.
61 |
62 | Args:
63 | split: splits to use for class imbalance.
64 | n: number of datapoints to sample.
65 | grid_size: checkerboard size.
66 | Returns:
67 | X: 2d features.
68 | y: binary class.
69 | """
70 | y = np.zeros(0)
71 | X = np.zeros((0, 2))
72 | for i in range(grid_size):
73 | for j in range(grid_size):
74 | label = 0
75 | n_0 = int(n/(grid_size*grid_size) * split[0] * 2)
76 | if (i-j) % 2 == 0:
77 | label = 1
78 | n_0 = int(n/(grid_size*grid_size) * split[1] * 2)
79 | x_1 = np.random.uniform(i, i+1, n_0)
80 | x_2 = np.random.uniform(j, j+1, n_0)
81 | x = np.vstack((x_1, x_2))
82 | x = x.T
83 | X = np.concatenate((X, x))
84 | y_0 = label * np.ones(n_0)
85 | y = np.concatenate((y, y_0))
86 | return X, y
87 |
88 |
89 | def flatten_X(X):
90 | shape = X.shape
91 | flat_X = X
92 | if len(shape) > 2:
93 | flat_X = np.reshape(X, (shape[0], np.product(shape[1:])))
94 | return flat_X
95 |
96 |
97 | def get_mldata(data_dir, name):
98 | """Loads data from data_dir.
99 |
100 | Looks for the file in data_dir.
101 | Assumes that data is in pickle format with dictionary fields data and target.
102 |
103 |
104 | Args:
105 | data_dir: directory to look in
106 | name: dataset name, assumes data is saved in the save_dir with filename
107 | .pkl
108 | Returns:
109 | data and targets
110 | Raises:
111 | NameError: dataset not found in data folder.
112 | """
113 | dataname = name
114 | if dataname == "checkerboard":
115 | X, y = create_checker_unbalanced(split=[1./5, 4./5], n=10000, grid_size=4)
116 | else:
117 | filename = os.path.join(data_dir, dataname + ".pkl")
118 | if not gfile.Exists(filename):
119 | raise NameError("ERROR: dataset not available")
120 | data = pickle.load(gfile.GFile(filename, "r"))
121 | X = data["data"]
122 | y = data["target"]
123 | if "keras" in dataname:
124 | X = X / 255
125 | y = y.flatten()
126 | return X, y
127 |
128 |
129 | def filter_data(X, y, keep=None):
130 | """Filters data by class indicated in keep.
131 |
132 | Args:
133 | X: train data
134 | y: train targets
135 | keep: defaults to None which will keep everything, otherwise takes a list
136 | of classes to keep
137 |
138 | Returns:
139 | filtered data and targets
140 | """
141 | if keep is None:
142 | return X, y
143 | keep_ind = [i for i in range(len(y)) if y[i] in keep]
144 | return X[keep_ind], y[keep_ind]
145 |
146 |
147 | def get_class_counts(y_full, y):
148 | """Gets the count of all classes in a sample.
149 |
150 | Args:
151 | y_full: full target vector containing all classes
152 | y: sample vector for which to perform the count
153 | Returns:
154 | count of classes for the sample vector y, the class order for count will
155 | be the same as long as same y_full is fed in
156 | """
157 | classes = np.unique(y_full)
158 | classes = np.sort(classes)
159 | unique, counts = np.unique(y, return_counts=True)
160 | complete_counts = []
161 | for c in classes:
162 | if c not in unique:
163 | complete_counts.append(0)
164 | else:
165 | index = np.where(unique == c)[0][0]
166 | complete_counts.append(counts[index])
167 | return np.array(complete_counts)
168 |
169 |
170 | def flip_label(y, percent_random):
171 | """Flips a percentage of labels for one class to the other.
172 |
173 | Randomly sample a percent of points and randomly label the sampled points as
174 | one of the other classes.
175 | Does not introduce bias.
176 |
177 | Args:
178 | y: labels of all datapoints
179 | percent_random: percent of datapoints to corrupt the labels
180 |
181 | Returns:
182 | new labels with noisy labels for indicated percent of data
183 | """
184 | classes = np.unique(y)
185 | y_orig = copy.copy(y)
186 | indices = range(y_orig.shape[0])
187 | np.random.shuffle(indices)
188 | sample = indices[0:int(len(indices) * 1.0 * percent_random)]
189 | fake_labels = []
190 | for s in sample:
191 | label = y[s]
192 | class_ind = np.where(classes == label)[0][0]
193 | other_classes = np.delete(classes, class_ind)
194 | np.random.shuffle(other_classes)
195 | fake_label = other_classes[0]
196 | assert fake_label != label
197 | fake_labels.append(fake_label)
198 | y[sample] = np.array(fake_labels)
199 | assert all(y[indices[len(sample):]] == y_orig[indices[len(sample):]])
200 | return y
201 |
202 |
203 | def get_model(method, seed=13):
204 | """Construct sklearn model using either logistic regression or linear svm.
205 |
206 | Wraps grid search on regularization parameter over either logistic regression
207 | or svm, returns constructed model
208 |
209 | Args:
210 | method: string indicating scikit method to use, currently accepts logistic
211 | and linear svm.
212 | seed: int or rng to use for random state fed to scikit method
213 |
214 | Returns:
215 | scikit learn model
216 | """
217 | # TODO(lishal): extend to include any scikit model that implements
218 | # a decision function.
219 | # TODO(lishal): for kernel methods, currently using default value for gamma
220 | # but should probably tune.
221 | if method == "logistic":
222 | model = LogisticRegression(random_state=seed, multi_class="multinomial",
223 | solver="lbfgs", max_iter=200)
224 | params = {"C": [10.0**(i) for i in range(-4, 5)]}
225 | elif method == "logistic_ovr":
226 | model = LogisticRegression(random_state=seed)
227 | params = {"C": [10.0**(i) for i in range(-5, 4)]}
228 | elif method == "linear_svm":
229 | model = LinearSVC(random_state=seed)
230 | params = {"C": [10.0**(i) for i in range(-4, 5)]}
231 | elif method == "kernel_svm":
232 | model = SVC(random_state=seed)
233 | params = {"C": [10.0**(i) for i in range(-4, 5)]}
234 | elif method == "kernel_ls":
235 | model = BlockKernelSolver(random_state=seed)
236 | params = {"C": [10.0**(i) for i in range(-6, 1)]}
237 | elif method == "small_cnn":
238 | # Model does not work with weighted_expert or simulate_batch
239 | model = SmallCNN(random_state=seed)
240 | return model
241 | elif method == "allconv":
242 | # Model does not work with weighted_expert or simulate_batch
243 | model = AllConv(random_state=seed)
244 | return model
245 |
246 | else:
247 | raise NotImplementedError("ERROR: " + method + " not implemented")
248 |
249 | model = GridSearchCV(model, params, cv=3)
250 | return model
251 |
252 |
253 | def calculate_entropy(batch_size, y_s):
254 | """Calculates KL div between training targets and targets selected by AL.
255 |
256 | Args:
257 | batch_size: batch size of datapoints selected by AL
258 | y_s: vector of datapoints selected by AL. Assumes that the order of the
259 | data is the order in which points were labeled by AL. Also assumes
260 | that in the offline setting y_s will eventually overlap completely with
261 | original training targets.
262 | Returns:
263 | entropy between actual distribution of classes and distribution of
264 | samples selected by AL
265 | """
266 | n_batches = int(np.ceil(len(y_s) * 1.0 / batch_size))
267 | counts = get_class_counts(y_s, y_s)
268 | true_dist = counts / (len(y_s) * 1.0)
269 | entropy = []
270 | for b in range(n_batches):
271 | sample = y_s[b * batch_size:(b + 1) * batch_size]
272 | counts = get_class_counts(y_s, sample)
273 | sample_dist = counts / (1.0 * len(sample))
274 | entropy.append(scipy.stats.entropy(true_dist, sample_dist))
275 | return entropy
276 |
277 |
278 | def get_train_val_test_splits(X, y, max_points, seed, confusion, seed_batch,
279 | split=(2./3, 1./6, 1./6)):
280 | """Return training, validation, and test splits for X and y.
281 |
282 | Args:
283 | X: features
284 | y: targets
285 | max_points: # of points to use when creating splits.
286 | seed: seed for shuffling.
287 | confusion: labeling noise to introduce. 0.1 means randomize 10% of labels.
288 | seed_batch: # of initial datapoints to ensure sufficient class membership.
289 | split: percent splits for train, val, and test.
290 | Returns:
291 | indices: shuffled indices to recreate splits given original input data X.
292 | y_noise: y with noise injected, needed to reproduce results outside of
293 | run_experiments using original data.
294 | """
295 | np.random.seed(seed)
296 | X_copy = copy.copy(X)
297 | y_copy = copy.copy(y)
298 |
299 | # Introduce labeling noise
300 | y_noise = flip_label(y_copy, confusion)
301 |
302 | indices = np.arange(len(y))
303 |
304 | if max_points is None:
305 | max_points = len(y_noise)
306 | else:
307 | max_points = min(len(y_noise), max_points)
308 | train_split = int(max_points * split[0])
309 | val_split = train_split + int(max_points * split[1])
310 | assert seed_batch <= train_split
311 |
312 | # Do this to make sure that the initial batch has examples from all classes
313 | min_shuffle = 3
314 | n_shuffle = 0
315 | y_tmp = y_noise
316 |
317 | # Need at least 4 obs of each class for 2 fold CV to work in grid search step
318 | while (any(get_class_counts(y_tmp, y_tmp[0:seed_batch]) < 4)
319 | or n_shuffle < min_shuffle):
320 | np.random.shuffle(indices)
321 | y_tmp = y_noise[indices]
322 | n_shuffle += 1
323 |
324 | X_train = X_copy[indices[0:train_split]]
325 | X_val = X_copy[indices[train_split:val_split]]
326 | X_test = X_copy[indices[val_split:max_points]]
327 | y_train = y_noise[indices[0:train_split]]
328 | y_val = y_noise[indices[train_split:val_split]]
329 | y_test = y_noise[indices[val_split:max_points]]
330 | # Make sure that we have enough observations of each class for 2-fold cv
331 | assert all(get_class_counts(y_noise, y_train[0:seed_batch]) >= 4)
332 | # Make sure that returned shuffled indices are correct
333 | assert all(y_noise[indices[0:max_points]] ==
334 | np.concatenate((y_train, y_val, y_test), axis=0))
335 | return (indices[0:max_points], X_train, y_train,
336 | X_val, y_val, X_test, y_test, y_noise)
337 |
--------------------------------------------------------------------------------