├── README.md
├── data
├── citeseer
│ ├── bestconfig.txt
│ ├── citeseer.edgelist
│ ├── citeseer.x.pkl
│ ├── citeseer.y.pkl
│ ├── processed
│ │ └── data.pt
│ └── raw
│ │ ├── ind.citeseer.allx
│ │ ├── ind.citeseer.ally
│ │ ├── ind.citeseer.graph
│ │ ├── ind.citeseer.test.index
│ │ ├── ind.citeseer.tx
│ │ ├── ind.citeseer.ty
│ │ ├── ind.citeseer.x
│ │ └── ind.citeseer.y
├── cora
│ ├── bestconfig.txt
│ ├── cora.edgelist
│ ├── cora.x.pkl
│ ├── cora.y.pkl
│ ├── processed
│ │ └── data.pt
│ └── raw
│ │ ├── ind.cora.allx
│ │ ├── ind.cora.ally
│ │ ├── ind.cora.graph
│ │ ├── ind.cora.test.index
│ │ ├── ind.cora.tx
│ │ ├── ind.cora.ty
│ │ ├── ind.cora.x
│ │ └── ind.cora.y
├── pubmed
│ ├── bestconfig.txt
│ ├── processed
│ │ └── data.pt
│ ├── pubmed.edgelist
│ ├── pubmed.x.pkl
│ ├── pubmed.y.pkl
│ └── raw
│ │ ├── ind.pubmed.allx
│ │ ├── ind.pubmed.ally
│ │ ├── ind.pubmed.graph
│ │ ├── ind.pubmed.test.index
│ │ ├── ind.pubmed.tx
│ │ ├── ind.pubmed.ty
│ │ ├── ind.pubmed.x
│ │ └── ind.pubmed.y
└── reddit1401
│ ├── bestconfig.txt
│ ├── graph1
│ ├── graph1.edgelist
│ ├── graph1.x.pkl
│ └── graph1.y.pkl
│ ├── graph2
│ ├── graph2.edgelist
│ ├── graph2.x.pkl
│ └── graph2.y.pkl
│ ├── graph3
│ ├── graph3.edgelist
│ ├── graph3.x.pkl
│ └── graph3.y.pkl
│ ├── graph4
│ ├── graph4.edgelist
│ ├── graph4.x.pkl
│ └── graph4.y.pkl
│ ├── graph5
│ ├── graph5.edgelist
│ ├── graph5.x.pkl
│ └── graph5.y.pkl
│ └── graphforvis
│ ├── graphforvis.edgelist
│ ├── graphforvis.x.pkl
│ └── graphforvis.y.pkl
├── mainfig_website.png
├── requirements.txt
├── saved_models
├── pretrain_citeseer+pubmed.pkl
├── pretrain_cora+citeseer.pkl
├── pretrain_cora+pubmed.pkl
└── pretrain_reddit1+2.pkl
└── src
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
└── train.cpython-36.pyc
├── baselines
├── __init__.py
├── active-learning
│ ├── CONTRIBUTING.md
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── requirements.txt
│ ├── run_experiment.py
│ ├── sampling_methods
│ │ ├── __init__.py
│ │ ├── bandit_discrete.py
│ │ ├── constants.py
│ │ ├── graph_density.py
│ │ ├── hierarchical_clustering_AL.py
│ │ ├── informative_diverse.py
│ │ ├── kcenter_greedy.py
│ │ ├── margin_AL.py
│ │ ├── mixture_of_samplers.py
│ │ ├── represent_cluster_centers.py
│ │ ├── sampling_def.py
│ │ ├── simulate_batch.py
│ │ ├── uniform_sampling.py
│ │ ├── utils
│ │ │ ├── __init__.py
│ │ │ ├── tree.py
│ │ │ └── tree_test.py
│ │ └── wrapper_sampler_def.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── allconv.py
│ │ ├── chart_data.py
│ │ ├── create_data.py
│ │ ├── kernel_block_solver.py
│ │ ├── small_cnn.py
│ │ └── utils.py
├── age.py
├── anrmab.py
├── coreset.py
├── coreset
│ ├── compute_distance_mat.py
│ ├── configure.sh
│ ├── full_solver_gurobi.py
│ └── gurobi_solution_parser.py
└── sampling_methods
│ ├── __init__.py
│ ├── bandit_discrete.py
│ ├── constants.py
│ ├── graph_density.py
│ ├── hierarchical_clustering_AL.py
│ ├── informative_diverse.py
│ ├── kcenter_greedy.py
│ ├── margin_AL.py
│ ├── mixture_of_samplers.py
│ ├── represent_cluster_centers.py
│ ├── sampling_def.py
│ ├── simulate_batch.py
│ ├── uniform_sampling.py
│ ├── utils
│ ├── __init__.py
│ ├── tree.py
│ └── tree_test.py
│ └── wrapper_sampler_def.py
├── datasetcollecting
├── biggraph.py
├── collecting.py
├── loadreddit.py
└── mergeedgelist.py
├── test.py
├── train.py
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-36.pyc
├── classificationnet.cpython-36.pyc
├── common.cpython-36.pyc
├── const.cpython-36.pyc
├── dataloader.cpython-36.pyc
├── env.cpython-36.pyc
├── player.cpython-36.pyc
├── policynet.cpython-36.pyc
├── rewardshaper.cpython-36.pyc
└── utils.cpython-36.pyc
├── classificationnet.py
├── common.py
├── const.py
├── dataloader.py
├── env.py
├── player.py
├── policynet.py
├── query.py
├── rewardshaper.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Graph Policy Network for Transferable Active Learning on Graphs
2 | This is the code of the paper **G**raph **P**olicy network for transferable **A**ctive learning on graphs (GPA).
3 |
4 | ## Dependencies
5 | matplotlib==2.2.3
6 | networkx==2.4
7 | scikit-learn==0.21.2
8 | numpy==1.16.3
9 | scipy==1.2.1
10 | torch==1.3.1
11 |
12 | ## Data
13 | We have provided Cora, Pubmed, Citeseer, Reddit1401 whose data format have been processed and can be directly consumed by our code. Reddit1401 is collected from the reddit data source (where 1401 means Janurary, 2014) and preprocessed by ourselves. If you use these graphs in your work, please cite our paper. For the Coauthor_CS and Coauthor_Phy dataset, we don't provide the processed data because they are too large for github repos. If you are interested, please email shengdinghu@gmail.com for the processed data.
14 |
15 | ## Train
16 |
17 | Use ```train.py``` to train the active learning policy on multiple labeled training graphs. Assume that we have two labeled training graphs ```A``` and ```B``` with query budgets of ```x``` and ```y``` respectively, and we want to save the trained model in ```temp.pkl```, then use the following commend:
18 | ```
19 | python -m src.train --datasets A+B --budgets x+y --save 1 --savename temp
20 | ```
21 | Please refer to the source code to see how to set the other arguments.
22 |
23 | ## Test
24 | Use ```test.py``` to test the learned active learning policy on unlabeled test graphs. Assume that we have an unlabeled test graph ```G``` with a query budget of ```z```, and we want to test the policy stored in ```temp.pkl```, then use the following commend:
25 | ```
26 | python -m src.test --method 3 --modelname temp --datasets G --budgets z
27 | ```
28 | Please refer to the source code to see how to set the other arguments.
29 |
30 | ## Pre-trained Models and Results
31 | We provide several pre-trained models with their test results on the unlabeled test graphs.
32 | For transferable active learning on graphs from the **same** domain, we train on Reddit {1, 2} on test on Reddit {3, 4, 5}. The pre-trained model is saved in ```models/pretrain_reddit1+2.pkl```. The test results are
33 |
34 | | Metric | Reddit 3 | Reddit 4 | Reddit 5 |
35 | | :---: | :---:| :---:| :---: |
36 | | Micro-F1 | 92.51 | 91.49 | 90.71 |
37 | | Macro-F1 | 92.22 | 89.57 | 90.32 |
38 |
39 | For tranferable active learning on graphs across **different** domains, we provide three pre-trained models trained on different training graphs as follows:
40 |
41 | 1. Train on Cora + Citeseer, and test on the remaining graphs. The pre-trained model is saved in ```models/pretrain_cora+citeseer.pkl```. The test results are
42 |
43 | | Metric | Pubmed | Reddit 1 | Reddit 2 | Reddit 3 | Reddit 4 | Reddit 5 | Physics | CS |
44 | | :---: | :---:| :---:| :---: | :---: | :---:| :---:| :---: | :---: |
45 | | Micro-F1 | 77.44 | 88.16 | 95.25 | 92.09 | 91.37 | 90.71 | 87.91 | 87.64 |
46 | | Macro-F1 | 75.28 | 87.84 | 95.04 | 91.77 | 89.50 | 90.30 | 82.57 | 84.45 |
47 |
48 | 2. Train on Cora + Pubmed, and test on the remaining graphs. The pre-trained model is saved in ```models/pretrain_cora+pubmed.pkl```. The test results are
49 |
50 | | Metric | Citeseer | Reddit 1 | Reddit 2 | Reddit 3 | Reddit 4 | Reddit 5 | Physics | CS |
51 | | :---: | :---:| :---:| :---: | :---: | :---:| :---:| :---: | :---: |
52 | | Micro-F1 | 65.76 | 88.14 | 95.14 | 92.08 | 91.05 | 90.38 | 87.14 | 88.15 |
53 | | Macro-F1 | 57.52 | 87.86 | 94.93 | 91.78 | 89.08 | 89.92 | 81.04 | 85.24 |
54 |
55 | 3. Train on Citeseer + Pubmed, and test on the remaining graphs. The pre-trained model is saved in ```models/pretrain_citeseer+pubmed.pkl```. The test results are
56 |
57 | | Metric | Cora | Reddit 1 | Reddit 2 | Reddit 3 | Reddit 4 | Reddit 5 | Physics | CS |
58 | | :---: | :---:| :---:| :---: | :---: | :---:| :---:| :---: | :---: |
59 | | Micro-F1 | 73.40 | 87.57 | 95.08 | 92.07 | 90.99 | 90.53 | 87.06 | 87.00 |
60 | | Macro-F1 | 71.22 | 87.11 | 94.87 | 91.74 | 88.97 | 90.14 | 81.20 | 83.90 |
61 |
--------------------------------------------------------------------------------
/data/citeseer/bestconfig.txt:
--------------------------------------------------------------------------------
1 | feature_normalize 1
2 |
3 |
--------------------------------------------------------------------------------
/data/citeseer/citeseer.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/citeseer.x.pkl
--------------------------------------------------------------------------------
/data/citeseer/citeseer.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/citeseer.y.pkl
--------------------------------------------------------------------------------
/data/citeseer/processed/data.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/processed/data.pt
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.allx
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.ally
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.graph
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.tx
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.ty
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.x
--------------------------------------------------------------------------------
/data/citeseer/raw/ind.citeseer.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.y
--------------------------------------------------------------------------------
/data/cora/bestconfig.txt:
--------------------------------------------------------------------------------
1 | feature_normalize 0
2 |
3 |
--------------------------------------------------------------------------------
/data/cora/cora.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/cora.x.pkl
--------------------------------------------------------------------------------
/data/cora/cora.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/cora.y.pkl
--------------------------------------------------------------------------------
/data/cora/processed/data.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/processed/data.pt
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.allx
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.ally
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.graph
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.tx
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.ty
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.x
--------------------------------------------------------------------------------
/data/cora/raw/ind.cora.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.y
--------------------------------------------------------------------------------
/data/pubmed/bestconfig.txt:
--------------------------------------------------------------------------------
1 | feature_normalize 1
2 |
3 |
--------------------------------------------------------------------------------
/data/pubmed/processed/data.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/processed/data.pt
--------------------------------------------------------------------------------
/data/pubmed/pubmed.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/pubmed.x.pkl
--------------------------------------------------------------------------------
/data/pubmed/pubmed.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/pubmed.y.pkl
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.allx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.allx
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.ally:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.ally
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.graph:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.graph
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.tx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.tx
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.ty:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.ty
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.x:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.x
--------------------------------------------------------------------------------
/data/pubmed/raw/ind.pubmed.y:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.y
--------------------------------------------------------------------------------
/data/reddit1401/bestconfig.txt:
--------------------------------------------------------------------------------
1 | feature_normalize 0
2 |
--------------------------------------------------------------------------------
/data/reddit1401/graph1/graph1.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph1/graph1.x.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph1/graph1.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph1/graph1.y.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph2/graph2.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph2/graph2.x.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph2/graph2.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph2/graph2.y.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph3/graph3.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph3/graph3.x.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph3/graph3.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph3/graph3.y.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph4/graph4.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph4/graph4.x.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph4/graph4.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph4/graph4.y.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph5/graph5.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph5/graph5.x.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graph5/graph5.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph5/graph5.y.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graphforvis/graphforvis.edgelist:
--------------------------------------------------------------------------------
1 | 85
2 | 16 32
3 | 72 31
4 | 1 31
5 | 69 35
6 | 78 35
7 | 69 74
8 | 35 29
9 | 15 70
10 | 32 70
11 | 32 7
12 | 68 7
13 | 56 30
14 | 32 30
15 | 74 48
16 | 72 9
17 | 34 9
18 | 79 9
19 | 54 42
20 | 78 61
21 | 35 61
22 | 61 17
23 | 72 55
24 | 31 55
25 | 1 67
26 | 9 67
27 | 55 67
28 | 15 84
29 | 2 84
30 | 70 84
31 | 29 84
32 | 4 84
33 | 32 84
34 | 30 84
35 | 48 25
36 | 31 25
37 | 17 25
38 | 35 25
39 | 74 25
40 | 27 25
41 | 61 25
42 | 78 25
43 | 25 53
44 | 2 40
45 | 32 40
46 | 25 40
47 | 62 40
48 | 78 0
49 | 72 10
50 | 37 10
51 | 2 46
52 | 1 12
53 | 35 77
54 | 25 77
55 | 75 77
56 | 60 77
57 | 51 77
58 | 40 77
59 | 31 77
60 | 74 77
61 | 10 77
62 | 73 77
63 | 76 77
64 | 69 77
65 | 23 77
66 | 22 77
67 | 44 77
68 | 4 77
69 | 84 77
70 | 17 77
71 | 27 77
72 | 77 18
73 | 60 18
74 | 77 36
75 | 51 36
76 | 18 36
77 | 77 39
78 | 77 58
79 | 77 63
80 | 36 63
81 | 18 63
82 | 77 83
83 | 1 6
84 | 54 6
85 | 42 6
86 | 67 6
87 | 37 6
88 | 74 26
89 | 35 26
90 | 25 26
91 | 27 26
92 | 69 26
93 | 35 43
94 | 72 65
95 | 9 65
96 | 37 65
97 | 10 65
98 | 32 14
99 | 35 71
100 | 77 21
101 | 74 49
102 | 77 66
103 | 77 80
104 | 26 80
105 | 35 80
106 | 25 80
107 | 74 80
108 | 49 80
109 | 61 8
110 | 25 8
111 | 80 8
112 | 74 8
113 | 26 8
114 | 73 11
115 | 84 33
116 | 60 59
117 | 17 57
118 | 25 57
119 | 35 57
120 | 69 57
121 | 84 41
122 | 14 64
123 | 77 82
124 | 32 5
125 | 2 5
126 | 50 5
127 | 25 28
128 | 77 47
129 | 63 47
130 | 77 81
131 | 18 81
132 | 65 3
133 | 67 3
134 | 72 3
135 | 2 13
136 | 5 13
137 | 44 24
138 | 81 24
139 | 76 24
140 | 77 24
141 | 40 45
142 | 25 19
143 | 77 19
144 | 49 19
145 | 35 19
146 | 57 19
147 | 80 19
148 | 77 38
149 | 18 38
150 | 35 52
151 | 77 52
152 | 18 52
153 | 57 52
154 | 19 52
155 | 61 52
156 | 40 20
157 | 52 20
158 |
--------------------------------------------------------------------------------
/data/reddit1401/graphforvis/graphforvis.x.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graphforvis/graphforvis.x.pkl
--------------------------------------------------------------------------------
/data/reddit1401/graphforvis/graphforvis.y.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graphforvis/graphforvis.y.pkl
--------------------------------------------------------------------------------
/mainfig_website.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/mainfig_website.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==2.2.3
2 | networkx==2.4
3 | scikit-learn==0.21.2
4 | numpy==1.16.3
5 | scipy==1.2.1
6 | torch==1.3.1
7 |
8 |
--------------------------------------------------------------------------------
/saved_models/pretrain_citeseer+pubmed.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_citeseer+pubmed.pkl
--------------------------------------------------------------------------------
/saved_models/pretrain_cora+citeseer.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_cora+citeseer.pkl
--------------------------------------------------------------------------------
/saved_models/pretrain_cora+pubmed.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_cora+pubmed.pkl
--------------------------------------------------------------------------------
/saved_models/pretrain_reddit1+2.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_reddit1+2.pkl
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/__init__.py
--------------------------------------------------------------------------------
/src/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/src/__pycache__/train.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/__pycache__/train.cpython-36.pyc
--------------------------------------------------------------------------------
/src/baselines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/baselines/__init__.py
--------------------------------------------------------------------------------
/src/baselines/active-learning/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution,
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/README.md:
--------------------------------------------------------------------------------
1 | # Active Learning Playground
2 |
3 | ## Introduction
4 |
5 | This is a python module for experimenting with different active learning
6 | algorithms. There are a few key components to running active learning
7 | experiments:
8 |
9 | * Main experiment script is
10 | [`run_experiment.py`](run_experiment.py)
11 | with many flags for different run options.
12 |
13 | * Supported datasets can be downloaded to a specified directory by running
14 | [`utils/create_data.py`](utils/create_data.py).
15 |
16 | * Supported active learning methods are in
17 | [`sampling_methods`](sampling_methods/).
18 |
19 | Below I will go into each component in more detail.
20 |
21 | DISCLAIMER: This is not an official Google product.
22 |
23 | ## Setup
24 | The dependencies are in [`requirements.txt`](requirements.txt). Please make sure these packages are
25 | installed before running experiments. If GPU capable `tensorflow` is desired, please follow
26 | instructions [here](https://www.tensorflow.org/install/).
27 |
28 | It is highly suggested that you install all dependencies into a separate `virtualenv` for
29 | easy package management.
30 |
31 | ## Getting benchmark datasets
32 |
33 | By default the datasets are saved to `/tmp/data`. You can specify another directory via the
34 | `--save_dir` flag.
35 |
36 | Redownloading all the datasets will be very time consuming so please be patient.
37 | You can specify a subset of the data to download by passing in a comma separated
38 | string of datasets via the `--datasets` flag.
39 |
40 | ## Running experiments
41 |
42 | There are a few key flags for
43 | [`run_experiment.py`](run_experiment.py):
44 |
45 | * `dataset`: name of the dataset, must match the save name used in
46 | `create_data.py`. Must also exist in the data_dir.
47 |
48 | * `sampling_method`: active learning method to use. Must be specified in
49 | [`sampling_methods/constants.py`](sampling_methods/constants.py).
50 |
51 | * `warmstart_size`: initial batch of uniformly sampled examples to use as seed
52 | data. Float indicates percentage of total training data and integer
53 | indicates raw size.
54 |
55 | * `batch_size`: number of datapoints to request in each batch. Float indicates
56 | percentage of total training data and integer indicates raw size.
57 |
58 | * `score_method`: model to use to evaluate the performance of the sampling
59 | method. Must be in `get_model` method of
60 | [`utils/utils.py`](utils/utils.py).
61 |
62 | * `data_dir`: directory with saved datasets.
63 |
64 | * `save_dir`: directory to save results.
65 |
66 | This is just a subset of all the flags. There are also options for
67 | preprocessing, introducing labeling noise, dataset subsampling, and using a
68 | different model to select than to score/evaluate.
69 |
70 | ## Available active learning methods
71 |
72 | All named active learning methods are in
73 | [`sampling_methods/constants.py`](sampling_methods/constants.py).
74 |
75 | You can also specify a mixture of active learning methods by following the
76 | pattern of `[sampling_method]-[mixture_weight]` separated by dashes; i.e.
77 | `mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34`.
78 |
79 | Some supported sampling methods include:
80 |
81 | * Uniform: samples are selected via uniform sampling.
82 |
83 | * Margin: uncertainty based sampling method.
84 |
85 | * Informative and diverse: margin and cluster based sampling method.
86 |
87 | * k-center greedy: representative strategy that greedily forms a batch of
88 | points to minimize maximum distance from a labeled point.
89 |
90 | * Graph density: representative strategy that selects points in dense regions
91 | of pool.
92 |
93 | * Exp3 bandit: meta-active learning method that tries to learns optimal
94 | sampling method using a popular multi-armed bandit algorithm.
95 |
96 | ### Adding new active learning methods
97 |
98 | Implement either a base sampler that inherits from
99 | [`SamplingMethod`](sampling_methods/sampling_def.py)
100 | or a meta-sampler that calls base samplers which inherits from
101 | [`WrapperSamplingMethod`](sampling_methods/wrapper_sampler_def.py).
102 |
103 | The only method that must be implemented by any sampler is `select_batch_`,
104 | which can have arbitrary named arguments. The only restriction is that the name
105 | for the same input must be consistent across all the samplers (i.e. the indices
106 | for already selected examples all have the same name across samplers). Adding a
107 | new named argument that hasn't been used in other sampling methods will require
108 | feeding that into the `select_batch` call in
109 | [`run_experiment.py`](run_experiment.py).
110 |
111 | After implementing your sampler, be sure to add it to
112 | [`constants.py`](sampling_methods/constants.py)
113 | so that it can be called from
114 | [`run_experiment.py`](run_experiment.py).
115 |
116 | ## Available models
117 |
118 | All available models are in the `get_model` method of
119 | [`utils/utils.py`](utils/utils.py).
120 |
121 | Supported methods:
122 |
123 | * Linear SVM: scikit method with grid search wrapper for regularization
124 | parameter.
125 |
126 | * Kernel SVM: scikit method with grid search wrapper for regularization
127 | parameter.
128 |
129 | * Logistc Regression: scikit method with grid search wrapper for
130 | regularization parameter.
131 |
132 | * Small CNN: 4 layer CNN optimized using rmsprop implemented in Keras with
133 | tensorflow backend.
134 |
135 | * Kernel Least Squares Classification: block gradient descient solver that can
136 | use multiple cores so is often faster than scikit Kernel SVM.
137 |
138 | ### Adding new models
139 |
140 | New models must follow the scikit learn api and implement the following methods
141 |
142 | * `fit(X, y[, sample_weight])`: fit the model to the input features and
143 | target.
144 |
145 | * `predict(X)`: predict the value of the input features.
146 |
147 | * `score(X, y)`: returns target metric given test features and test targets.
148 |
149 | * `decision_function(X)` (optional): return class probabilities, distance to
150 | decision boundaries, or other metric that can be used by margin sampler as a
151 | measure of uncertainty.
152 |
153 | See
154 | [`small_cnn.py`](utils/small_cnn.py)
155 | for an example.
156 |
157 | After implementing your new model, be sure to add it to `get_model` method of
158 | [`utils/utils.py`](utils/utils.py).
159 |
160 | Currently models must be added on a one-off basis and not all scikit-learn
161 | classifiers are supported due to the need for user input on whether and how to
162 | tune the hyperparameters of the model. However, it is very easy to add a
163 | scikit-learn model with hyperparameter search wrapped around as a supported
164 | model.
165 |
166 | ## Collecting results and charting
167 |
168 | The
169 | [`utils/chart_data.py`](utils/chart_data.py)
170 | script handles processing of data and charting for a specified dataset and
171 | source directory.
172 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | numpy>=1.13
3 | scipy>=0.19
4 | pandas>=0.20
5 | scikit-learn>=0.19
6 | matplotlib>=2.0.2
7 |
8 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/bandit_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Bandit wrapper around base AL sampling methods.
16 |
17 | Assumes adversarial multi-armed bandit setting where arms correspond to
18 | mixtures of different AL methods.
19 |
20 | Uses EXP3 algorithm to decide which AL method to use to create the next batch.
21 | Similar to Hsu & Lin 2015, Active Learning by Learning.
22 | https://www.csie.ntu.edu.tw/~htlin/paper/doc/aaai15albl.pdf
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import numpy as np
30 |
31 | from sampling_methods.wrapper_sampler_def import AL_MAPPING, WrapperSamplingMethod
32 |
33 |
34 | class BanditDiscreteSampler(WrapperSamplingMethod):
35 | """Wraps EXP3 around mixtures of indicated methods.
36 |
37 | Uses EXP3 mult-armed bandit algorithm to select sampler methods.
38 | """
39 | def __init__(self,
40 | X,
41 | y,
42 | seed,
43 | reward_function = lambda AL_acc: AL_acc[-1],
44 | gamma=0.5,
45 | samplers=[{'methods':('margin','uniform'),'weights':(0,1)},
46 | {'methods':('margin','uniform'),'weights':(1,0)}]):
47 | """Initializes sampler with indicated gamma and arms.
48 |
49 | Args:
50 | X: training data
51 | y: labels, may need to be input into base samplers
52 | seed: seed to use for random sampling
53 | reward_function: reward based on previously observed accuracies. Assumes
54 | that the input is a sequence of observed accuracies. Will ultimately be
55 | a class method and may need access to other class properties.
56 | gamma: weight on uniform mixture. Arm probability updates are a weighted
57 | mixture of uniform and an exponentially weighted distribution.
58 | Lower gamma more aggressively updates based on observed rewards.
59 | samplers: list of dicts with two fields
60 | 'samplers': list of named samplers
61 | 'weights': percentage of batch to allocate to each sampler
62 | """
63 |
64 | self.name = 'bandit_discrete'
65 | np.random.seed(seed)
66 | self.X = X
67 | self.y = y
68 | self.seed = seed
69 | self.initialize_samplers(samplers)
70 |
71 | self.gamma = gamma
72 | self.n_arms = len(samplers)
73 | self.reward_function = reward_function
74 |
75 | self.pull_history = []
76 | self.acc_history = []
77 | self.w = np.ones(self.n_arms)
78 | self.x = np.zeros(self.n_arms)
79 | self.p = self.w / (1.0 * self.n_arms)
80 | self.probs = []
81 |
82 | def update_vars(self, arm_pulled):
83 | reward = self.reward_function(self.acc_history)
84 | self.x = np.zeros(self.n_arms)
85 | self.x[arm_pulled] = reward / self.p[arm_pulled]
86 | self.w = self.w * np.exp(self.gamma * self.x / self.n_arms)
87 | self.p = ((1.0 - self.gamma) * self.w / sum(self.w)
88 | + self.gamma / self.n_arms)
89 | print(self.p)
90 | self.probs.append(self.p)
91 |
92 | def select_batch_(self, already_selected, N, eval_acc, **kwargs):
93 | """Returns batch of datapoints sampled using mixture of AL_methods.
94 |
95 | Assumes that data has already been shuffled.
96 |
97 | Args:
98 | already_selected: index of datapoints already selected
99 | N: batch size
100 | eval_acc: accuracy of model trained after incorporating datapoints from
101 | last recommended batch
102 |
103 | Returns:
104 | indices of points selected to label
105 | """
106 | # Update observed reward and arm probabilities
107 | self.acc_history.append(eval_acc)
108 | if len(self.pull_history) > 0:
109 | self.update_vars(self.pull_history[-1])
110 | # Sample an arm
111 | arm = np.random.choice(range(self.n_arms), p=self.p)
112 | self.pull_history.append(arm)
113 | kwargs['N'] = N
114 | kwargs['already_selected'] = already_selected
115 | sample = self.samplers[arm].select_batch(**kwargs)
116 | return sample
117 |
118 | def to_dict(self):
119 | output = {}
120 | output['samplers'] = self.base_samplers
121 | output['arm_probs'] = self.probs
122 | output['pull_history'] = self.pull_history
123 | output['rewards'] = self.acc_history
124 | return output
125 |
126 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Controls imports to fill up dictionary of different sampling methods.
16 | """
17 |
18 | from functools import partial
19 | AL_MAPPING = {}
20 |
21 |
22 | def get_base_AL_mapping():
23 | from sampling_methods.margin_AL import MarginAL
24 | from sampling_methods.informative_diverse import InformativeClusterDiverseSampler
25 | from sampling_methods.hierarchical_clustering_AL import HierarchicalClusterAL
26 | from sampling_methods.uniform_sampling import UniformSampling
27 | from sampling_methods.represent_cluster_centers import RepresentativeClusterMeanSampling
28 | from sampling_methods.graph_density import GraphDensitySampler
29 | from sampling_methods.kcenter_greedy import kCenterGreedy
30 | AL_MAPPING['margin'] = MarginAL
31 | AL_MAPPING['informative_diverse'] = InformativeClusterDiverseSampler
32 | AL_MAPPING['hierarchical'] = HierarchicalClusterAL
33 | AL_MAPPING['uniform'] = UniformSampling
34 | AL_MAPPING['margin_cluster_mean'] = RepresentativeClusterMeanSampling
35 | AL_MAPPING['graph_density'] = GraphDensitySampler
36 | AL_MAPPING['kcenter'] = kCenterGreedy
37 |
38 |
39 | def get_all_possible_arms():
40 | from sampling_methods.mixture_of_samplers import MixtureOfSamplers
41 | AL_MAPPING['mixture_of_samplers'] = MixtureOfSamplers
42 |
43 |
44 | def get_wrapper_AL_mapping():
45 | from sampling_methods.bandit_discrete import BanditDiscreteSampler
46 | from sampling_methods.simulate_batch import SimulateBatchSampler
47 | AL_MAPPING['bandit_mixture'] = partial(
48 | BanditDiscreteSampler,
49 | samplers=[{
50 | 'methods': ['margin', 'uniform'],
51 | 'weights': [0, 1]
52 | }, {
53 | 'methods': ['margin', 'uniform'],
54 | 'weights': [0.25, 0.75]
55 | }, {
56 | 'methods': ['margin', 'uniform'],
57 | 'weights': [0.5, 0.5]
58 | }, {
59 | 'methods': ['margin', 'uniform'],
60 | 'weights': [0.75, 0.25]
61 | }, {
62 | 'methods': ['margin', 'uniform'],
63 | 'weights': [1, 0]
64 | }])
65 | AL_MAPPING['bandit_discrete'] = partial(
66 | BanditDiscreteSampler,
67 | samplers=[{
68 | 'methods': ['margin', 'uniform'],
69 | 'weights': [0, 1]
70 | }, {
71 | 'methods': ['margin', 'uniform'],
72 | 'weights': [1, 0]
73 | }])
74 | AL_MAPPING['simulate_batch_mixture'] = partial(
75 | SimulateBatchSampler,
76 | samplers=({
77 | 'methods': ['margin', 'uniform'],
78 | 'weights': [1, 0]
79 | }, {
80 | 'methods': ['margin', 'uniform'],
81 | 'weights': [0.5, 0.5]
82 | }, {
83 | 'methods': ['margin', 'uniform'],
84 | 'weights': [0, 1]
85 | }),
86 | n_sims=5,
87 | train_per_sim=10,
88 | return_best_sim=False)
89 | AL_MAPPING['simulate_batch_best_sim'] = partial(
90 | SimulateBatchSampler,
91 | samplers=[{
92 | 'methods': ['margin', 'uniform'],
93 | 'weights': [1, 0]
94 | }],
95 | n_sims=10,
96 | train_per_sim=10,
97 | return_type='best_sim')
98 | AL_MAPPING['simulate_batch_frequency'] = partial(
99 | SimulateBatchSampler,
100 | samplers=[{
101 | 'methods': ['margin', 'uniform'],
102 | 'weights': [1, 0]
103 | }],
104 | n_sims=10,
105 | train_per_sim=10,
106 | return_type='frequency')
107 |
108 | def get_mixture_of_samplers(name):
109 | assert 'mixture_of_samplers' in name
110 | if 'mixture_of_samplers' not in AL_MAPPING:
111 | raise KeyError('Mixture of Samplers not yet loaded.')
112 | args = name.split('-')[1:]
113 | samplers = args[0::2]
114 | weights = args[1::2]
115 | weights = [float(w) for w in weights]
116 | assert sum(weights) == 1
117 | mixture = {'methods': samplers, 'weights': weights}
118 | print(mixture)
119 | return partial(AL_MAPPING['mixture_of_samplers'], mixture=mixture)
120 |
121 |
122 | def get_AL_sampler(name):
123 | if name in AL_MAPPING and name != 'mixture_of_samplers':
124 | return AL_MAPPING[name]
125 | if 'mixture_of_samplers' in name:
126 | return get_mixture_of_samplers(name)
127 | raise NotImplementedError('The specified sampler is not available.')
128 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/graph_density.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Diversity promoting sampling method that uses graph density to determine
16 | most representative points.
17 |
18 | This is an implementation of the method described in
19 | https://www.mpi-inf.mpg.de/fileadmin/inf/d2/Research_projects_files/EbertCVPR2012.pdf
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import copy
27 |
28 | from sklearn.neighbors import kneighbors_graph
29 | from sklearn.metrics import pairwise_distances
30 | import numpy as np
31 | from sampling_methods.sampling_def import SamplingMethod
32 |
33 |
34 | class GraphDensitySampler(SamplingMethod):
35 | """Diversity promoting sampling method that uses graph density to determine
36 | most representative points.
37 | """
38 |
39 | def __init__(self, X, y, seed):
40 | self.name = 'graph_density'
41 | self.X = X
42 | self.flat_X = self.flatten_X()
43 | # Set gamma for gaussian kernel to be equal to 1/n_features
44 | self.gamma = 1. / self.X.shape[1]
45 | self.compute_graph_density()
46 |
47 | def compute_graph_density(self, n_neighbor=10):
48 | # kneighbors graph is constructed using k=10
49 | connect = kneighbors_graph(self.flat_X, n_neighbor,p=1)
50 | # Make connectivity matrix symmetric, if a point is a k nearest neighbor of
51 | # another point, make it vice versa
52 | neighbors = connect.nonzero()
53 | inds = zip(neighbors[0],neighbors[1])
54 | # Graph edges are weighted by applying gaussian kernel to manhattan dist.
55 | # By default, gamma for rbf kernel is equal to 1/n_features but may
56 | # get better results if gamma is tuned.
57 | for entry in inds:
58 | i = entry[0]
59 | j = entry[1]
60 | distance = pairwise_distances(self.flat_X[[i]],self.flat_X[[j]],metric='manhattan')
61 | distance = distance[0,0]
62 | weight = np.exp(-distance * self.gamma)
63 | connect[i,j] = weight
64 | connect[j,i] = weight
65 | self.connect = connect
66 | # Define graph density for an observation to be sum of weights for all
67 | # edges to the node representing the datapoint. Normalize sum weights
68 | # by total number of neighbors.
69 | self.graph_density = np.zeros(self.X.shape[0])
70 | for i in np.arange(self.X.shape[0]):
71 | self.graph_density[i] = connect[i,:].sum() / (connect[i,:]>0).sum()
72 | self.starting_density = copy.deepcopy(self.graph_density)
73 |
74 | def select_batch_(self, N, already_selected, **kwargs):
75 | # If a neighbor has already been sampled, reduce the graph density
76 | # for its direct neighbors to promote diversity.
77 | batch = set()
78 | self.graph_density[already_selected] = min(self.graph_density) - 1
79 | while len(batch) < N:
80 | selected = np.argmax(self.graph_density)
81 | neighbors = (self.connect[selected,:] > 0).nonzero()[1]
82 | self.graph_density[neighbors] = self.graph_density[neighbors] - self.graph_density[selected]
83 | batch.add(selected)
84 | self.graph_density[already_selected] = min(self.graph_density) - 1
85 | self.graph_density[list(batch)] = min(self.graph_density) - 1
86 | return list(batch)
87 |
88 | def to_dict(self):
89 | output = {}
90 | output['connectivity'] = self.connect
91 | output['graph_density'] = self.starting_density
92 | return output
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/informative_diverse.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Informative and diverse batch sampler that samples points with small margin
16 | while maintaining same distribution over clusters as entire training data.
17 |
18 | Batch is created by sorting datapoints by increasing margin and then growing
19 | the batch greedily. A point is added to the batch if the result batch still
20 | respects the constraint that the cluster distribution of the batch will
21 | match the cluster distribution of the entire training set.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from sklearn.cluster import MiniBatchKMeans
29 | import numpy as np
30 | from sampling_methods.sampling_def import SamplingMethod
31 |
32 |
33 | class InformativeClusterDiverseSampler(SamplingMethod):
34 | """Selects batch based on informative and diverse criteria.
35 |
36 | Returns highest uncertainty lowest margin points while maintaining
37 | same distribution over clusters as entire dataset.
38 | """
39 |
40 | def __init__(self, X, y, seed):
41 | self.name = 'informative_and_diverse'
42 | self.X = X
43 | self.flat_X = self.flatten_X()
44 | # y only used for determining how many clusters there should be
45 | # probably not practical to assume we know # of classes before hand
46 | # should also probably scale with dimensionality of data
47 | self.y = y
48 | self.n_clusters = len(list(set(y)))
49 | self.cluster_model = MiniBatchKMeans(n_clusters=self.n_clusters)
50 | self.cluster_data()
51 |
52 | def cluster_data(self):
53 | # Probably okay to always use MiniBatchKMeans
54 | # Should standardize data before clustering
55 | # Can cluster on standardized data but train on raw features if desired
56 | self.cluster_model.fit(self.flat_X)
57 | unique, counts = np.unique(self.cluster_model.labels_, return_counts=True)
58 | self.cluster_prob = counts/sum(counts)
59 | self.cluster_labels = self.cluster_model.labels_
60 |
61 | def select_batch_(self, model, already_selected, N, **kwargs):
62 | """Returns a batch of size N using informative and diverse selection.
63 |
64 | Args:
65 | model: scikit learn model with decision_function implemented
66 | already_selected: index of datapoints already selected
67 | N: batch size
68 |
69 | Returns:
70 | indices of points selected to add using margin active learner
71 | """
72 | # TODO(lishal): have MarginSampler and this share margin function
73 | try:
74 | distances = model.decision_function(self.X)
75 | except:
76 | distances = model.predict_proba(self.X)
77 | if len(distances.shape) < 2:
78 | min_margin = abs(distances)
79 | else:
80 | sort_distances = np.sort(distances, 1)[:, -2:]
81 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
82 | rank_ind = np.argsort(min_margin)
83 | rank_ind = [i for i in rank_ind if i not in already_selected]
84 | new_batch_cluster_counts = [0 for _ in range(self.n_clusters)]
85 | new_batch = []
86 | for i in rank_ind:
87 | if len(new_batch) == N:
88 | break
89 | label = self.cluster_labels[i]
90 | if new_batch_cluster_counts[label] / N < self.cluster_prob[label]:
91 | new_batch.append(i)
92 | new_batch_cluster_counts[label] += 1
93 | n_slot_remaining = N - len(new_batch)
94 | batch_filler = list(set(rank_ind) - set(already_selected) - set(new_batch))
95 | new_batch.extend(batch_filler[0:n_slot_remaining])
96 | return new_batch
97 |
98 | def to_dict(self):
99 | output = {}
100 | output['cluster_membership'] = self.cluster_labels
101 | return output
102 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/kcenter_greedy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Returns points that minimizes the maximum distance of any point to a center.
16 |
17 | Implements the k-Center-Greedy method in
18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for
19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017
20 |
21 | Distance metric defaults to l2 distance. Features used to calculate distance
22 | are either raw features or if a model has transform method then uses the output
23 | of model.transform(X).
24 |
25 | Can be extended to a robust k centers algorithm that ignores a certain number of
26 | outlier datapoints. Resulting centers are solution to multiple integer program.
27 | """
28 |
29 | from __future__ import absolute_import
30 | from __future__ import division
31 | from __future__ import print_function
32 |
33 | import numpy as np
34 | from sklearn.metrics import pairwise_distances
35 | from sampling_methods.sampling_def import SamplingMethod
36 |
37 |
38 | class kCenterGreedy(SamplingMethod):
39 |
40 | def __init__(self, X, y, seed, metric='euclidean'):
41 | self.X = X
42 | self.y = y
43 | self.flat_X = self.flatten_X()
44 | self.name = 'kcenter'
45 | self.features = self.flat_X
46 | self.metric = metric
47 | self.min_distances = None
48 | self.n_obs = self.X.shape[0]
49 | self.already_selected = []
50 |
51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
52 | """Update min distances given cluster centers.
53 |
54 | Args:
55 | cluster_centers: indices of cluster centers
56 | only_new: only calculate distance for newly selected points and update
57 | min_distances.
58 | rest_dist: whether to reset min_distances.
59 | """
60 |
61 | if reset_dist:
62 | self.min_distances = None
63 | if only_new:
64 | cluster_centers = [d for d in cluster_centers
65 | if d not in self.already_selected]
66 | if cluster_centers:
67 | # Update min_distances for all examples given new cluster center.
68 | x = self.features[cluster_centers]
69 | dist = pairwise_distances(self.features, x, metric=self.metric)
70 |
71 | if self.min_distances is None:
72 | self.min_distances = np.min(dist, axis=1).reshape(-1,1)
73 | else:
74 | self.min_distances = np.minimum(self.min_distances, dist)
75 |
76 | def select_batch_(self, model, already_selected, N, **kwargs):
77 | """
78 | Diversity promoting active learning method that greedily forms a batch
79 | to minimize the maximum distance to a cluster center among all unlabeled
80 | datapoints.
81 |
82 | Args:
83 | model: model with scikit-like API with decision_function implemented
84 | already_selected: index of datapoints already selected
85 | N: batch size
86 |
87 | Returns:
88 | indices of points selected to minimize distance to cluster centers
89 | """
90 |
91 | try:
92 | # Assumes that the transform function takes in original data and not
93 | # flattened data.
94 | print('Getting transformed features...')
95 | self.features = model.transform(self.X)
96 | print('Calculating distances...')
97 | self.update_distances(already_selected, only_new=False, reset_dist=True)
98 | except:
99 | print('Using flat_X as features.')
100 | self.update_distances(already_selected, only_new=True, reset_dist=False)
101 |
102 | new_batch = []
103 |
104 | for _ in range(N):
105 | if self.already_selected is None:
106 | # Initialize centers with a randomly selected datapoint
107 | ind = np.random.choice(np.arange(self.n_obs))
108 | else:
109 | ind = np.argmax(self.min_distances)
110 | # New examples should not be in already selected since those points
111 | # should have min_distance of zero to a cluster center.
112 | assert ind not in already_selected
113 |
114 | self.update_distances([ind], only_new=True, reset_dist=False)
115 | new_batch.append(ind)
116 | print('Maximum distance from cluster centers is %0.2f'
117 | % max(self.min_distances))
118 |
119 |
120 | self.already_selected = already_selected
121 |
122 | return new_batch
123 |
124 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/margin_AL.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Margin based AL method.
16 |
17 | Samples in batches based on margin scores.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import numpy as np
25 | from sampling_methods.sampling_def import SamplingMethod
26 |
27 |
28 | class MarginAL(SamplingMethod):
29 | def __init__(self, X, y, seed):
30 | self.X = X
31 | self.y = y
32 | self.name = 'margin'
33 |
34 | def select_batch_(self, model, already_selected, N, **kwargs):
35 | """Returns batch of datapoints with smallest margin/highest uncertainty.
36 |
37 | For binary classification, can just take the absolute distance to decision
38 | boundary for each point.
39 | For multiclass classification, must consider the margin between distance for
40 | top two most likely classes.
41 |
42 | Args:
43 | model: scikit learn model with decision_function implemented
44 | already_selected: index of datapoints already selected
45 | N: batch size
46 |
47 | Returns:
48 | indices of points selected to add using margin active learner
49 | """
50 |
51 | try:
52 | distances = model.decision_function(self.X)
53 | except:
54 | distances = model.predict_proba(self.X)
55 | if len(distances.shape) < 2:
56 | min_margin = abs(distances)
57 | else:
58 | sort_distances = np.sort(distances, 1)[:, -2:]
59 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
60 | rank_ind = np.argsort(min_margin)
61 | rank_ind = [i for i in rank_ind if i not in already_selected]
62 | active_samples = rank_ind[0:N]
63 | return active_samples
64 |
65 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/mixture_of_samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mixture of base sampling strategies
16 |
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import copy
24 |
25 | from sampling_methods.sampling_def import SamplingMethod
26 | from sampling_methods.constants import AL_MAPPING, get_base_AL_mapping
27 |
28 | get_base_AL_mapping()
29 |
30 |
31 | class MixtureOfSamplers(SamplingMethod):
32 | """Samples according to mixture of base sampling methods.
33 |
34 | If duplicate points are selected by the mixed strategies when forming the batch
35 | then the remaining slots are divided according to mixture weights and
36 | another partial batch is requested until the batch is full.
37 | """
38 | def __init__(self,
39 | X,
40 | y,
41 | seed,
42 | mixture={'methods': ('margin', 'uniform'),
43 | 'weight': (0.5, 0.5)},
44 | samplers=None):
45 | self.X = X
46 | self.y = y
47 | self.name = 'mixture_of_samplers'
48 | self.sampling_methods = mixture['methods']
49 | self.sampling_weights = dict(zip(mixture['methods'], mixture['weights']))
50 | self.seed = seed
51 | # A list of initialized samplers is allowed as an input because
52 | # for AL_methods that search over different mixtures, may want mixtures to
53 | # have shared AL_methods so that initialization is only performed once for
54 | # computation intensive methods like HierarchicalClusteringAL and
55 | # states are shared between mixtures.
56 | # If initialized samplers are not provided, initialize them ourselves.
57 | if samplers is None:
58 | self.samplers = {}
59 | self.initialize(self.sampling_methods)
60 | else:
61 | self.samplers = samplers
62 | self.history = []
63 |
64 | def initialize(self, samplers):
65 | self.samplers = {}
66 | for s in samplers:
67 | self.samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed)
68 |
69 | def select_batch_(self, already_selected, N, **kwargs):
70 | """Returns batch of datapoints selected according to mixture weights.
71 |
72 | Args:
73 | already_included: index of datapoints already selected
74 | N: batch size
75 |
76 | Returns:
77 | indices of points selected to add using margin active learner
78 | """
79 | kwargs['already_selected'] = copy.copy(already_selected)
80 | inds = set()
81 | self.selected_by_sampler = {}
82 | for s in self.sampling_methods:
83 | self.selected_by_sampler[s] = []
84 | effective_N = 0
85 | while len(inds) < N:
86 | effective_N += N - len(inds)
87 | for s in self.sampling_methods:
88 | if len(inds) < N:
89 | batch_size = min(max(int(self.sampling_weights[s] * effective_N), 1), N)
90 | sampler = self.samplers[s]
91 | kwargs['N'] = batch_size
92 | s_inds = sampler.select_batch(**kwargs)
93 | for ind in s_inds:
94 | if ind not in self.selected_by_sampler[s]:
95 | self.selected_by_sampler[s].append(ind)
96 | s_inds = [d for d in s_inds if d not in inds]
97 | s_inds = s_inds[0 : min(len(s_inds), N-len(inds))]
98 | inds.update(s_inds)
99 | self.history.append(copy.deepcopy(self.selected_by_sampler))
100 | return list(inds)
101 |
102 | def to_dict(self):
103 | output = {}
104 | output['history'] = self.history
105 | output['samplers'] = self.sampling_methods
106 | output['mixture_weights'] = self.sampling_weights
107 | for s in self.samplers:
108 | s_output = self.samplers[s].to_dict()
109 | output[s] = s_output
110 | return output
111 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/represent_cluster_centers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Another informative and diverse sampler that mirrors the algorithm described
16 | in Xu, et. al., Representative Sampling for Text Classification Using
17 | Support Vector Machines, 2003
18 |
19 | Batch is created by clustering points within the margin of the classifier and
20 | choosing points closest to the k centroids.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from sklearn.cluster import MiniBatchKMeans
28 | import numpy as np
29 | from sampling_methods.sampling_def import SamplingMethod
30 |
31 |
32 | class RepresentativeClusterMeanSampling(SamplingMethod):
33 | """Selects batch based on informative and diverse criteria.
34 |
35 | Returns points within the margin of the classifier that are closest to the
36 | k-means centers of those points.
37 | """
38 |
39 | def __init__(self, X, y, seed):
40 | self.name = 'cluster_mean'
41 | self.X = X
42 | self.flat_X = self.flatten_X()
43 | self.y = y
44 | self.seed = seed
45 |
46 | def select_batch_(self, model, N, already_selected, **kwargs):
47 | # Probably okay to always use MiniBatchKMeans
48 | # Should standardize data before clustering
49 | # Can cluster on standardized data but train on raw features if desired
50 | try:
51 | distances = model.decision_function(self.X)
52 | except:
53 | distances = model.predict_proba(self.X)
54 | if len(distances.shape) < 2:
55 | min_margin = abs(distances)
56 | else:
57 | sort_distances = np.sort(distances, 1)[:, -2:]
58 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
59 | rank_ind = np.argsort(min_margin)
60 | rank_ind = [i for i in rank_ind if i not in already_selected]
61 |
62 | distances = abs(model.decision_function(self.X))
63 | min_margin_by_class = np.min(abs(distances[already_selected]),axis=0)
64 | unlabeled_in_margin = np.array([i for i in range(len(self.y))
65 | if i not in already_selected and
66 | any(distances[i] 2:
42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
43 | return flat_X
44 |
45 |
46 | @abc.abstractmethod
47 | def select_batch_(self):
48 | return
49 |
50 | def select_batch(self, **kwargs):
51 | return self.select_batch_(**kwargs)
52 |
53 | def to_dict(self):
54 | return None
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/uniform_sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Uniform sampling method.
16 |
17 | Samples in batches.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import numpy as np
25 |
26 | from sampling_methods.sampling_def import SamplingMethod
27 |
28 |
29 | class UniformSampling(SamplingMethod):
30 |
31 | def __init__(self, X, y, seed):
32 | self.X = X
33 | self.y = y
34 | self.name = 'uniform'
35 | np.random.seed(seed)
36 |
37 | def select_batch_(self, already_selected, N, **kwargs):
38 | """Returns batch of randomly sampled datapoints.
39 |
40 | Assumes that data has already been shuffled.
41 |
42 | Args:
43 | already_selected: index of datapoints already selected
44 | N: batch size
45 |
46 | Returns:
47 | indices of points selected to label
48 | """
49 |
50 | # This is uniform given the remaining pool but biased wrt the entire pool.
51 | sample = [i for i in range(self.X.shape[0]) if i not in already_selected]
52 | return sample[0:N]
53 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/utils/tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Node and Tree class to support hierarchical clustering AL method.
16 |
17 | Assumed to be binary tree.
18 |
19 | Node class is used to represent each node in a hierarchical clustering.
20 | Each node has certain properties that are used in the AL method.
21 |
22 | Tree class is used to traverse a hierarchical clustering.
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import copy
30 |
31 |
32 | class Node(object):
33 | """Node class for hierarchical clustering.
34 |
35 | Initialized with name and left right children.
36 | """
37 |
38 | def __init__(self, name, left=None, right=None):
39 | self.name = name
40 | self.left = left
41 | self.right = right
42 | self.is_leaf = left is None and right is None
43 | self.parent = None
44 | # Fields for hierarchical clustering AL
45 | self.score = 1.0
46 | self.split = False
47 | self.best_label = None
48 | self.weight = None
49 |
50 | def set_parent(self, parent):
51 | self.parent = parent
52 |
53 |
54 | class Tree(object):
55 | """Tree object for traversing a binary tree.
56 |
57 | Most methods apply to trees in general with the exception of get_pruning
58 | which is specific to the hierarchical clustering AL method.
59 | """
60 |
61 | def __init__(self, root, node_dict):
62 | """Initializes tree and creates all nodes in node_dict.
63 |
64 | Args:
65 | root: id of the root node
66 | node_dict: dictionary with node_id as keys and entries indicating
67 | left and right child of node respectively.
68 | """
69 | self.node_dict = node_dict
70 | self.root = self.make_tree(root)
71 | self.nodes = {}
72 | self.leaves_mapping = {}
73 | self.fill_parents()
74 | self.n_leaves = None
75 |
76 | def print_tree(self, node, max_depth):
77 | """Helper function to print out tree for debugging."""
78 | node_list = [node]
79 | output = ""
80 | level = 0
81 | while level < max_depth and len(node_list):
82 | children = set()
83 | for n in node_list:
84 | node = self.get_node(n)
85 | output += ("\t"*level+"node %d: score %.2f, weight %.2f" %
86 | (node.name, node.score, node.weight)+"\n")
87 | if node.left:
88 | children.add(node.left.name)
89 | if node.right:
90 | children.add(node.right.name)
91 | level += 1
92 | node_list = children
93 | return print(output)
94 |
95 | def make_tree(self, node_id):
96 | if node_id is not None:
97 | return Node(node_id,
98 | self.make_tree(self.node_dict[node_id][0]),
99 | self.make_tree(self.node_dict[node_id][1]))
100 |
101 | def fill_parents(self):
102 | # Setting parent and storing nodes in dict for fast access
103 | def rec(pointer, parent):
104 | if pointer is not None:
105 | self.nodes[pointer.name] = pointer
106 | pointer.set_parent(parent)
107 | rec(pointer.left, pointer)
108 | rec(pointer.right, pointer)
109 | rec(self.root, None)
110 |
111 | def get_node(self, node_id):
112 | return self.nodes[node_id]
113 |
114 | def get_ancestor(self, node):
115 | ancestors = []
116 | if isinstance(node, int):
117 | node = self.get_node(node)
118 | while node.name != self.root.name:
119 | node = node.parent
120 | ancestors.append(node.name)
121 | return ancestors
122 |
123 | def fill_weights(self):
124 | for v in self.node_dict:
125 | node = self.get_node(v)
126 | node.weight = len(self.leaves_mapping[v]) / (1.0 * self.n_leaves)
127 |
128 | def create_child_leaves_mapping(self, leaves):
129 | """DP for creating child leaves mapping.
130 |
131 | Storing in dict to save recompute.
132 | """
133 | self.n_leaves = len(leaves)
134 | for v in leaves:
135 | self.leaves_mapping[v] = [v]
136 | node_list = set([self.get_node(v).parent for v in leaves])
137 | while node_list:
138 | to_fill = copy.copy(node_list)
139 | for v in node_list:
140 | if (v.left.name in self.leaves_mapping
141 | and v.right.name in self.leaves_mapping):
142 | to_fill.remove(v)
143 | self.leaves_mapping[v.name] = (self.leaves_mapping[v.left.name] +
144 | self.leaves_mapping[v.right.name])
145 | if v.parent is not None:
146 | to_fill.add(v.parent)
147 | node_list = to_fill
148 | self.fill_weights()
149 |
150 | def get_child_leaves(self, node):
151 | return self.leaves_mapping[node]
152 |
153 | def get_pruning(self, node):
154 | if node.split:
155 | return self.get_pruning(node.left) + self.get_pruning(node.right)
156 | else:
157 | return [node.name]
158 |
159 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/utils/tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sampling_methods.utils.tree."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import unittest
22 | from sampling_methods.utils import tree
23 |
24 |
25 | class TreeTest(unittest.TestCase):
26 |
27 | def setUp(self):
28 | node_dict = {
29 | 1: (2, 3),
30 | 2: (4, 5),
31 | 3: (6, 7),
32 | 4: [None, None],
33 | 5: [None, None],
34 | 6: [None, None],
35 | 7: [None, None]
36 | }
37 | self.tree = tree.Tree(1, node_dict)
38 | self.tree.create_child_leaves_mapping([4, 5, 6, 7])
39 | node = self.tree.get_node(1)
40 | node.split = True
41 | node = self.tree.get_node(2)
42 | node.split = True
43 |
44 | def assertNode(self, node, name, left, right):
45 | self.assertEqual(node.name, name)
46 | self.assertEqual(node.left.name, left)
47 | self.assertEqual(node.right.name, right)
48 |
49 | def testTreeRootSetCorrectly(self):
50 | self.assertNode(self.tree.root, 1, 2, 3)
51 |
52 | def testGetNode(self):
53 | node = self.tree.get_node(1)
54 | assert isinstance(node, tree.Node)
55 | self.assertEqual(node.name, 1)
56 |
57 | def testFillParent(self):
58 | node = self.tree.get_node(3)
59 | self.assertEqual(node.parent.name, 1)
60 |
61 | def testGetAncestors(self):
62 | ancestors = self.tree.get_ancestor(5)
63 | self.assertTrue(all([a in ancestors for a in [1, 2]]))
64 |
65 | def testChildLeaves(self):
66 | leaves = self.tree.get_child_leaves(3)
67 | self.assertTrue(all([c in leaves for c in [6, 7]]))
68 |
69 | def testFillWeights(self):
70 | node = self.tree.get_node(3)
71 | self.assertEqual(node.weight, 0.5)
72 |
73 | def testGetPruning(self):
74 | node = self.tree.get_node(1)
75 | pruning = self.tree.get_pruning(node)
76 | self.assertTrue(all([n in pruning for n in [3, 4, 5]]))
77 |
78 | if __name__ == '__main__':
79 | unittest.main()
80 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/sampling_methods/wrapper_sampler_def.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Abstract class for wrapper sampling methods that call base sampling methods.
16 |
17 | Provides interface to sampling methods that allow same signature
18 | for select_batch. Each subclass implements select_batch_ with the desired
19 | signature for readability.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import abc
27 |
28 | from sampling_methods.constants import AL_MAPPING
29 | from sampling_methods.constants import get_all_possible_arms
30 | from sampling_methods.sampling_def import SamplingMethod
31 |
32 | get_all_possible_arms()
33 |
34 |
35 | class WrapperSamplingMethod(SamplingMethod):
36 | __metaclass__ = abc.ABCMeta
37 |
38 | def initialize_samplers(self, mixtures):
39 | methods = []
40 | for m in mixtures:
41 | methods += m['methods']
42 | methods = set(methods)
43 | self.base_samplers = {}
44 | for s in methods:
45 | self.base_samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed)
46 | self.samplers = []
47 | for m in mixtures:
48 | self.samplers.append(
49 | AL_MAPPING['mixture_of_samplers'](self.X, self.y, self.seed, m,
50 | self.base_samplers))
51 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/utils/allconv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implements allconv model in keras using tensorflow backend."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 |
22 | import keras
23 | import keras.backend as K
24 | from keras.layers import Activation
25 | from keras.layers import Conv2D
26 | from keras.layers import Dropout
27 | from keras.layers import GlobalAveragePooling2D
28 | from keras.models import Sequential
29 |
30 | import numpy as np
31 | import tensorflow as tf
32 |
33 |
34 | class AllConv(object):
35 | """allconv network that matches sklearn api."""
36 |
37 | def __init__(self,
38 | random_state=1,
39 | epochs=50,
40 | batch_size=32,
41 | solver='rmsprop',
42 | learning_rate=0.001,
43 | lr_decay=0.):
44 | # params
45 | self.solver = solver
46 | self.epochs = epochs
47 | self.batch_size = batch_size
48 | self.learning_rate = learning_rate
49 | self.lr_decay = lr_decay
50 | # data
51 | self.encode_map = None
52 | self.decode_map = None
53 | self.model = None
54 | self.random_state = random_state
55 | self.n_classes = None
56 |
57 | def build_model(self, X):
58 | # assumes that data axis order is same as the backend
59 | input_shape = X.shape[1:]
60 | np.random.seed(self.random_state)
61 | tf.set_random_seed(self.random_state)
62 |
63 | model = Sequential()
64 | model.add(Conv2D(96, (3, 3), padding='same',
65 | input_shape=input_shape, name='conv1'))
66 | model.add(Activation('relu'))
67 | model.add(Conv2D(96, (3, 3), name='conv2', padding='same'))
68 | model.add(Activation('relu'))
69 | model.add(Conv2D(96, (3, 3), strides=(2, 2), padding='same', name='conv3'))
70 | model.add(Activation('relu'))
71 | model.add(Dropout(0.5))
72 |
73 | model.add(Conv2D(192, (3, 3), name='conv4', padding='same'))
74 | model.add(Activation('relu'))
75 | model.add(Conv2D(192, (3, 3), name='conv5', padding='same'))
76 | model.add(Activation('relu'))
77 | model.add(Conv2D(192, (3, 3), strides=(2, 2), name='conv6', padding='same'))
78 | model.add(Activation('relu'))
79 | model.add(Dropout(0.5))
80 |
81 | model.add(Conv2D(192, (3, 3), name='conv7', padding='same'))
82 | model.add(Activation('relu'))
83 | model.add(Conv2D(192, (1, 1), name='conv8', padding='valid'))
84 | model.add(Activation('relu'))
85 | model.add(Conv2D(10, (1, 1), name='conv9', padding='valid'))
86 |
87 | model.add(GlobalAveragePooling2D())
88 | model.add(Activation('softmax', name='activation_top'))
89 | model.summary()
90 |
91 | try:
92 | optimizer = getattr(keras.optimizers, self.solver)
93 | except:
94 | raise NotImplementedError('optimizer not implemented in keras')
95 | # All optimizers with the exception of nadam take decay as named arg
96 | try:
97 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay)
98 | except:
99 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay)
100 |
101 | model.compile(loss='categorical_crossentropy',
102 | optimizer=opt,
103 | metrics=['accuracy'])
104 | # Save initial weights so that model can be retrained with same
105 | # initialization
106 | self.initial_weights = copy.deepcopy(model.get_weights())
107 |
108 | self.model = model
109 |
110 | def create_y_mat(self, y):
111 | y_encode = self.encode_y(y)
112 | y_encode = np.reshape(y_encode, (len(y_encode), 1))
113 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes)
114 | return y_mat
115 |
116 | # Add handling for classes that do not start counting from 0
117 | def encode_y(self, y):
118 | if self.encode_map is None:
119 | self.classes_ = sorted(list(set(y)))
120 | self.n_classes = len(self.classes_)
121 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_))))
122 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_))
123 | mapper = lambda x: self.encode_map[x]
124 | transformed_y = np.array(map(mapper, y))
125 | return transformed_y
126 |
127 | def decode_y(self, y):
128 | mapper = lambda x: self.decode_map[x]
129 | transformed_y = np.array(map(mapper, y))
130 | return transformed_y
131 |
132 | def fit(self, X_train, y_train, sample_weight=None):
133 | y_mat = self.create_y_mat(y_train)
134 |
135 | if self.model is None:
136 | self.build_model(X_train)
137 |
138 | # We don't want incremental fit so reset learning rate and weights
139 | K.set_value(self.model.optimizer.lr, self.learning_rate)
140 | self.model.set_weights(self.initial_weights)
141 | self.model.fit(
142 | X_train,
143 | y_mat,
144 | batch_size=self.batch_size,
145 | epochs=self.epochs,
146 | shuffle=True,
147 | sample_weight=sample_weight,
148 | verbose=0)
149 |
150 | def predict(self, X_val):
151 | predicted = self.model.predict(X_val)
152 | return predicted
153 |
154 | def score(self, X_val, val_y):
155 | y_mat = self.create_y_mat(val_y)
156 | val_acc = self.model.evaluate(X_val, y_mat)[1]
157 | return val_acc
158 |
159 | def decision_function(self, X):
160 | return self.predict(X)
161 |
162 | def transform(self, X):
163 | model = self.model
164 | inp = [model.input]
165 | activations = []
166 |
167 | # Get activations of the last conv layer.
168 | output = [layer.output for layer in model.layers if
169 | layer.name == 'conv9'][0]
170 | func = K.function(inp + [K.learning_phase()], [output])
171 | for i in range(int(X.shape[0]/self.batch_size) + 1):
172 | minibatch = X[i * self.batch_size
173 | : min(X.shape[0], (i+1) * self.batch_size)]
174 | list_inputs = [minibatch, 0.]
175 | # Learning phase. 0 = Test mode (no dropout or batch normalization)
176 | layer_output = func(list_inputs)[0]
177 | activations.append(layer_output)
178 | output = np.vstack(tuple(activations))
179 | output = np.reshape(output, (output.shape[0],np.product(output.shape[1:])))
180 | return output
181 |
182 | def get_params(self, deep = False):
183 | params = {}
184 | params['solver'] = self.solver
185 | params['epochs'] = self.epochs
186 | params['batch_size'] = self.batch_size
187 | params['learning_rate'] = self.learning_rate
188 | params['weight_decay'] = self.lr_decay
189 | if deep:
190 | return copy.deepcopy(params)
191 | return copy.copy(params)
192 |
193 | def set_params(self, **parameters):
194 | for parameter, value in parameters.items():
195 | setattr(self, parameter, value)
196 | return self
197 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/utils/chart_data.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Experiment charting script.
16 | """
17 |
18 | from __future__ import absolute_import
19 | from __future__ import division
20 | from __future__ import print_function
21 |
22 | import os
23 | import pickle
24 |
25 | import numpy as np
26 | import matplotlib.pyplot as plt
27 | from matplotlib.backends.backend_pdf import PdfPages
28 |
29 | from absl import app
30 | from absl import flags
31 | from tensorflow import gfile
32 |
33 | flags.DEFINE_string('source_dir',
34 | '/tmp/toy_experiments',
35 | 'Directory with the output to analyze.')
36 | flags.DEFINE_string('save_dir', '/tmp/active_learning',
37 | 'Directory to save charts.')
38 | flags.DEFINE_string('dataset', 'letter', 'Dataset to analyze.')
39 | flags.DEFINE_string(
40 | 'sampling_methods',
41 | ('uniform,margin,informative_diverse,'
42 | 'pred_expert_advice_trip_agg,'
43 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34'),
44 | 'Comma separated string of sampling methods to include in chart.')
45 | flags.DEFINE_string('scoring_methods', 'logistic,kernel_ls',
46 | 'Comma separated string of scoring methods to chart.')
47 | flags.DEFINE_bool('normalize', False, 'Chart runs using normalized data.')
48 | flags.DEFINE_bool('standardize', True, 'Chart runs using standardized data.')
49 |
50 | FLAGS = flags.FLAGS
51 |
52 |
53 | def combine_results(files, diff=False):
54 | all_results = {}
55 | for f in files:
56 | data = pickle.load(gfile.FastGFile(f, 'r'))
57 | for k in data:
58 | if isinstance(k, tuple):
59 | data[k].pop('noisy_targets')
60 | data[k].pop('indices')
61 | data[k].pop('selected_inds')
62 | data[k].pop('sampler_output')
63 | key = list(k)
64 | seed = key[-1]
65 | key = key[0:10]
66 | key = tuple(key)
67 | if key in all_results:
68 | if seed not in all_results[key]['random_seeds']:
69 | all_results[key]['random_seeds'].append(seed)
70 | for field in [f for f in data[k] if f != 'n_points']:
71 | all_results[key][field] = np.vstack(
72 | (all_results[key][field], data[k][field]))
73 | else:
74 | all_results[key] = data[k]
75 | all_results[key]['random_seeds'] = [seed]
76 | else:
77 | all_results[k] = data[k]
78 | return all_results
79 |
80 |
81 | def plot_results(all_results, score_method, norm, stand, sampler_filter):
82 | colors = {
83 | 'margin':
84 | 'gold',
85 | 'uniform':
86 | 'k',
87 | 'informative_diverse':
88 | 'r',
89 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34':
90 | 'b',
91 | 'pred_expert_advice_trip_agg':
92 | 'g'
93 | }
94 | labels = {
95 | 'margin':
96 | 'margin',
97 | 'uniform':
98 | 'uniform',
99 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34':
100 | 'margin:0.33,informative_diverse:0.33, uniform:0.34',
101 | 'informative_diverse':
102 | 'informative and diverse',
103 | 'pred_expert_advice_trip_agg':
104 | 'expert: margin,informative_diverse,uniform'
105 | }
106 | markers = {
107 | 'margin':
108 | 'None',
109 | 'uniform':
110 | 'None',
111 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34':
112 | '>',
113 | 'informative_diverse':
114 | 'None',
115 | 'pred_expert_advice_trip_agg':
116 | 'p'
117 | }
118 | fields = all_results['tuple_keys']
119 | fields = dict(zip(fields, range(len(fields))))
120 |
121 | for k in sorted(all_results.keys()):
122 | sampler = k[fields['sampler']]
123 | if (isinstance(k, tuple) and
124 | k[fields['score_method']] == score_method and
125 | k[fields['standardize']] == stand and
126 | k[fields['normalize']] == norm and
127 | (sampler_filter is None or sampler in sampler_filter)):
128 | results = all_results[k]
129 | n_trials = results['accuracy'].shape[0]
130 | x = results['data_sizes'][0]
131 | mean_acc = np.mean(results['accuracy'], axis=0)
132 | CI_acc = np.std(results['accuracy'], axis=0) / np.sqrt(n_trials) * 2.96
133 | if sampler == 'uniform':
134 | plt.plot(
135 | x,
136 | mean_acc,
137 | linewidth=1,
138 | label=labels[sampler],
139 | color=colors[sampler],
140 | linestyle='--'
141 | )
142 | plt.fill_between(
143 | x,
144 | mean_acc - CI_acc,
145 | mean_acc + CI_acc,
146 | color=colors[sampler],
147 | alpha=0.2
148 | )
149 | else:
150 | plt.plot(
151 | x,
152 | mean_acc,
153 | linewidth=1,
154 | label=labels[sampler],
155 | color=colors[sampler],
156 | marker=markers[sampler],
157 | markeredgecolor=colors[sampler]
158 | )
159 | plt.fill_between(
160 | x,
161 | mean_acc - CI_acc,
162 | mean_acc + CI_acc,
163 | color=colors[sampler],
164 | alpha=0.2
165 | )
166 | plt.legend(loc=4)
167 |
168 |
169 | def get_between(filename, start, end):
170 | start_ind = filename.find(start) + len(start)
171 | end_ind = filename.rfind(end)
172 | return filename[start_ind:end_ind]
173 |
174 |
175 | def get_sampling_method(dataset, filename):
176 | return get_between(filename, dataset + '_', '/')
177 |
178 |
179 | def get_scoring_method(filename):
180 | return get_between(filename, 'results_score_', '_select_')
181 |
182 |
183 | def get_normalize(filename):
184 | return get_between(filename, '_norm_', '_stand_') == 'True'
185 |
186 |
187 | def get_standardize(filename):
188 | return get_between(
189 | filename, '_stand_', filename[filename.rfind('_'):]) == 'True'
190 |
191 |
192 | def main(argv):
193 | del argv # Unused.
194 | if not gfile.Exists(FLAGS.save_dir):
195 | gfile.MkDir(FLAGS.save_dir)
196 | charting_filepath = os.path.join(FLAGS.save_dir,
197 | FLAGS.dataset + '_charts.pdf')
198 | sampling_methods = FLAGS.sampling_methods.split(',')
199 | scoring_methods = FLAGS.scoring_methods.split(',')
200 | files = gfile.Glob(
201 | os.path.join(FLAGS.source_dir, FLAGS.dataset + '*/results*.pkl'))
202 | files = [
203 | f for f in files
204 | if (get_sampling_method(FLAGS.dataset, f) in sampling_methods and
205 | get_scoring_method(f) in scoring_methods and
206 | get_normalize(f) == FLAGS.normalize and
207 | get_standardize(f) == FLAGS.standardize)
208 | ]
209 |
210 | print('Reading in %d files...' % len(files))
211 | all_results = combine_results(files)
212 | pdf = PdfPages(charting_filepath)
213 |
214 | print('Plotting charts...')
215 | plt.style.use('ggplot')
216 | for m in scoring_methods:
217 | plot_results(
218 | all_results,
219 | m,
220 | FLAGS.normalize,
221 | FLAGS.standardize,
222 | sampler_filter=sampling_methods)
223 | plt.title('Dataset: %s, Score Method: %s' % (FLAGS.dataset, m))
224 | pdf.savefig()
225 | plt.close()
226 | pdf.close()
227 |
228 |
229 | if __name__ == '__main__':
230 | app.run(main)
231 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/utils/kernel_block_solver.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Block kernel lsqr solver for multi-class classification."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 | import math
22 |
23 | import numpy as np
24 | import scipy.linalg as linalg
25 | from scipy.sparse.linalg import spsolve
26 | from sklearn import metrics
27 |
28 |
29 | class BlockKernelSolver(object):
30 | """Inspired by algorithm from https://arxiv.org/pdf/1602.05310.pdf."""
31 | # TODO(lishal): save preformed kernel matrix and reuse if possible
32 | # perhaps not possible if want to keep scikitlearn signature
33 |
34 | def __init__(self,
35 | random_state=1,
36 | C=0.1,
37 | block_size=4000,
38 | epochs=3,
39 | verbose=False,
40 | gamma=None):
41 | self.block_size = block_size
42 | self.epochs = epochs
43 | self.C = C
44 | self.kernel = 'rbf'
45 | self.coef_ = None
46 | self.verbose = verbose
47 | self.encode_map = None
48 | self.decode_map = None
49 | self.gamma = gamma
50 | self.X_train = None
51 | self.random_state = random_state
52 |
53 | def encode_y(self, y):
54 | # Handles classes that do not start counting from 0.
55 | if self.encode_map is None:
56 | self.classes_ = sorted(list(set(y)))
57 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_))))
58 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_))
59 | mapper = lambda x: self.encode_map[x]
60 | transformed_y = np.array(map(mapper, y))
61 | return transformed_y
62 |
63 | def decode_y(self, y):
64 | mapper = lambda x: self.decode_map[x]
65 | transformed_y = np.array(map(mapper, y))
66 | return transformed_y
67 |
68 | def fit(self, X_train, y_train, sample_weight=None):
69 | """Form K and solve (K + lambda * I)x = y in a block-wise fashion."""
70 | np.random.seed(self.random_state)
71 | self.X_train = X_train
72 | n_features = X_train.shape[1]
73 | y = self.encode_y(y_train)
74 | if self.gamma is None:
75 | self.gamma = 1./n_features
76 | K = metrics.pairwise.pairwise_kernels(
77 | X_train, metric=self.kernel, gamma=self.gamma)
78 | if self.verbose:
79 | print('Finished forming kernel matrix.')
80 |
81 | # compute some constants
82 | num_classes = len(list(set(y)))
83 | num_samples = K.shape[0]
84 | num_blocks = math.ceil(num_samples*1.0/self.block_size)
85 | x = np.zeros((K.shape[0], num_classes))
86 | y_hat = np.zeros((K.shape[0], num_classes))
87 | onehot = lambda x: np.eye(num_classes)[x]
88 | y_onehot = np.array(map(onehot, y))
89 | idxes = np.diag_indices(num_samples)
90 | if sample_weight is not None:
91 | weights = np.sqrt(sample_weight)
92 | weights = weights[:, np.newaxis]
93 | y_onehot = weights * y_onehot
94 | K *= np.outer(weights, weights)
95 | if num_blocks == 1:
96 | epochs = 1
97 | else:
98 | epochs = self.epochs
99 |
100 | for e in range(epochs):
101 | shuffled_coords = np.random.choice(
102 | num_samples, num_samples, replace=False)
103 | for b in range(int(num_blocks)):
104 | residuals = y_onehot - y_hat
105 |
106 | # Form a block of K.
107 | K[idxes] += (self.C * num_samples)
108 | block = shuffled_coords[b*self.block_size:
109 | min((b+1)*self.block_size, num_samples)]
110 | K_block = K[:, block]
111 | # Dim should be block size x block size
112 | KbTKb = K_block.T.dot(K_block)
113 |
114 | if self.verbose:
115 | print('solving block {0}'.format(b))
116 | # Try linalg solve then sparse solve for handling of sparse input.
117 | try:
118 | x_block = linalg.solve(KbTKb, K_block.T.dot(residuals))
119 | except:
120 | try:
121 | x_block = spsolve(KbTKb, K_block.T.dot(residuals))
122 | except:
123 | return None
124 |
125 | # update model
126 | x[block] = x[block] + x_block
127 | K[idxes] = K[idxes] - (self.C * num_samples)
128 | y_hat = K.dot(x)
129 |
130 | y_pred = np.argmax(y_hat, axis=1)
131 | train_acc = metrics.accuracy_score(y, y_pred)
132 | if self.verbose:
133 | print('Epoch: {0}, Block: {1}, Train Accuracy: {2}'
134 | .format(e, b, train_acc))
135 | self.coef_ = x
136 |
137 | def predict(self, X_val):
138 | val_K = metrics.pairwise.pairwise_kernels(
139 | X_val, self.X_train, metric=self.kernel, gamma=self.gamma)
140 | val_pred = np.argmax(val_K.dot(self.coef_), axis=1)
141 | return self.decode_y(val_pred)
142 |
143 | def score(self, X_val, val_y):
144 | val_pred = self.predict(X_val)
145 | val_acc = metrics.accuracy_score(val_y, val_pred)
146 | return val_acc
147 |
148 | def decision_function(self, X, type='predicted'):
149 | # Return the predicted value of the best class
150 | # Margin_AL will see that a vector is returned and not a matrix and
151 | # simply select the points that have the lowest predicted value to label
152 | K = metrics.pairwise.pairwise_kernels(
153 | X, self.X_train, metric=self.kernel, gamma=self.gamma)
154 | predicted = K.dot(self.coef_)
155 | if type == 'scores':
156 | val_best = np.max(K.dot(self.coef_), axis=1)
157 | return val_best
158 | elif type == 'predicted':
159 | return predicted
160 | else:
161 | raise NotImplementedError('Invalid return type for decision function.')
162 |
163 | def get_params(self, deep=False):
164 | params = {}
165 | params['C'] = self.C
166 | params['gamma'] = self.gamma
167 | if deep:
168 | return copy.deepcopy(params)
169 | return copy.copy(params)
170 |
171 | def set_params(self, **parameters):
172 | for parameter, value in parameters.items():
173 | setattr(self, parameter, value)
174 | return self
175 |
176 | def softmax_over_predicted(self, X):
177 | val_K = metrics.pairwise.pairwise_kernels(
178 | X, self.X_train, metric=self.kernel, gamma=self.gamma)
179 | val_pred = val_K.dot(self.coef_)
180 | row_min = np.min(val_pred, axis=1)
181 | val_pred = val_pred - row_min[:, None]
182 | val_pred = np.exp(val_pred)
183 | sum_exp = np.sum(val_pred, axis=1)
184 | val_pred = val_pred/sum_exp[:, None]
185 | return val_pred
186 |
--------------------------------------------------------------------------------
/src/baselines/active-learning/utils/small_cnn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Implements Small CNN model in keras using tensorflow backend."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import copy
21 |
22 | import keras
23 | import keras.backend as K
24 | from keras.layers import Activation
25 | from keras.layers import Conv2D
26 | from keras.layers import Dense
27 | from keras.layers import Dropout
28 | from keras.layers import Flatten
29 | from keras.layers import MaxPooling2D
30 | from keras.models import Sequential
31 |
32 | import numpy as np
33 | import tensorflow as tf
34 |
35 |
36 | class SmallCNN(object):
37 | """Small convnet that matches sklearn api.
38 |
39 | Implements model from
40 | https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py
41 | Adapts for inputs of variable size, expects data to be 4d tensor, with
42 | # of obserations as first dimension and other dimensions to correspond to
43 | length width and # of channels in image.
44 | """
45 |
46 | def __init__(self,
47 | random_state=1,
48 | epochs=50,
49 | batch_size=32,
50 | solver='rmsprop',
51 | learning_rate=0.001,
52 | lr_decay=0.):
53 | # params
54 | self.solver = solver
55 | self.epochs = epochs
56 | self.batch_size = batch_size
57 | self.learning_rate = learning_rate
58 | self.lr_decay = lr_decay
59 | # data
60 | self.encode_map = None
61 | self.decode_map = None
62 | self.model = None
63 | self.random_state = random_state
64 | self.n_classes = None
65 |
66 | def build_model(self, X):
67 | # assumes that data axis order is same as the backend
68 | input_shape = X.shape[1:]
69 | np.random.seed(self.random_state)
70 | tf.set_random_seed(self.random_state)
71 |
72 | model = Sequential()
73 | model.add(Conv2D(32, (3, 3), padding='same',
74 | input_shape=input_shape, name='conv1'))
75 | model.add(Activation('relu'))
76 | model.add(Conv2D(32, (3, 3), name='conv2'))
77 | model.add(Activation('relu'))
78 | model.add(MaxPooling2D(pool_size=(2, 2)))
79 | model.add(Dropout(0.25))
80 |
81 | model.add(Conv2D(64, (3, 3), padding='same', name='conv3'))
82 | model.add(Activation('relu'))
83 | model.add(Conv2D(64, (3, 3), name='conv4'))
84 | model.add(Activation('relu'))
85 | model.add(MaxPooling2D(pool_size=(2, 2)))
86 | model.add(Dropout(0.25))
87 |
88 | model.add(Flatten())
89 | model.add(Dense(512, name='dense1'))
90 | model.add(Activation('relu'))
91 | model.add(Dropout(0.5))
92 | model.add(Dense(self.n_classes, name='dense2'))
93 | model.add(Activation('softmax'))
94 |
95 | try:
96 | optimizer = getattr(keras.optimizers, self.solver)
97 | except:
98 | raise NotImplementedError('optimizer not implemented in keras')
99 | # All optimizers with the exception of nadam take decay as named arg
100 | try:
101 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay)
102 | except:
103 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay)
104 |
105 | model.compile(loss='categorical_crossentropy',
106 | optimizer=opt,
107 | metrics=['accuracy'])
108 | # Save initial weights so that model can be retrained with same
109 | # initialization
110 | self.initial_weights = copy.deepcopy(model.get_weights())
111 |
112 | self.model = model
113 |
114 | def create_y_mat(self, y):
115 | y_encode = self.encode_y(y)
116 | y_encode = np.reshape(y_encode, (len(y_encode), 1))
117 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes)
118 | return y_mat
119 |
120 | # Add handling for classes that do not start counting from 0
121 | def encode_y(self, y):
122 | if self.encode_map is None:
123 | self.classes_ = sorted(list(set(y)))
124 | self.n_classes = len(self.classes_)
125 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_))))
126 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_))
127 | mapper = lambda x: self.encode_map[x]
128 | transformed_y = np.array(map(mapper, y))
129 | return transformed_y
130 |
131 | def decode_y(self, y):
132 | mapper = lambda x: self.decode_map[x]
133 | transformed_y = np.array(map(mapper, y))
134 | return transformed_y
135 |
136 | def fit(self, X_train, y_train, sample_weight=None):
137 | y_mat = self.create_y_mat(y_train)
138 |
139 | if self.model is None:
140 | self.build_model(X_train)
141 |
142 | # We don't want incremental fit so reset learning rate and weights
143 | K.set_value(self.model.optimizer.lr, self.learning_rate)
144 | self.model.set_weights(self.initial_weights)
145 | self.model.fit(
146 | X_train,
147 | y_mat,
148 | batch_size=self.batch_size,
149 | epochs=self.epochs,
150 | shuffle=True,
151 | sample_weight=sample_weight,
152 | verbose=0)
153 |
154 | def predict(self, X_val):
155 | predicted = self.model.predict(X_val)
156 | return predicted
157 |
158 | def score(self, X_val, val_y):
159 | y_mat = self.create_y_mat(val_y)
160 | val_acc = self.model.evaluate(X_val, y_mat)[1]
161 | return val_acc
162 |
163 | def decision_function(self, X):
164 | return self.predict(X)
165 |
166 | def transform(self, X):
167 | model = self.model
168 | inp = [model.input]
169 | activations = []
170 |
171 | # Get activations of the first dense layer.
172 | output = [layer.output for layer in model.layers if
173 | layer.name == 'dense1'][0]
174 | func = K.function(inp + [K.learning_phase()], [output])
175 | for i in range(int(X.shape[0]/self.batch_size) + 1):
176 | minibatch = X[i * self.batch_size
177 | : min(X.shape[0], (i+1) * self.batch_size)]
178 | list_inputs = [minibatch, 0.]
179 | # Learning phase. 0 = Test mode (no dropout or batch normalization)
180 | layer_output = func(list_inputs)[0]
181 | activations.append(layer_output)
182 | output = np.vstack(tuple(activations))
183 | return output
184 |
185 | def get_params(self, deep = False):
186 | params = {}
187 | params['solver'] = self.solver
188 | params['epochs'] = self.epochs
189 | params['batch_size'] = self.batch_size
190 | params['learning_rate'] = self.learning_rate
191 | params['weight_decay'] = self.lr_decay
192 | if deep:
193 | return copy.deepcopy(params)
194 | return copy.copy(params)
195 |
196 | def set_params(self, **parameters):
197 | for parameter, value in parameters.items():
198 | setattr(self, parameter, value)
199 | return self
200 |
--------------------------------------------------------------------------------
/src/baselines/anrmab.py:
--------------------------------------------------------------------------------
1 | # from src.baselines.sampling_methods.bandit_discrete import BanditDiscreteSampler
2 | import numpy as np
3 | import scipy as sc
4 | from sklearn.cluster import KMeans
5 | from sklearn.metrics.pairwise import euclidean_distances
6 | import networkx as nx
7 | def centralissimo(G):
8 | centralities = []
9 | centralities.append(nx.pagerank(G)) #centralities.append(nx.harmonic_centrality(G))
10 | L = len(centralities[0])
11 | Nc = len(centralities)
12 | cenarray = np.zeros((Nc,L))
13 | for i in range(Nc):
14 | cenarray[i][list(centralities[i].keys())]=list(centralities[i].values())
15 | normcen = (cenarray.astype(float)-np.min(cenarray,axis=1)[:,None])/(np.max(cenarray,axis=1)-np.min(cenarray,axis=1))[:,None]
16 | return normcen
17 |
18 | class ProbSampler(object):
19 | def __init__(self):
20 | pass
21 |
22 | def select_batch(self,wraped_feat):
23 | entropy = wraped_feat[0]
24 | selected = np.argmax(entropy)
25 | return selected
26 |
27 | class DegSampler(object):
28 | def __init__(self):
29 | pass
30 | def select_batch(self,wraped_feat):
31 | deg = wraped_feat[1]
32 | selected = np.argmax(deg)
33 | # print("in DegSampler {}".format(np.max(deg)))
34 | return selected
35 |
36 | class ClusterSampler(object):
37 | def __init__(self):
38 | pass
39 | def select_batch(self,wraped_feat):
40 | edscore = wraped_feat[2]
41 | selected = np.argmax(-edscore)
42 | # print("in DegSampler {}".format(np.max(deg)))
43 | return selected
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | class BanditDiscreteSampler(object):
52 | """Wraps EXP3 around mixtures of indicated methods.
53 |
54 | Uses EXP3 mult-armed bandit algorithm to select sampler methods.
55 | """
56 | def __init__(self,budget,num_nodes,n_arms,seed=123,
57 | reward_function = lambda AL_acc: AL_acc[-1],
58 | gamma=0.5):
59 |
60 | # self.name = 'bandit_discrete'
61 | # np.random.seed(seed)
62 | # self.seed = seed
63 | # self.initialize_samplers(samplers)
64 | # self.gamma = gamma
65 |
66 | # print(budget,num_nodes,n_arms)
67 | self.n_arms = n_arms
68 | self.reward_function = reward_function
69 | self.num_arm = float(self.n_arms)
70 | self.pmin=np.sqrt(np.log(self.num_arm) / (self.num_arm * budget))
71 |
72 | self.pull_history = []
73 | self.acc_history = []
74 | self.w = np.ones(self.n_arms)
75 | # self.x = np.zeros(self.n_arms)
76 | self.p = None
77 | # self.probs = []
78 | self.selectionhistory = []
79 | self.num_nodes = float(num_nodes)
80 | self.rewardlist = []
81 | self.phi = []
82 | self.budget = budget
83 | self.Q = []
84 | self.lastselected = []
85 |
86 | def select_batch(self, r, wraped_feat):
87 |
88 | # print(r)
89 | # self.acc_history.append(eval_acc)
90 |
91 | wraped_feat = np.exp(wraped_feat*20.) #make the prob sharper
92 | wraped_feat = wraped_feat/np.sum(wraped_feat,axis=-1,keepdims=True)
93 | if self.p is not None :
94 | self.rewardlist.append(1.0 / (self.phi[self.lastselected] * float(wraped_feat.shape[-1])) * r)
95 | # print("rhat {} {} {}".format(self.rewardlist[-1],self.phi[self.lastselected] , float(wraped_feat.shape[-1])))
96 | reward = 1/float(self.budget)*sum(self.rewardlist)
97 | self.rhat = reward*self.Q[:, self.lastselected] / self.phi[self.lastselected]
98 | self.w = self.w*np.exp(1*0.5*self.pmin*(self.rhat+1/self.p*np.sqrt(np.log(self.num_nodes/0.1)/(self.num_arm*self.budget))))
99 |
100 | self.p = (1 - self.num_arm * self.pmin) * self.w / np.sum(self.w) + self.pmin
101 | # print(self.p,np.sum(self.p),self.pmin)
102 | self.Q = wraped_feat
103 | self.phi = np.matmul(self.p.reshape((1, -1)), wraped_feat).squeeze()
104 |
105 | t= 10
106 | # self.phi = np.exp(self.phi * 100.)
107 | # self.phi = self.phi/np.sum(self.phi)
108 | # print(self.phi.shape)
109 |
110 |
111 | selected = np.random.choice(range(wraped_feat.shape[-1]), p=self.phi)
112 | self.lastselected = selected
113 |
114 | # print(wraped_feat[:,selected])
115 | return selected
116 |
117 |
118 | class AnrmabQuery(object):
119 | def __init__(self, G, budget,num_nodes,batchsize):
120 | self.q = []
121 | self.batchsize = batchsize
122 | # self.alreadyselected = [[] for i in range(batchsize)]
123 | self.G = G
124 | self.NCL = G.stat["nclass"]
125 | self.normcen = centralissimo(self.G.G)[0]
126 | self.cenperc = self.perc(self.normcen)
127 |
128 | for i in range(batchsize):
129 | self.q.append(BanditDiscreteSampler(budget=budget,num_nodes=num_nodes,n_arms=3))
130 | pass
131 |
132 | def __call__(self, output, acc, pool):
133 | ret = []
134 | for id in range(self.batchsize):
135 | selected = self.selectOneNode( self.q[id],acc[id], output[id],pool[id])
136 | ret.append(selected)
137 | ret = np.array(ret) #.reshape(-1,1)
138 | ret = ret.tolist()
139 | return ret
140 |
141 | def selectOneNode(self, q, acc, output,pool):
142 |
143 | # if self.multilabel:
144 | # probs = 1. / (1. + np.exp(-output))
145 | # entropy = multiclassentropy_numpy(probs)
146 | # else:
147 | entropy = sc.stats.entropy(output.transpose())
148 | # print("entropy shape{}".format(entropy.shape))
149 |
150 | validentropy = entropy[pool]
151 | validdeg = self.normcen[pool]
152 |
153 | kmeans = KMeans(n_clusters=self.NCL, random_state=0).fit(output)
154 | ed = euclidean_distances(output, kmeans.cluster_centers_)
155 | ed_score = np.min(ed, axis=1)
156 | valided_score = ed_score[pool]
157 |
158 |
159 | entrperc = self.perc(entropy)
160 | edprec = self.percd(ed_score)
161 |
162 |
163 | wraped_feat = np.stack([entrperc, edprec, self.cenperc])
164 |
165 | wraped_feat_valid = wraped_feat[:,pool]
166 | # print("warpedfaet {}".format(wraped_feat_valid.shape))
167 | selected = q.select_batch( acc, wraped_feat_valid)
168 |
169 | realselected = pool[selected]
170 | # print(realselected)
171 | # print(self.normcen[realselected])
172 |
173 | return realselected
174 |
175 | def percd(self,input):
176 | return 1-np.argsort(np.argsort(input,kind='stable'),kind='stable')/len(input)
177 |
178 | # calculate the percentage of elements smaller than the k-th element
179 | def perc(self,input):
180 | return 1-np.argsort(np.argsort(-input,kind='stable'),kind='stable')/len(input)
181 |
182 |
183 | def unitTest():
184 | q = CoreSetQuery(3)
185 |
186 | # pool = np.array([[2,3],[2,3],[4,6]])
187 |
188 | for i in range(3):
189 | features = np.random.randn(3, 10, 5)
190 | selected = q(features)
191 | print(selected)
192 |
193 |
194 | if __name__ == "__main__":
195 | unitTest()
196 |
197 |
198 |
--------------------------------------------------------------------------------
/src/baselines/coreset.py:
--------------------------------------------------------------------------------
1 | from src.baselines.sampling_methods.kcenter_greedy import kCenterGreedy
2 | import numpy as np
3 |
4 | class CoreSetQuery(object):
5 | def __init__(self, batchsize,trainsetid):
6 | self.q = []
7 | self.trainsetid = trainsetid
8 | self.batchsize = batchsize
9 | self.alreadyselected = [[] for i in range(batchsize)]
10 | for i in range(batchsize):
11 | self.q.append(kCenterGreedy())
12 | pass
13 |
14 | def __call__(self, outputs):
15 | ret = []
16 | for id,row in enumerate(self.alreadyselected):
17 | selected = self.selectOneNode(outputs[id], row, self.q[id])
18 | ret.append(selected)
19 | ret = np.array(ret) #.reshape(-1,1)
20 | ret = ret.tolist()
21 | self.alreadyselected = [x + ret[id] for id, x in enumerate(self.alreadyselected)]
22 |
23 | selectedtrueid = []
24 | for i in range(self.batchsize):
25 | print(ret[i])
26 | selectedtrueid.append(self.trainsetid[i][ret[i][0]])
27 |
28 | return selectedtrueid
29 |
30 |
31 | def selectOneNode(self, output, pool,q):
32 | selected = q.select_batch(output,pool,1)
33 | return selected
34 |
35 |
36 | def unitTest():
37 | q = CoreSetQuery(3)
38 |
39 | # pool = np.array([[2,3],[2,3],[4,6]])
40 |
41 | for i in range(3):
42 | features = np.random.randn(3, 10, 5)
43 | selected = q(features)
44 | print(selected)
45 |
46 |
47 | if __name__ == "__main__":
48 | unitTest()
49 |
50 |
51 |
--------------------------------------------------------------------------------
/src/baselines/coreset/compute_distance_mat.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import numpy
3 | import numpy.matlib
4 | import time
5 | import pickle
6 | import bisect
7 |
8 | dat = pickle.load(open('feature_vectors_pickled'))
9 | data = numpy.concatenate((dat['gt_f'],dat['f']), axis=0)
10 | budget = 5000
11 |
12 | start = time.clock()
13 | num_images = data.shape[0]
14 |
15 | dist_mat = numpy.matmul(data,data.transpose())
16 |
17 | sq = numpy.array(dist_mat.diagonal()).reshape(num_images,1)
18 | dist_mat *= -2
19 | dist_mat+=sq
20 | dist_mat+=sq.transpose()
21 |
22 | elapsed = time.clock() - start
23 | print "Time spent in (distance computation) is: ", elapsed
24 | numpy.save('distances.npy', dist_mat)
25 |
26 |
--------------------------------------------------------------------------------
/src/baselines/coreset/configure.sh:
--------------------------------------------------------------------------------
1 | export GUROBI_HOME="/opt/gurobi702/linux64"
2 | export PATH="${PATH}:${GUROBI_HOME}/bin"
3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${GUROBI_HOME}/lib"
4 | export GRB_LICENSE_FILE="/afs/cs.stanford.edu/u/ozansener/gurobi.lic"
5 |
6 |
--------------------------------------------------------------------------------
/src/baselines/coreset/full_solver_gurobi.py:
--------------------------------------------------------------------------------
1 | import numpy
2 | from gurobipy import *
3 | import pickle
4 | import numpy.matlib
5 | import time
6 | import pickle
7 | import bisect
8 |
9 | def solve_fac_loc(xx,yy,subset,n,budget):
10 | model = Model("k-center")
11 | x={}
12 | y={}
13 | z={}
14 | for i in range(n):
15 | # z_i: is a loss
16 | z[i] = model.addVar(obj=1, ub=0.0, vtype="B", name="z_{}".format(i))
17 |
18 | m = len(xx)
19 | for i in range(m):
20 | _x = xx[i]
21 | _y = yy[i]
22 | # y_i = 1 means i is facility, 0 means it is not
23 | if _y not in y:
24 | if _y in subset:
25 | y[_y] = model.addVar(obj=0, ub=1.0, lb=1.0, vtype="B", name="y_{}".format(_y))
26 | else:
27 | y[_y] = model.addVar(obj=0,vtype="B", name="y_{}".format(_y))
28 | #if not _x == _y:
29 | x[_x,_y] = model.addVar(obj=0, vtype="B", name="x_{},{}".format(_x,_y))
30 | model.update()
31 |
32 | coef = [1 for j in range(n)]
33 | var = [y[j] for j in range(n)]
34 | model.addConstr(LinExpr(coef,var), "=", rhs=budget+len(subset), name="k_center")
35 |
36 | for i in range(m):
37 | _x = xx[i]
38 | _y = yy[i]
39 | #if not _x == _y:
40 | model.addConstr(x[_x,_y], "<", y[_y], name="Strong_{},{}".format(_x,_y))
41 |
42 | yyy = {}
43 | for v in range(m):
44 | _x = xx[v]
45 | _y = yy[v]
46 | if _x not in yyy:
47 | yyy[_x]=[]
48 | if _y not in yyy[_x]:
49 | yyy[_x].append(_y)
50 |
51 | for _x in yyy:
52 | coef = []
53 | var = []
54 | for _y in yyy[_x]:
55 | #if not _x==_y:
56 | coef.append(1)
57 | var.append(x[_x,_y])
58 | coef.append(1)
59 | var.append(z[_x])
60 | model.addConstr(LinExpr(coef,var), "=", 1, name="Assign{}".format(_x))
61 | model.__data = x,y,z
62 | return model
63 |
64 |
65 | data = pickle.load(open('feature_vectors_pickled'))
66 | budget = 10000
67 |
68 | start = time.clock()
69 | num_images = data.shape[0]
70 | dist_mat = numpy.matmul(data,data.transpose())
71 |
72 | sq = numpy.array(dist_mat.diagonal()).reshape(num_images,1)
73 | dist_mat *= -2
74 | dist_mat+=sq
75 | dist_mat+=sq.transpose()
76 |
77 | elapsed = time.clock() - start
78 | print "Time spent in (distance computation) is: ", elapsed
79 |
80 | num_images = 50000
81 |
82 | # We need to get k centers start with greedy solution
83 | budget = 10000
84 | subset = [i for i in range(1)]
85 |
86 | ub= UB
87 | lb = ub/2.0
88 | max_dist=ub
89 |
90 | _x,_y = numpy.where(dist_mat<=max_dist)
91 | _d = dist_mat[_x,_y]
92 | subset = [i for i in range(1)]
93 | model = solve_fac_loc(_x,_y,subset,num_images,budget)
94 | #model.setParam( 'OutputFlag', False )
95 | x,y,z = model.__data
96 | delta=1e-7
97 | while ub-lb>delta:
98 | print "State",ub,lb
99 | cur_r = (ub+lb)/2.0
100 | viol = numpy.where(_d>cur_r)
101 | new_max_d = numpy.min(_d[_d>=cur_r])
102 | new_min_d = numpy.max(_d[_d<=cur_r])
103 | print "If it succeeds, new max is:",new_max_d,new_min_d
104 | for v in viol[0]:
105 | x[_x[v],_y[v]].UB = 0
106 |
107 | model.update()
108 | r = model.optimize()
109 | if model.getAttr(GRB.Attr.Status) == GRB.INFEASIBLE:
110 | failed=True
111 | print "Infeasible"
112 | elif sum([z[i].X for i in range(len(z))]) > 0:
113 | failed=True
114 | print "Failed"
115 | else:
116 | failed=False
117 | if failed:
118 | lb = max(cur_r,new_max_d)
119 | #failed so put edges back
120 | for v in viol[0]:
121 | x[_x[v],_y[v]].UB = 1
122 | else:
123 | print "sol founded",cur_r,lb,ub
124 | ub = min(cur_r,new_min_d)
125 | model.write("s_{}_solution_{}.sol".format(budget,cur_r))
126 |
127 |
--------------------------------------------------------------------------------
/src/baselines/coreset/gurobi_solution_parser.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import matplotlib.pyplot as plt
3 | import numpy as np
4 |
5 | import numpy
6 | import pickle
7 | import numpy.matlib
8 | import time
9 | import pickle
10 | import bisect
11 |
12 | gurobi_solution_file = 'solution_2.86083525164.sol'
13 | results = open(gurobi_solution_file).read().split('\n')
14 | results_nodes = filter(lambda x: 'y' in x,filter(lambda x:'#' not in x, results))
15 | string_to_id = lambda x:(int(x.split(' ')[0].split('_')[1]),int(x.split(' ')[1]))
16 | result_node_ids = map(string_to_id, results_nodes)
17 |
18 | results_as_dict = {v[0]:v[1] for v in result_node_ids}
19 |
20 | centers = []
21 | for node_result in result_node_ids:
22 | if node_result[1] > 0:
23 | centers.append(node_result[0])
24 |
25 | pickle.dump(centers,open('centers.bn','wb'))
26 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/bandit_discrete.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Bandit wrapper around base AL sampling methods.
16 |
17 | Assumes adversarial multi-armed bandit setting where arms correspond to
18 | mixtures of different AL methods.
19 |
20 | Uses EXP3 algorithm to decide which AL method to use to create the next batch.
21 | Similar to Hsu & Lin 2015, Active Learning by Learning.
22 | https://www.csie.ntu.edu.tw/~htlin/paper/doc/aaai15albl.pdf
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import numpy as np
30 |
31 | from src.baselines.sampling_methods.wrapper_sampler_def import AL_MAPPING, WrapperSamplingMethod
32 |
33 |
34 | class BanditDiscreteSampler(WrapperSamplingMethod):
35 | """Wraps EXP3 around mixtures of indicated methods.
36 |
37 | Uses EXP3 mult-armed bandit algorithm to select sampler methods.
38 | """
39 | def __init__(self,budget,seed=123,
40 | reward_function = lambda AL_acc: AL_acc[-1],
41 | gamma=0.5,
42 | samplers=[{'methods':('margin','uniform'),'weights':(0,1)},
43 | {'methods':('margin','uniform'),'weights':(1,0)}]):
44 | """Initializes sampler with indicated gamma and arms.
45 |
46 | Args:
47 | X: training data
48 | y: labels, may need to be input into base samplers
49 | seed: seed to use for random sampling
50 | reward_function: reward based on previously observed accuracies. Assumes
51 | that the input is a sequence of observed accuracies. Will ultimately be
52 | a class method and may need access to other class properties.
53 | gamma: weight on uniform mixture. Arm probability updates are a weighted
54 | mixture of uniform and an exponentially weighted distribution.
55 | Lower gamma more aggressively updates based on observed rewards.
56 | samplers: list of dicts with two fields
57 | 'samplers': list of named samplers
58 | 'weights': percentage of batch to allocate to each sampler
59 | """
60 |
61 | self.name = 'bandit_discrete'
62 | np.random.seed(seed)
63 | self.seed = seed
64 | # self.initialize_samplers(samplers)
65 | self.samplers = samplers
66 | self.gamma = gamma
67 | self.n_arms = len(samplers)
68 | self.reward_function = reward_function
69 |
70 | self.pull_history = []
71 | self.acc_history = []
72 | self.w = np.ones(self.n_arms)
73 | self.x = np.zeros(self.n_arms)
74 | self.p = self.w / (1.0 * self.n_arms)
75 | self.probs = []
76 | self.selectionhistory = []
77 | self.num_arm = float(len(self.samplers))
78 |
79 | self.pmin = np.sqrt(np.ln(self.num_arm)/(self.num_arm*budget))
80 |
81 | def update_vars_arnmab(self, arm_pulled):
82 | reward = self.reward_function(self.acc_history)
83 | Qkstar = [self.samplers[arm_pulled][self.selectionhistory[-1]]
84 | phistar = sum([self.p[i] * self.samplers[i].valuelist[self.selectionhistory[-1]] for i in range(len(self.samplers)))
85 | rhat = Qkstar / phistar
86 | self.w = self.w*np.exp(self.pmin/2.0*(rhat))
87 |
88 |
89 |
90 | def update_vars(self, arm_pulled):
91 | reward = self.reward_function(self.acc_history)
92 |
93 | self.x = np.zeros(self.n_arms)
94 | self.x[arm_pulled] = reward / self.p[arm_pulled]
95 | self.w = self.w * np.exp(self.gamma * self.x / self.n_arms)
96 | self.p = ((1.0 - self.gamma) * self.w / sum(self.w)
97 | + self.gamma / self.n_arms)
98 | # print(self.p)
99 | self.probs.append(self.p)
100 |
101 | def select_batch_arnmab(N, eval_acc, wraped_feat) :
102 | self.acc_history.append(eval_acc)
103 |
104 |
105 | if len(self.pull_history) > 0:
106 | self.update_vars(self.pull_history[-1])
107 |
108 | def select_batch(self, N, eval_acc,wraped_feat):
109 | """Returns batch of datapoints sampled using mixture of AL_methods.
110 |
111 | Assumes that data has already been shuffled.
112 |
113 | Args:
114 | already_selected: index of datapoints already selected
115 | N: batch size
116 | eval_acc: accuracy of model trained after incorporating datapoints from
117 | last recommended batch
118 |
119 | Returns:
120 | indices of points selected to label
121 | """
122 |
123 | # print("eval_acc {}".format(eval_acc))
124 | # exit()
125 | # Update observed reward and arm probabilities
126 | self.acc_history.append(eval_acc)
127 | if len(self.pull_history) > 0:
128 | self.update_vars(self.pull_history[-1])
129 | # Sample an arm
130 |
131 |
132 |
133 | arm = np.random.choice(range(self.n_arms), p=self.p)
134 | self.pull_history.append(arm)
135 | # kwargs['N'] = N
136 | # kwargs['already_selected'] = already_selected
137 |
138 | # print("use arm {}".format(arm))
139 | sample = self.samplers[arm].select_batch(wraped_feat)
140 | return sample
141 |
142 | def to_dict(self):
143 | output = {}
144 | output['samplers'] = self.base_samplers
145 | output['arm_probs'] = self.probs
146 | output['pull_history'] = self.pull_history
147 | output['rewards'] = self.acc_history
148 | return output
149 |
150 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Controls imports to fill up dictionary of different sampling methods.
16 | """
17 |
18 | from functools import partial
19 | AL_MAPPING = {}
20 |
21 |
22 | def get_base_AL_mapping():
23 | from src.baselines.sampling_methods.margin_AL import MarginAL
24 | from src.baselines.sampling_methods.informative_diverse import InformativeClusterDiverseSampler
25 | from src.baselines.sampling_methods.hierarchical_clustering_AL import HierarchicalClusterAL
26 | from src.baselines.sampling_methods.uniform_sampling import UniformSampling
27 | from src.baselines.sampling_methods.represent_cluster_centers import RepresentativeClusterMeanSampling
28 | from src.baselines.sampling_methods.graph_density import GraphDensitySampler
29 | from src.baselines.sampling_methods.kcenter_greedy import kCenterGreedy
30 | AL_MAPPING['margin'] = MarginAL
31 | AL_MAPPING['informative_diverse'] = InformativeClusterDiverseSampler
32 | AL_MAPPING['hierarchical'] = HierarchicalClusterAL
33 | AL_MAPPING['uniform'] = UniformSampling
34 | AL_MAPPING['margin_cluster_mean'] = RepresentativeClusterMeanSampling
35 | AL_MAPPING['graph_density'] = GraphDensitySampler
36 | AL_MAPPING['kcenter'] = kCenterGreedy
37 |
38 |
39 | def get_all_possible_arms():
40 | from src.baselines.sampling_methods.mixture_of_samplers import MixtureOfSamplers
41 | AL_MAPPING['mixture_of_samplers'] = MixtureOfSamplers
42 |
43 |
44 | def get_wrapper_AL_mapping():
45 | from src.baselines.sampling_methods.bandit_discrete import BanditDiscreteSampler
46 | from src.baselines.sampling_methods.simulate_batch import SimulateBatchSampler
47 | AL_MAPPING['bandit_mixture'] = partial(
48 | BanditDiscreteSampler,
49 | samplers=[{
50 | 'methods': ['margin', 'uniform'],
51 | 'weights': [0, 1]
52 | }, {
53 | 'methods': ['margin', 'uniform'],
54 | 'weights': [0.25, 0.75]
55 | }, {
56 | 'methods': ['margin', 'uniform'],
57 | 'weights': [0.5, 0.5]
58 | }, {
59 | 'methods': ['margin', 'uniform'],
60 | 'weights': [0.75, 0.25]
61 | }, {
62 | 'methods': ['margin', 'uniform'],
63 | 'weights': [1, 0]
64 | }])
65 | AL_MAPPING['bandit_discrete'] = partial(
66 | BanditDiscreteSampler,
67 | samplers=[{
68 | 'methods': ['margin', 'uniform'],
69 | 'weights': [0, 1]
70 | }, {
71 | 'methods': ['margin', 'uniform'],
72 | 'weights': [1, 0]
73 | }])
74 | AL_MAPPING['simulate_batch_mixture'] = partial(
75 | SimulateBatchSampler,
76 | samplers=({
77 | 'methods': ['margin', 'uniform'],
78 | 'weights': [1, 0]
79 | }, {
80 | 'methods': ['margin', 'uniform'],
81 | 'weights': [0.5, 0.5]
82 | }, {
83 | 'methods': ['margin', 'uniform'],
84 | 'weights': [0, 1]
85 | }),
86 | n_sims=5,
87 | train_per_sim=10,
88 | return_best_sim=False)
89 | AL_MAPPING['simulate_batch_best_sim'] = partial(
90 | SimulateBatchSampler,
91 | samplers=[{
92 | 'methods': ['margin', 'uniform'],
93 | 'weights': [1, 0]
94 | }],
95 | n_sims=10,
96 | train_per_sim=10,
97 | return_type='best_sim')
98 | AL_MAPPING['simulate_batch_frequency'] = partial(
99 | SimulateBatchSampler,
100 | samplers=[{
101 | 'methods': ['margin', 'uniform'],
102 | 'weights': [1, 0]
103 | }],
104 | n_sims=10,
105 | train_per_sim=10,
106 | return_type='frequency')
107 |
108 | def get_mixture_of_samplers(name):
109 | assert 'mixture_of_samplers' in name
110 | if 'mixture_of_samplers' not in AL_MAPPING:
111 | raise KeyError('Mixture of Samplers not yet loaded.')
112 | args = name.split('-')[1:]
113 | samplers = args[0::2]
114 | weights = args[1::2]
115 | weights = [float(w) for w in weights]
116 | assert sum(weights) == 1
117 | mixture = {'methods': samplers, 'weights': weights}
118 | print(mixture)
119 | return partial(AL_MAPPING['mixture_of_samplers'], mixture=mixture)
120 |
121 |
122 | def get_AL_sampler(name):
123 | if name in AL_MAPPING and name != 'mixture_of_samplers':
124 | return AL_MAPPING[name]
125 | if 'mixture_of_samplers' in name:
126 | return get_mixture_of_samplers(name)
127 | raise NotImplementedError('The specified sampler is not available.')
128 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/graph_density.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Diversity promoting sampling method that uses graph density to determine
16 | most representative points.
17 |
18 | This is an implementation of the method described in
19 | https://www.mpi-inf.mpg.de/fileadmin/inf/d2/Research_projects_files/EbertCVPR2012.pdf
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import copy
27 |
28 | from sklearn.neighbors import kneighbors_graph
29 | from sklearn.metrics import pairwise_distances
30 | import numpy as np
31 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
32 |
33 |
34 | class GraphDensitySampler(SamplingMethod):
35 | """Diversity promoting sampling method that uses graph density to determine
36 | most representative points.
37 | """
38 |
39 | def __init__(self, X, y, seed):
40 | self.name = 'graph_density'
41 | self.X = X
42 | self.flat_X = self.flatten_X()
43 | # Set gamma for gaussian kernel to be equal to 1/n_features
44 | self.gamma = 1. / self.X.shape[1]
45 | self.compute_graph_density()
46 |
47 | def compute_graph_density(self, n_neighbor=10):
48 | # kneighbors graph is constructed using k=10
49 | connect = kneighbors_graph(self.flat_X, n_neighbor,p=1)
50 | # Make connectivity matrix symmetric, if a point is a k nearest neighbor of
51 | # another point, make it vice versa
52 | neighbors = connect.nonzero()
53 | inds = zip(neighbors[0],neighbors[1])
54 | # Graph edges are weighted by applying gaussian kernel to manhattan dist.
55 | # By default, gamma for rbf kernel is equal to 1/n_features but may
56 | # get better results if gamma is tuned.
57 | for entry in inds:
58 | i = entry[0]
59 | j = entry[1]
60 | distance = pairwise_distances(self.flat_X[[i]],self.flat_X[[j]],metric='manhattan')
61 | distance = distance[0,0]
62 | weight = np.exp(-distance * self.gamma)
63 | connect[i,j] = weight
64 | connect[j,i] = weight
65 | self.connect = connect
66 | # Define graph density for an observation to be sum of weights for all
67 | # edges to the node representing the datapoint. Normalize sum weights
68 | # by total number of neighbors.
69 | self.graph_density = np.zeros(self.X.shape[0])
70 | for i in np.arange(self.X.shape[0]):
71 | self.graph_density[i] = connect[i,:].sum() / (connect[i,:]>0).sum()
72 | self.starting_density = copy.deepcopy(self.graph_density)
73 |
74 | def select_batch_(self, N, already_selected, **kwargs):
75 | # If a neighbor has already been sampled, reduce the graph density
76 | # for its direct neighbors to promote diversity.
77 | batch = set()
78 | self.graph_density[already_selected] = min(self.graph_density) - 1
79 | while len(batch) < N:
80 | selected = np.argmax(self.graph_density)
81 | neighbors = (self.connect[selected,:] > 0).nonzero()[1]
82 | self.graph_density[neighbors] = self.graph_density[neighbors] - self.graph_density[selected]
83 | batch.add(selected)
84 | self.graph_density[already_selected] = min(self.graph_density) - 1
85 | self.graph_density[list(batch)] = min(self.graph_density) - 1
86 | return list(batch)
87 |
88 | def to_dict(self):
89 | output = {}
90 | output['connectivity'] = self.connect
91 | output['graph_density'] = self.starting_density
92 | return output
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/informative_diverse.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Informative and diverse batch sampler that samples points with small margin
16 | while maintaining same distribution over clusters as entire training data.
17 |
18 | Batch is created by sorting datapoints by increasing margin and then growing
19 | the batch greedily. A point is added to the batch if the result batch still
20 | respects the constraint that the cluster distribution of the batch will
21 | match the cluster distribution of the entire training set.
22 | """
23 |
24 | from __future__ import absolute_import
25 | from __future__ import division
26 | from __future__ import print_function
27 |
28 | from sklearn.cluster import MiniBatchKMeans
29 | import numpy as np
30 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
31 |
32 |
33 | class InformativeClusterDiverseSampler(SamplingMethod):
34 | """Selects batch based on informative and diverse criteria.
35 |
36 | Returns highest uncertainty lowest margin points while maintaining
37 | same distribution over clusters as entire dataset.
38 | """
39 |
40 | def __init__(self, X, y, seed):
41 | self.name = 'informative_and_diverse'
42 | self.X = X
43 | self.flat_X = self.flatten_X()
44 | # y only used for determining how many clusters there should be
45 | # probably not practical to assume we know # of classes before hand
46 | # should also probably scale with dimensionality of data
47 | self.y = y
48 | self.n_clusters = len(list(set(y)))
49 | self.cluster_model = MiniBatchKMeans(n_clusters=self.n_clusters)
50 | self.cluster_data()
51 |
52 | def cluster_data(self):
53 | # Probably okay to always use MiniBatchKMeans
54 | # Should standardize data before clustering
55 | # Can cluster on standardized data but train on raw features if desired
56 | self.cluster_model.fit(self.flat_X)
57 | unique, counts = np.unique(self.cluster_model.labels_, return_counts=True)
58 | self.cluster_prob = counts/sum(counts)
59 | self.cluster_labels = self.cluster_model.labels_
60 |
61 | def select_batch_(self, model, already_selected, N, **kwargs):
62 | """Returns a batch of size N using informative and diverse selection.
63 |
64 | Args:
65 | model: scikit learn model with decision_function implemented
66 | already_selected: index of datapoints already selected
67 | N: batch size
68 |
69 | Returns:
70 | indices of points selected to add using margin active learner
71 | """
72 | # TODO(lishal): have MarginSampler and this share margin function
73 | try:
74 | distances = model.decision_function(self.X)
75 | except:
76 | distances = model.predict_proba(self.X)
77 | if len(distances.shape) < 2:
78 | min_margin = abs(distances)
79 | else:
80 | sort_distances = np.sort(distances, 1)[:, -2:]
81 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
82 | rank_ind = np.argsort(min_margin)
83 | rank_ind = [i for i in rank_ind if i not in already_selected]
84 | new_batch_cluster_counts = [0 for _ in range(self.n_clusters)]
85 | new_batch = []
86 | for i in rank_ind:
87 | if len(new_batch) == N:
88 | break
89 | label = self.cluster_labels[i]
90 | if new_batch_cluster_counts[label] / N < self.cluster_prob[label]:
91 | new_batch.append(i)
92 | new_batch_cluster_counts[label] += 1
93 | n_slot_remaining = N - len(new_batch)
94 | batch_filler = list(set(rank_ind) - set(already_selected) - set(new_batch))
95 | new_batch.extend(batch_filler[0:n_slot_remaining])
96 | return new_batch
97 |
98 | def to_dict(self):
99 | output = {}
100 | output['cluster_membership'] = self.cluster_labels
101 | return output
102 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/kcenter_greedy.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Returns points that minimizes the maximum distance of any point to a center.
16 |
17 | Implements the k-Center-Greedy method in
18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for
19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017
20 |
21 | Distance metric defaults to l2 distance. Features used to calculate distance
22 | are either raw features or if a model has transform method then uses the output
23 | of model.transform(X).
24 |
25 | Can be extended to a robust k centers algorithm that ignores a certain number of
26 | outlier datapoints. Resulting centers are solution to multiple integer program.
27 | """
28 |
29 | from __future__ import absolute_import
30 | from __future__ import division
31 | from __future__ import print_function
32 |
33 | import numpy as np
34 | from sklearn.metrics import pairwise_distances
35 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
36 |
37 |
38 | class kCenterGreedy(object):
39 |
40 | def __init__(self, metric='euclidean'):
41 | # self.X = X
42 | # self.y = y
43 | # self.flat_X = self.flatten_X()
44 | self.name = 'kcenter'
45 | # self.features = self.flat_X
46 | self.metric = metric
47 | self.min_distances = None
48 | self.n_obs = 0#self.X.shape[0]
49 | self.already_selected = []
50 |
51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
52 | """Update min distances given cluster centers.
53 |
54 | Args:
55 | cluster_centers: indices of cluster centers
56 | only_new: only calculate distance for newly selected points and update
57 | min_distances.
58 | rest_dist: whether to reset min_distances.
59 | """
60 |
61 | if reset_dist:
62 | self.min_distances = None
63 | if only_new:
64 | cluster_centers = [d for d in cluster_centers
65 | if d not in self.already_selected]
66 | if len(cluster_centers)>0:
67 | # Update min_distances for all examples given new cluster center.
68 | x = self.features[cluster_centers]
69 | dist = pairwise_distances(self.features, x, metric=self.metric)
70 |
71 | if self.min_distances is None:
72 | self.min_distances = np.min(dist, axis=1).reshape(-1, 1)
73 | else:
74 | self.min_distances = np.minimum(self.min_distances, dist)
75 |
76 |
77 |
78 |
79 | def select_batch(self, features, already_selected, N, **kwargs):
80 | """
81 | Diversity promoting active learning method that greedily forms a batch
82 | to minimize the maximum distance to a cluster center among all unlabeled
83 | datapoints.
84 |
85 | Args:
86 | model: model with scikit-like API with decision_function implemented
87 | already_selected: index of datapoints already selected
88 | N: batch size
89 |
90 | Returns:
91 | indices of points selected to minimize distance to cluster centers
92 | """
93 |
94 | try:
95 | # Assumes that the transform function takes in original data and not
96 | # flattened data.
97 | # print('Getting transformed features...')
98 | self.features = features
99 | # print('Calculating distances...')
100 | self.update_distances(already_selected, only_new=False, reset_dist=True)
101 | except:
102 | # print('Using flat_X as features.')
103 | self.update_distances(already_selected, only_new=True, reset_dist=False)
104 |
105 | new_batch = []
106 |
107 | for _ in range(N):
108 | if self.already_selected is None:
109 | # Initialize centers with a randomly selected datapoint
110 | ind = np.random.choice(np.arange(self.n_obs))
111 | else:
112 | ind = np.argmax(self.min_distances)
113 |
114 | # New examples should not be in already selected since those points
115 | # should have min_distance of zero to a cluster center.
116 | assert ind not in already_selected
117 |
118 | self.update_distances([ind], only_new=True, reset_dist=False)
119 | new_batch.append(ind)
120 | # print('Maximum distance from cluster centers is %0.2f'
121 | # % max(self.min_distances))
122 |
123 |
124 | self.already_selected = already_selected
125 | # print("already selected3 {}".format(self.already_selected))
126 |
127 |
128 | return new_batch
129 |
130 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/margin_AL.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Margin based AL method.
16 |
17 | Samples in batches based on margin scores.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import numpy as np
25 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
26 |
27 |
28 | class MarginAL(SamplingMethod):
29 | def __init__(self, X, y, seed):
30 | self.X = X
31 | self.y = y
32 | self.name = 'margin'
33 |
34 | def select_batch_(self, model, already_selected, N, **kwargs):
35 | """Returns batch of datapoints with smallest margin/highest uncertainty.
36 |
37 | For binary classification, can just take the absolute distance to decision
38 | boundary for each point.
39 | For multiclass classification, must consider the margin between distance for
40 | top two most likely classes.
41 |
42 | Args:
43 | model: scikit learn model with decision_function implemented
44 | already_selected: index of datapoints already selected
45 | N: batch size
46 |
47 | Returns:
48 | indices of points selected to add using margin active learner
49 | """
50 |
51 | try:
52 | distances = model.decision_function(self.X)
53 | except:
54 | distances = model.predict_proba(self.X)
55 | if len(distances.shape) < 2:
56 | min_margin = abs(distances)
57 | else:
58 | sort_distances = np.sort(distances, 1)[:, -2:]
59 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
60 | rank_ind = np.argsort(min_margin)
61 | rank_ind = [i for i in rank_ind if i not in already_selected]
62 | active_samples = rank_ind[0:N]
63 | return active_samples
64 |
65 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/mixture_of_samplers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Mixture of base sampling strategies
16 |
17 | """
18 |
19 | from __future__ import absolute_import
20 | from __future__ import division
21 | from __future__ import print_function
22 |
23 | import copy
24 |
25 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
26 | from src.baselines.sampling_methods.constants import AL_MAPPING, get_base_AL_mapping
27 |
28 | get_base_AL_mapping()
29 |
30 |
31 | class MixtureOfSamplers(SamplingMethod):
32 | """Samples according to mixture of base sampling methods.
33 |
34 | If duplicate points are selected by the mixed strategies when forming the batch
35 | then the remaining slots are divided according to mixture weights and
36 | another partial batch is requested until the batch is full.
37 | """
38 | def __init__(self,
39 | X,
40 | y,
41 | seed,
42 | mixture={'methods': ('margin', 'uniform'),
43 | 'weight': (0.5, 0.5)},
44 | samplers=None):
45 | self.X = X
46 | self.y = y
47 | self.name = 'mixture_of_samplers'
48 | self.sampling_methods = mixture['methods']
49 | self.sampling_weights = dict(zip(mixture['methods'], mixture['weights']))
50 | self.seed = seed
51 | # A list of initialized samplers is allowed as an input because
52 | # for AL_methods that search over different mixtures, may want mixtures to
53 | # have shared AL_methods so that initialization is only performed once for
54 | # computation intensive methods like HierarchicalClusteringAL and
55 | # states are shared between mixtures.
56 | # If initialized samplers are not provided, initialize them ourselves.
57 | if samplers is None:
58 | self.samplers = {}
59 | self.initialize(self.sampling_methods)
60 | else:
61 | self.samplers = samplers
62 | self.history = []
63 |
64 | def initialize(self, samplers):
65 | self.samplers = {}
66 | for s in samplers:
67 | self.samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed)
68 |
69 | def select_batch_(self, already_selected, N, **kwargs):
70 | """Returns batch of datapoints selected according to mixture weights.
71 |
72 | Args:
73 | already_included: index of datapoints already selected
74 | N: batch size
75 |
76 | Returns:
77 | indices of points selected to add using margin active learner
78 | """
79 | kwargs['already_selected'] = copy.copy(already_selected)
80 | inds = set()
81 | self.selected_by_sampler = {}
82 | for s in self.sampling_methods:
83 | self.selected_by_sampler[s] = []
84 | effective_N = 0
85 | while len(inds) < N:
86 | effective_N += N - len(inds)
87 | for s in self.sampling_methods:
88 | if len(inds) < N:
89 | batch_size = min(max(int(self.sampling_weights[s] * effective_N), 1), N)
90 | sampler = self.samplers[s]
91 | kwargs['N'] = batch_size
92 | s_inds = sampler.select_batch(**kwargs)
93 | for ind in s_inds:
94 | if ind not in self.selected_by_sampler[s]:
95 | self.selected_by_sampler[s].append(ind)
96 | s_inds = [d for d in s_inds if d not in inds]
97 | s_inds = s_inds[0 : min(len(s_inds), N-len(inds))]
98 | inds.update(s_inds)
99 | self.history.append(copy.deepcopy(self.selected_by_sampler))
100 | return list(inds)
101 |
102 | def to_dict(self):
103 | output = {}
104 | output['history'] = self.history
105 | output['samplers'] = self.sampling_methods
106 | output['mixture_weights'] = self.sampling_weights
107 | for s in self.samplers:
108 | s_output = self.samplers[s].to_dict()
109 | output[s] = s_output
110 | return output
111 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/represent_cluster_centers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Another informative and diverse sampler that mirrors the algorithm described
16 | in Xu, et. al., Representative Sampling for Text Classification Using
17 | Support Vector Machines, 2003
18 |
19 | Batch is created by clustering points within the margin of the classifier and
20 | choosing points closest to the k centroids.
21 | """
22 |
23 | from __future__ import absolute_import
24 | from __future__ import division
25 | from __future__ import print_function
26 |
27 | from sklearn.cluster import MiniBatchKMeans
28 | import numpy as np
29 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
30 |
31 |
32 | class RepresentativeClusterMeanSampling(SamplingMethod):
33 | """Selects batch based on informative and diverse criteria.
34 |
35 | Returns points within the margin of the classifier that are closest to the
36 | k-means centers of those points.
37 | """
38 |
39 | def __init__(self, X, y, seed):
40 | self.name = 'cluster_mean'
41 | self.X = X
42 | self.flat_X = self.flatten_X()
43 | self.y = y
44 | self.seed = seed
45 |
46 | def select_batch_(self, model, N, already_selected, **kwargs):
47 | # Probably okay to always use MiniBatchKMeans
48 | # Should standardize data before clustering
49 | # Can cluster on standardized data but train on raw features if desired
50 | try:
51 | distances = model.decision_function(self.X)
52 | except:
53 | distances = model.predict_proba(self.X)
54 | if len(distances.shape) < 2:
55 | min_margin = abs(distances)
56 | else:
57 | sort_distances = np.sort(distances, 1)[:, -2:]
58 | min_margin = sort_distances[:, 1] - sort_distances[:, 0]
59 | rank_ind = np.argsort(min_margin)
60 | rank_ind = [i for i in rank_ind if i not in already_selected]
61 |
62 | distances = abs(model.decision_function(self.X))
63 | min_margin_by_class = np.min(abs(distances[already_selected]),axis=0)
64 | unlabeled_in_margin = np.array([i for i in range(len(self.y))
65 | if i not in already_selected and
66 | any(distances[i] 2:
42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:])))
43 | return flat_X
44 |
45 |
46 | @abc.abstractmethod
47 | def select_batch_(self):
48 | return
49 |
50 | def select_batch(self, **kwargs):
51 | return self.select_batch_(**kwargs)
52 |
53 | def to_dict(self):
54 | return None
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/uniform_sampling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Uniform sampling method.
16 |
17 | Samples in batches.
18 | """
19 |
20 | from __future__ import absolute_import
21 | from __future__ import division
22 | from __future__ import print_function
23 |
24 | import numpy as np
25 |
26 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
27 |
28 |
29 | class UniformSampling(SamplingMethod):
30 |
31 | def __init__(self, X, y, seed):
32 | self.X = X
33 | self.y = y
34 | self.name = 'uniform'
35 | np.random.seed(seed)
36 |
37 | def select_batch_(self, already_selected, N, **kwargs):
38 | """Returns batch of randomly sampled datapoints.
39 |
40 | Assumes that data has already been shuffled.
41 |
42 | Args:
43 | already_selected: index of datapoints already selected
44 | N: batch size
45 |
46 | Returns:
47 | indices of points selected to label
48 | """
49 |
50 | # This is uniform given the remaining pool but biased wrt the entire pool.
51 | sample = [i for i in range(self.X.shape[0]) if i not in already_selected]
52 | return sample[0:N]
53 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/utils/tree.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Node and Tree class to support hierarchical clustering AL method.
16 |
17 | Assumed to be binary tree.
18 |
19 | Node class is used to represent each node in a hierarchical clustering.
20 | Each node has certain properties that are used in the AL method.
21 |
22 | Tree class is used to traverse a hierarchical clustering.
23 | """
24 |
25 | from __future__ import absolute_import
26 | from __future__ import division
27 | from __future__ import print_function
28 |
29 | import copy
30 |
31 |
32 | class Node(object):
33 | """Node class for hierarchical clustering.
34 |
35 | Initialized with name and left right children.
36 | """
37 |
38 | def __init__(self, name, left=None, right=None):
39 | self.name = name
40 | self.left = left
41 | self.right = right
42 | self.is_leaf = left is None and right is None
43 | self.parent = None
44 | # Fields for hierarchical clustering AL
45 | self.score = 1.0
46 | self.split = False
47 | self.best_label = None
48 | self.weight = None
49 |
50 | def set_parent(self, parent):
51 | self.parent = parent
52 |
53 |
54 | class Tree(object):
55 | """Tree object for traversing a binary tree.
56 |
57 | Most methods apply to trees in general with the exception of get_pruning
58 | which is specific to the hierarchical clustering AL method.
59 | """
60 |
61 | def __init__(self, root, node_dict):
62 | """Initializes tree and creates all nodes in node_dict.
63 |
64 | Args:
65 | root: id of the root node
66 | node_dict: dictionary with node_id as keys and entries indicating
67 | left and right child of node respectively.
68 | """
69 | self.node_dict = node_dict
70 | self.root = self.make_tree(root)
71 | self.nodes = {}
72 | self.leaves_mapping = {}
73 | self.fill_parents()
74 | self.n_leaves = None
75 |
76 | def print_tree(self, node, max_depth):
77 | """Helper function to print out tree for debugging."""
78 | node_list = [node]
79 | output = ""
80 | level = 0
81 | while level < max_depth and len(node_list):
82 | children = set()
83 | for n in node_list:
84 | node = self.get_node(n)
85 | output += ("\t"*level+"node %d: score %.2f, weight %.2f" %
86 | (node.name, node.score, node.weight)+"\n")
87 | if node.left:
88 | children.add(node.left.name)
89 | if node.right:
90 | children.add(node.right.name)
91 | level += 1
92 | node_list = children
93 | return print(output)
94 |
95 | def make_tree(self, node_id):
96 | if node_id is not None:
97 | return Node(node_id,
98 | self.make_tree(self.node_dict[node_id][0]),
99 | self.make_tree(self.node_dict[node_id][1]))
100 |
101 | def fill_parents(self):
102 | # Setting parent and storing nodes in dict for fast access
103 | def rec(pointer, parent):
104 | if pointer is not None:
105 | self.nodes[pointer.name] = pointer
106 | pointer.set_parent(parent)
107 | rec(pointer.left, pointer)
108 | rec(pointer.right, pointer)
109 | rec(self.root, None)
110 |
111 | def get_node(self, node_id):
112 | return self.nodes[node_id]
113 |
114 | def get_ancestor(self, node):
115 | ancestors = []
116 | if isinstance(node, int):
117 | node = self.get_node(node)
118 | while node.name != self.root.name:
119 | node = node.parent
120 | ancestors.append(node.name)
121 | return ancestors
122 |
123 | def fill_weights(self):
124 | for v in self.node_dict:
125 | node = self.get_node(v)
126 | node.weight = len(self.leaves_mapping[v]) / (1.0 * self.n_leaves)
127 |
128 | def create_child_leaves_mapping(self, leaves):
129 | """DP for creating child leaves mapping.
130 |
131 | Storing in dict to save recompute.
132 | """
133 | self.n_leaves = len(leaves)
134 | for v in leaves:
135 | self.leaves_mapping[v] = [v]
136 | node_list = set([self.get_node(v).parent for v in leaves])
137 | while node_list:
138 | to_fill = copy.copy(node_list)
139 | for v in node_list:
140 | if (v.left.name in self.leaves_mapping
141 | and v.right.name in self.leaves_mapping):
142 | to_fill.remove(v)
143 | self.leaves_mapping[v.name] = (self.leaves_mapping[v.left.name] +
144 | self.leaves_mapping[v.right.name])
145 | if v.parent is not None:
146 | to_fill.add(v.parent)
147 | node_list = to_fill
148 | self.fill_weights()
149 |
150 | def get_child_leaves(self, node):
151 | return self.leaves_mapping[node]
152 |
153 | def get_pruning(self, node):
154 | if node.split:
155 | return self.get_pruning(node.left) + self.get_pruning(node.right)
156 | else:
157 | return [node.name]
158 |
159 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/utils/tree_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sampling_methods.utils.tree."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import unittest
22 | from sampling_methods.utils import tree
23 |
24 |
25 | class TreeTest(unittest.TestCase):
26 |
27 | def setUp(self):
28 | node_dict = {
29 | 1: (2, 3),
30 | 2: (4, 5),
31 | 3: (6, 7),
32 | 4: [None, None],
33 | 5: [None, None],
34 | 6: [None, None],
35 | 7: [None, None]
36 | }
37 | self.tree = tree.Tree(1, node_dict)
38 | self.tree.create_child_leaves_mapping([4, 5, 6, 7])
39 | node = self.tree.get_node(1)
40 | node.split = True
41 | node = self.tree.get_node(2)
42 | node.split = True
43 |
44 | def assertNode(self, node, name, left, right):
45 | self.assertEqual(node.name, name)
46 | self.assertEqual(node.left.name, left)
47 | self.assertEqual(node.right.name, right)
48 |
49 | def testTreeRootSetCorrectly(self):
50 | self.assertNode(self.tree.root, 1, 2, 3)
51 |
52 | def testGetNode(self):
53 | node = self.tree.get_node(1)
54 | assert isinstance(node, tree.Node)
55 | self.assertEqual(node.name, 1)
56 |
57 | def testFillParent(self):
58 | node = self.tree.get_node(3)
59 | self.assertEqual(node.parent.name, 1)
60 |
61 | def testGetAncestors(self):
62 | ancestors = self.tree.get_ancestor(5)
63 | self.assertTrue(all([a in ancestors for a in [1, 2]]))
64 |
65 | def testChildLeaves(self):
66 | leaves = self.tree.get_child_leaves(3)
67 | self.assertTrue(all([c in leaves for c in [6, 7]]))
68 |
69 | def testFillWeights(self):
70 | node = self.tree.get_node(3)
71 | self.assertEqual(node.weight, 0.5)
72 |
73 | def testGetPruning(self):
74 | node = self.tree.get_node(1)
75 | pruning = self.tree.get_pruning(node)
76 | self.assertTrue(all([n in pruning for n in [3, 4, 5]]))
77 |
78 | if __name__ == '__main__':
79 | unittest.main()
80 |
--------------------------------------------------------------------------------
/src/baselines/sampling_methods/wrapper_sampler_def.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 Google Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Abstract class for wrapper sampling methods that call base sampling methods.
16 |
17 | Provides interface to sampling methods that allow same signature
18 | for select_batch. Each subclass implements select_batch_ with the desired
19 | signature for readability.
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import abc
27 |
28 | from src.baselines.sampling_methods.constants import AL_MAPPING
29 | from src.baselines.sampling_methods.constants import get_all_possible_arms
30 | from src.baselines.sampling_methods.sampling_def import SamplingMethod
31 |
32 | get_all_possible_arms()
33 |
34 |
35 | class WrapperSamplingMethod(SamplingMethod):
36 | __metaclass__ = abc.ABCMeta
37 |
38 | def initialize_samplers(self, mixtures):
39 | methods = []
40 | for m in mixtures:
41 | methods += m['methods']
42 | methods = set(methods)
43 | self.base_samplers = {}
44 | for s in methods:
45 | self.base_samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed)
46 | self.samplers = []
47 | for m in mixtures:
48 | self.samplers.append(
49 | AL_MAPPING['mixture_of_samplers'](self.X, self.y, self.seed, m,
50 | self.base_samplers))
51 |
--------------------------------------------------------------------------------
/src/datasetcollecting/biggraph.py:
--------------------------------------------------------------------------------
1 | from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset
2 | import argparse
3 | import torch
4 | from src.utils.dataloader import GraphLoader
5 | import pickle as pkl
6 | def selectclass(labels):
7 | thres=4
8 | numof1or0 = labels.sum(dim=0)/labels.size(0)
9 | idxinclude = torch.where((numof1or0>0.75+(numof1or0<0.25))==0)[0]
10 | # print(len(idxinclude))
11 | labels = labels[:,idxinclude]
12 | equalsset = set()
13 | useset =set()
14 | for i in range(len(idxinclude)):
15 | if i in equalsset:
16 | continue
17 | base = labels[:,i:i+1]
18 | rate = (labels*base).sum(dim=0)/labels.sum(dim=0) # labels 1 is inside base's 1
19 | rate2 = ((1-labels)*(1-base)).sum(dim=0)/(1-labels).sum(dim=0) # labels 0 is inside base's 0,
20 | #if many large rate, base doesn't have many 1 (at hierachy's bottom)
21 | largeloc = torch.where(rate>0.9) [0].numpy().tolist()
22 | largeloc2 = torch.where(rate2>0.9) [0].numpy().tolist() # if both high, two classes are equal
23 |
24 |
25 |
26 |
27 |
28 | # print("i:{} {}; {}".format(i,largeloc,largeloc2))
29 | if len(largeloc)jip",input,self.weight)
35 | support = torch.reshape(support,[support.size(0),-1])
36 | # support = torch.mm(input, self.weight)
37 | if self.bias is not None:
38 | support = support + self.bias
39 | output = torch.spmm(adj, support)
40 | output = torch.reshape(output,[output.size(0),self.batchsize,-1])
41 | return output
42 |
43 | def __repr__(self):
44 | return self.__class__.__name__ + ' (' \
45 | + str(self.in_features) + ' -> ' \
46 | + str(self.out_features) + ')'
47 |
48 |
49 | class GCN(nn.Module):
50 |
51 | def __init__(self, nfeat, nhid, nclass, batchsize, dropout=0.5, softmax=True, bias=False):
52 | super(GCN, self).__init__()
53 | self.dropout_rate = dropout
54 | self.softmax = softmax
55 | self.batchsize = batchsize
56 | self.gc1 = GraphConvolution(nfeat, nhid, batchsize, bias=bias)
57 | self.gc2 = GraphConvolution(nhid, nclass, batchsize, bias=bias)
58 | self.dropout = nn.Dropout(p=dropout,inplace=False)
59 |
60 | def forward(self, x, adj):
61 | x = x.expand([self.batchsize]+list(x.size())).transpose(0,1)
62 | x = self.dropout(x)
63 | x = F.relu(self.gc1(x, adj))
64 | x = self.dropout(x)
65 | x = self.gc2(x, adj)
66 | x = x.transpose(0,1).transpose(1,2)
67 | if self.softmax:
68 | return F.log_softmax(x, dim=1)
69 | else:
70 | return x.squeeze()
71 |
72 | def reset(self):
73 | self.gc1.reset_parameters()
74 | self.gc2.reset_parameters()
--------------------------------------------------------------------------------
/src/utils/common.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import sys
3 |
4 |
5 | def ConfigRootLogger(name='', version=None, level="info"):
6 | logger = logging.getLogger(name)
7 | if level == "debug":
8 | logger.setLevel(logging.DEBUG)
9 | elif level == "info":
10 | logger.setLevel(logging.INFO)
11 | elif level == "warning":
12 | logger.setLevel(logging.WARNING)
13 | format = logging.Formatter("T[%(asctime)s]-V[{}]-POS[%(module)s."
14 | "%(funcName)s(line %(lineno)s)]-PID[%(process)d] %(levelname)s"
15 | ">>> %(message)s ".format(version),"%H:%M:%S")
16 | stdout_format = logging.Formatter("T[%(asctime)s]-V[{}]-POS[%(module)s."
17 | "%(funcName)s(line %(lineno)s)]-PID[%(process)d] %(levelname)s"
18 | ">>> %(message)s ".format(version),"%H:%M:%S")
19 |
20 | file_handler = logging.FileHandler("log.txt")
21 | file_handler.setFormatter(format)
22 | stream_handler = logging.StreamHandler(sys.stdout)
23 | stream_handler.setFormatter(stdout_format)
24 | logger.addHandler(file_handler)
25 | logger.addHandler(stream_handler)
26 | return logger
27 |
28 |
29 | ConfigRootLogger("main","1",level="info")
30 | logger = logging.getLogger("main")
31 |
32 |
33 |
34 | def logargs(args,tablename="",width=120 ):
35 | length = 1
36 | L=[]
37 | l= "|"
38 | for id,arg in enumerate(vars(args)):
39 | name,value = arg, str(getattr(args, arg))
40 | nv = name+":"+value
41 | if length +(len(nv)+2)>width:
42 | L.append(l)
43 | l = "|"
44 | length = 1
45 | l += nv + " |"
46 | length += (len(nv)+2)
47 | if id+1 == len(vars(args)):
48 | L.append(l)
49 | printstr = niceprint(L)
50 | logger.info("{}:\n{}".format(tablename,printstr))
51 |
52 | def niceprint(L,mark="-"):
53 | printstr = []
54 | printstr.append("-"*len(L[0]))
55 | printstr.append(L[0])
56 | for id in range(1,len(L)):
57 | printstr.append("-"*max(len(L[id-1]),len(L[id])))
58 | printstr.append(L[id])
59 | printstr.append("-"*len(L[-1]))
60 | printstr = "\n".join(printstr)
61 | return printstr
62 |
63 | def logdicts(dic,tablename="",width=120 ):
64 | length = 1
65 | L=[]
66 | l= "|"
67 | tup = dic.items()
68 | for id,arg in enumerate(tup):
69 | name,value = arg
70 | nv = name+":"+str(value)
71 | if length +(len(nv)+2)>width:
72 | L.append(l)
73 | l = "|"
74 | length = 1
75 | l += nv + " |"
76 | length += (len(nv)+2)
77 | if id+1 == len(tup):
78 | L.append(l)
79 | printstr = niceprint(L)
80 |
81 | logger.info("{}:\n{}".format(tablename,printstr))
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/src/utils/const.py:
--------------------------------------------------------------------------------
1 | MIN_EPSILON = 1e-10
2 | MAX_EXP = 70
3 |
--------------------------------------------------------------------------------
/src/utils/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import networkx as nx
3 | import pickle as pkl
4 | import torch
5 | import numpy as np
6 | from collections import OrderedDict
7 | import time
8 | from src.utils.common import *
9 | from src.utils.utils import *
10 |
11 |
12 | class GraphLoader(object):
13 |
14 | def __init__(self,name,root = "./data",undirected=True, hasX=True,hasY=True,header=True,sparse=True,multigraphindex=None,args=None):
15 |
16 | self.name = name
17 | self.undirected = undirected
18 | self.hasX = hasX
19 | self.hasY = hasY
20 | self.header = header
21 | self.sparse = sparse
22 | self.dirname = os.path.join(root,name)
23 | if name == "reddit1401":
24 | self.prefix = os.path.join(root, name,multigraphindex,multigraphindex)
25 | else:
26 | self.prefix = os.path.join(root,name,name)
27 | self._load()
28 | self._registerStat()
29 | self.printStat()
30 |
31 |
32 | def _loadConfig(self):
33 | file_name = os.path.join(self.dirname,"bestconfig.txt")
34 | f = open(file_name,'r')
35 | L = f.readlines()
36 | L = [x.strip().split() for x in L]
37 | self.bestconfig = {x[0]:x[1] for x in L if len(x)!=0}
38 |
39 |
40 | def _loadGraph(self, header = True):
41 | """
42 | load file in form:
43 | --------------------
44 | NUM_Of_NODE\n
45 | v1 v2\n
46 | v3 v4\n
47 | --------------------
48 | """
49 | file_name = self.prefix+".edgelist"
50 | if not header:
51 | logger.warning("You are reading an edgelist with no explicit number of nodes")
52 | if self.undirected:
53 | G = nx.Graph()
54 | else:
55 | G = nx.DiGraph()
56 | with open(file_name) as f:
57 | L = f.readlines()
58 | if header:
59 | num_node = int(L[0].strip())
60 | L = L[1:]
61 | edge_list = [[int(x) for x in e.strip().split()] for e in L]
62 | nodeset = set([x for e in edge_list for x in e])
63 | # if header:
64 | # assert min(nodeset) == 0 and max(nodeset) + 1 == num_node, "input standard violated {} ,{}".format(num_node,max(nodeset))
65 |
66 | if header:
67 | G.add_nodes_from([x for x in range(num_node)])
68 | else:
69 | G.add_nodes_from([x for x in range(max(nodeset)+1)])
70 | G.add_edges_from(edge_list)
71 | self.G = G
72 |
73 | def _loadX(self):
74 | self.X = pkl.load(open(self.prefix + ".x.pkl", 'rb'))
75 | self.X = self.X.astype(np.float32)
76 | if self.name in ["coauthor_phy","corafull"]:
77 | self.X = self.X[:,:2000] # the coauthor_phy's feature is too large to fit in the memory.
78 |
79 | def _loadY(self):
80 | self.Y = pkl.load(open(self.prefix+".y.pkl",'rb'))#.astype(np.float32)
81 |
82 | def _getAdj(self):
83 | self.adj = nx.adjacency_matrix(self.G).astype(np.float32)
84 |
85 | def _toTensor(self,device=None):
86 | if device is None:
87 | if self.sparse:
88 | self.adj = sparse_mx_to_torch_sparse_tensor(self.adj).cuda()
89 | self.normadj = sparse_mx_to_torch_sparse_tensor(self.normadj).cuda()
90 | else:
91 | self.adj = torch.from_numpy(self.adj).cuda()
92 | self.normadj = torch.from_numpy(self.normadj).cuda()
93 | self.X = torch.from_numpy(self.X).cuda()
94 | self.Y = torch.from_numpy(self.Y).cuda()
95 |
96 | def _load(self):
97 | self._loadGraph(header=self.header)
98 | self._loadConfig()
99 | if self.hasX:
100 | self._loadX()
101 | if self.hasY:
102 | self._loadY()
103 | self._getAdj()
104 |
105 | def _registerStat(self):
106 | L=OrderedDict()
107 | L["name"] = self.name
108 | L["nnode"] = self.G.number_of_nodes()
109 | L["nedge"] = self.G.number_of_edges()
110 | L["nfeat"] = self.X.shape[1]
111 | L["nclass"] = self.Y.max() + 1
112 | L["sparse"] = self.sparse
113 | L["multilabel"] = False
114 | L.update(self.bestconfig)
115 | self.stat = L
116 |
117 | def process(self):
118 | if int(self.bestconfig['feature_normalize']):
119 | self.X = column_normalize(preprocess_features(self.X)) # take some time
120 |
121 | # self.X = self.X - self.X.min(axis=0)
122 | # print(np.where(self.X))
123 | # exit()
124 | # print(self.X[:3,:].tolist())
125 | self.normadj = preprocess_adj(self.adj)
126 | if not self.sparse:
127 | self.adj = self.adj.todense()
128 | self.normadj = self.normadj.todense()
129 | self._toTensor()
130 |
131 | self.normdeg = self._getNormDeg()
132 |
133 |
134 | def printStat(self):
135 | logdicts(self.stat,tablename="dataset stat")
136 |
137 | def _getNormDeg(self):
138 | self.deg = torch.sparse.sum(self.adj, dim=1).to_dense()
139 | normdeg =self.deg/ self.deg.max()
140 | return normdeg
--------------------------------------------------------------------------------
/src/utils/env.py:
--------------------------------------------------------------------------------
1 | import torch.multiprocessing as mp
2 | import time
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from src.utils.player import Player
8 | from src.utils.utils import *
9 |
10 |
11 | def logprob2Prob(logprobs,multilabel=False):
12 | if multilabel:
13 | probs = torch.sigmoid(logprobs)
14 | else:
15 | probs = F.softmax(logprobs, dim=2)
16 | return probs
17 |
18 | def normalizeEntropy(entro,classnum): #this is needed because different number of classes will have different entropy
19 | maxentro = np.log(float(classnum))
20 | entro = entro/maxentro
21 | return entro
22 |
23 | def prob2Logprob(probs,multilabel=False):
24 | if multilabel:
25 | raise NotImplementedError("multilabel for prob2Logprob is not implemented")
26 | else:
27 | logprobs = torch.log(probs)
28 | return logprobs
29 |
30 | def perc(input):
31 | # the biger valueis the biger result is
32 | numnode = input.size(-2)
33 | res = torch.argsort(torch.argsort(input, dim=-2), dim=-2) / float(numnode)
34 | return res
35 |
36 | def degprocess(deg):
37 | # deg = torch.log(1+deg)
38 | #return deg/20.
39 | return torch.clamp_max(deg / 20., 1.)
40 |
41 | def localdiversity(probs,adj,deg):
42 | indices = adj.coalesce().indices()
43 | N =adj.size()[0]
44 | classnum = probs.size()[-1]
45 | maxentro = np.log(float(classnum))
46 | edgeprobs = probs[:,indices.transpose(0,1),:]
47 | headprobs = edgeprobs[:,:,0,:]
48 | tailprobs = edgeprobs[:,:,1,:]
49 | kl_ht = (torch.sum(torch.log(torch.clamp_min(tailprobs,1e-10))*tailprobs,dim=-1) - \
50 | torch.sum(torch.log(torch.clamp_min(headprobs,1e-10))*tailprobs,dim=-1)).transpose(0,1)
51 | kl_th = (torch.sum(torch.log(torch.clamp_min(headprobs,1e-10))*headprobs,dim=-1) - \
52 | torch.sum(torch.log(torch.clamp_min(tailprobs,1e-10))*headprobs,dim=-1)).transpose(0,1)
53 | sparse_output_kl_ht = torch.sparse.FloatTensor(indices,kl_ht,size=torch.Size([N,N,kl_ht.size(-1)]))
54 | sparse_output_kl_th = torch.sparse.FloatTensor(indices,kl_th,size=torch.Size([N,N,kl_th.size(-1)]))
55 | sum_kl_ht = torch.sparse.sum(sparse_output_kl_ht,dim=1).to_dense().transpose(0,1)
56 | sum_kl_th = torch.sparse.sum(sparse_output_kl_th,dim=1).to_dense().transpose(0,1)
57 | mean_kl_ht = sum_kl_ht/(deg+1e-10)
58 | mean_kl_th = sum_kl_th/(deg+1e-10)
59 | # normalize
60 | mean_kl_ht = mean_kl_ht / mean_kl_ht.max(dim=1, keepdim=True).values
61 | mean_kl_th = mean_kl_th / mean_kl_th.max(dim=1, keepdim=True).values
62 | return mean_kl_ht,mean_kl_th
63 |
64 |
65 | class Env(object):
66 | ## an environment for multiple players testing the policy at the same time
67 | def __init__(self,players,args):
68 | '''
69 | players: a list containing main player (many task) (or only one task
70 | '''
71 | self.players = players
72 | self.args = args
73 | self.nplayer = len(self.players)
74 | self.graphs = [p.G for p in self.players]
75 | featdim =-1
76 | self.statedim = self.getState(0).size(featdim)
77 |
78 |
79 | def step(self,actions,playerid=0):
80 | p = self.players[playerid]
81 | p.query(actions)
82 | p.trainOnce()
83 | reward = p.validation(test=False, rerun=False)
84 | return reward
85 |
86 |
87 | def getState(self,playerid=0):
88 | p = self.players[playerid]
89 | output = logprob2Prob(p.allnodes_output.transpose(1,2),multilabel=p.G.stat["multilabel"])
90 | state = self.makeState(output,p.trainmask,p.G.deg,playerid)
91 | return state
92 |
93 |
94 | def reset(self,playerid=0):
95 | self.players[playerid].reset(fix_test=False)
96 |
97 |
98 | def makeState(self, probs, selected, deg,playerid, adj=None, multilabel=False ):
99 | entro = entropy(probs, multilabel=multilabel)
100 | entro = normalizeEntropy(entro,probs.size(-1)) ## in order to transfer
101 | deg = degprocess(deg.expand([probs.size(0)]+list(deg.size())))
102 |
103 | features = []
104 | if self.args.use_entropy:
105 | features.append(entro)
106 | if self.args.use_degree:
107 | features.append(deg)
108 | if self.args.use_local_diversity:
109 | mean_kl_ht,mean_kl_th = localdiversity(probs,self.players[playerid].G.adj,self.players[playerid].G.deg)
110 | features.extend([mean_kl_ht, mean_kl_th])
111 | if self.args.use_select:
112 | features.append(selected)
113 | state = torch.stack(features, dim=-1)
114 |
115 | return state
--------------------------------------------------------------------------------
/src/utils/player.py:
--------------------------------------------------------------------------------
1 | # individual player who takes the action and evaluates the effect
2 | import torch
3 | import torch.nn as nn
4 | import time
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 | from src.utils.common import *
9 | from src.utils.classificationnet import GCN
10 | from src.utils.utils import *
11 |
12 |
13 | class Player(nn.Module):
14 |
15 | def __init__(self,G,args,rank=0):
16 |
17 | super(Player,self).__init__()
18 | self.G = G
19 | self.args = args
20 | self.rank = rank
21 | self.batchsize = args.batchsize
22 |
23 | if self.G.stat['multilabel']:
24 | self.net = GCN(self.G.stat['nfeat'],args.nhid,self.G.stat['nclass'],args.batchsize,args.dropout,False,bias=True).cuda()
25 | self.loss_func=F.binary_cross_entropy_with_logits
26 | else:
27 | self.net = GCN(self.G.stat['nfeat'],args.nhid,self.G.stat['nclass'],args.batchsize,args.dropout,True).cuda()
28 | self.loss_func=F.nll_loss
29 |
30 | self.fulllabel = self.G.Y.expand([self.batchsize]+list(self.G.Y.size()))
31 |
32 | self.reset(fix_test=False) #initialize
33 | self.count = 0
34 |
35 |
36 | def makeValTestMask(self, fix_test=True):
37 | #if fix_test:
38 | # assert False
39 | valmask=torch.zeros((self.batchsize,self.G.stat['nnode'])).to(torch.float).cuda()
40 | testmask = torch.zeros((self.batchsize,self.G.stat['nnode'])).to(torch.float).cuda()
41 | valid = []
42 | testid = []
43 | vallabel = []
44 | testlabel = []
45 | for i in range(self.batchsize):
46 | base = np.array([x for x in range(self.G.stat["nnode"])])
47 | if fix_test:
48 | testid_=[x for x in range(self.G.stat["nnode"] - self.args.ntest,self.G.stat["nnode"])]
49 | else:
50 | testid_ = np.sort(np.random.choice(base, size=self.args.ntest, replace=False)).tolist()
51 | testmask[i, testid_] = 1.
52 | testid.append(testid_)
53 | testlabel.append(self.G.Y[testid_])
54 | s = set(testid_)
55 | base= [x for x in range(self.G.stat["nnode"]) if x not in s ]
56 | valid_ = np.sort(np.random.choice(base, size=self.args.nval, replace=False)).tolist()
57 | valmask[i,valid_]=1.
58 | valid.append(valid_)
59 | vallabel.append(self.G.Y[valid_])
60 | self.valid = torch.tensor(valid).cuda()
61 | self.testid = torch.tensor(testid).cuda()
62 | self.vallabel = torch.stack(vallabel).cuda()
63 | self.testlabel = torch.stack(testlabel).cuda()
64 | self.valmask=valmask
65 | self.testmask=testmask
66 |
67 |
68 | def lossWeighting(self,epoch):
69 | return min(epoch,10.)/10.
70 |
71 |
72 | def query(self,nodes):
73 | self.trainmask[[x for x in range(self.batchsize)],nodes] = 1.
74 |
75 |
76 | def getPool(self,reduce=True):
77 | mask = self.testmask+self.valmask+self.trainmask
78 | row,col = torch.where(mask<0.1)
79 | if reduce:
80 | row, col = row.cpu().numpy(),col.cpu().numpy()
81 | pool = []
82 | for i in range(self.batchsize):
83 | pool.append(col[row==i])
84 | return pool
85 | else:
86 | return row,col
87 |
88 |
89 | def trainOnce(self,log=False):
90 | nlabeled = torch.sum(self.trainmask)/self.batchsize
91 | self.net.train()
92 | self.opt.zero_grad()
93 | output = self.net(self.G.X,self.G.normadj)
94 | # print(output.size())
95 | # exit()
96 | if self.G.stat["multilabel"]:
97 | output_trans = output.transpose(1,2)
98 | # print(output_trans[:,-20:,:])
99 |
100 | losses = self.loss_func(output_trans,self.fulllabel,reduction="none").sum(dim=2)
101 | else:
102 | losses = self.loss_func(output,self.fulllabel,reduction="none")
103 | loss = torch.sum(losses*self.trainmask)/nlabeled*self.lossWeighting(float(nlabeled.cpu()))
104 | loss.backward()
105 | self.opt.step()
106 | #if log:
107 | #logger.info("nnodes selected:{},loss:{}".format(nlabeled,loss.detach().cpu().numpy()))
108 | self.allnodes_output=output.detach()
109 | return output
110 |
111 |
112 | def validation(self,test=False,rerun=True):
113 | if test:
114 | mask = self.testmask
115 | labels= self.testlabel
116 | index = self.testid
117 | else:
118 | mask = self.valmask
119 | labels = self.vallabel
120 | index = self.valid
121 | if rerun:
122 | self.net.eval()
123 | output = self.net(self.G.X,self.G.normadj)
124 | else:
125 | output = self.allnodes_output
126 | if self.G.stat["multilabel"]:
127 | # logger.info("output of classification {}".format(output))
128 | output_trans = output.transpose(1,2)
129 | losses_val = self.loss_func(output_trans,self.fulllabel,reduction="none").mean(dim=2)
130 | else:
131 | losses_val = self.loss_func(output,self.fulllabel,reduction="none")
132 | loss_val = torch.sum(losses_val*mask,dim =1,keepdim=True)/torch.sum(mask,dim =1,keepdim=True)
133 | acc= []
134 | for i in range(self.batchsize):
135 | pred_val = (output[i][:,index[i]]).transpose(0,1)
136 | # logger.info("pred_val {}".format(pred_val))
137 | acc.append(accuracy(pred_val,labels[i]))
138 |
139 | # logger.info("validation acc {}".format(acc))
140 | return list(zip(*acc))
141 |
142 |
143 | def trainRemain(self):
144 | for i in range(self.args.remain_epoch):
145 | self.trainOnce()
146 |
147 |
148 | def reset(self,resplit=True,fix_test=True):
149 | if resplit:
150 | self.makeValTestMask(fix_test=fix_test)
151 | self.trainmask = torch.zeros((self.batchsize,self.G.stat['nnode'])).to(torch.float).cuda()
152 | self.net.reset()
153 | self.opt = torch.optim.Adam(self.net.parameters(),lr=self.args.lr,weight_decay=5e-4)
154 | self.allnodes_output = self.net(self.G.X,self.G.normadj).detach()
155 |
156 |
157 |
158 | import argparse
159 | def parse_args():
160 | parser = argparse.ArgumentParser()
161 | parser.add_argument("--dropout",type=float,default=0.5)
162 | parser.add_argument("--ntest",type=int,default=1000)
163 | parser.add_argument("--nval",type=int,default=500)
164 | parser.add_argument("--nhid", type=int, default=64)
165 | parser.add_argument("--lr", type=float, default=3e-2)
166 | parser.add_argument("--batchsize", type=int, default=2)
167 | parser.add_argument("--budget", type=int, default=20, help="budget per class")
168 | parser.add_argument("--dataset", type=str, default="cora")
169 | parser.add_argument("--remain_epoch", type=int, default=35, help="continues training $remain_epoch")
170 |
171 | args = parser.parse_args()
172 | return args
173 |
174 | if __name__=="__main__":
175 | from src.utils.dataloader import GraphLoader
176 | args = parse_args()
177 | G = GraphLoader("cora")
178 | G.process()
179 | p = Player(G,args)
180 | p.query([2,3])
181 | p.query([4,6])
182 |
183 | p.trainOnce()
184 |
185 | print(p.trainmask[:,:10])
186 | print(p.allnodes_output[0].size())
--------------------------------------------------------------------------------
/src/utils/query.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.distributions import Categorical
5 | import torch.nn.functional as F
6 |
7 |
8 | def choose(p,pool):
9 | return np.random.choice(pool,size=1,replace=False,p=p)
10 |
11 |
12 | class RandomQuery(object):
13 | def __init__(self):
14 | pass
15 |
16 | def __call__(self,pool):
17 | ret = []
18 | for row in pool:
19 | p = np.ones(len(row))
20 | p/=p.sum()
21 | ret .append(choose(p,row))
22 | ret = np.concatenate(ret)
23 | return ret
24 |
25 |
26 | class ProbQuery(object):
27 | def __init__(self,type="soft"):
28 | self.type = type
29 |
30 | def __call__(self,probs,pool):
31 | if self.type == "soft":
32 | return self.softquery(probs,pool)
33 | elif self.type == "hard":
34 | return self.hardquery(probs,pool)
35 |
36 | def softquery(self,logits,pool):
37 | batchsize = logits.size(0)
38 | valid_logits = logits[pool].reshape(batchsize,-1)
39 | max_logits = torch.max(valid_logits,dim=1,keepdim=True)[0].detach()
40 | valid_logits = valid_logits - max_logits #torch.clamp(valid_logits,max = MAX_EXP)
41 | valid_probs = F.softmax(valid_logits, dim=1)
42 | pool = pool[1].reshape(batchsize, -1)
43 | assert pool.size() == valid_probs.size()
44 | m = Categorical(valid_probs)
45 | action_inpool = m.sample()
46 | action = pool[[x for x in range(batchsize)], action_inpool]
47 | return action
48 |
49 | def hardquery(self,logits,pool):
50 | batchsize = logits.size(0)
51 | valid_logits = logits[pool].reshape(batchsize,-1)
52 | max_logits = torch.max(valid_logits,dim=1,keepdim=True)[0].detach()
53 | valid_logits = valid_logits - max_logits #torch.clamp(valid_logits,max = MAX_EXP)
54 | valid_probs = F.softmax(valid_logits, dim=1)
55 |
56 | pool = pool[1].reshape(batchsize, -1)
57 | action_inpool = torch.argmax(valid_probs,dim=1)
58 | action = pool[[x for x in range(batchsize)], action_inpool]
59 | return action
60 |
61 |
62 | def unitTestProbQuery():
63 | probs = F.softmax(torch.randn(4,7)*3,dim=1)
64 | mask = torch.zeros_like(probs)
65 | mask[:,1] = 1
66 | pool = torch.where(mask==0)
67 | print(probs,pool)
68 |
69 | q = ProbQuery(type = "soft")
70 | action = q(probs,pool)
71 | print(action)
72 |
73 | q.type = "hard"
74 | action = q(probs,pool)
75 | print(action)
76 |
77 |
78 | def selectActions(self,logits,pool):
79 | valid_logits = logits[pool].reshape(self.args.batchsize,-1)
80 | max_logits = torch.max(valid_logits,dim=1,keepdim=True)[0].detach()
81 | if self.globel_number %10==0:
82 | logger.info(max_logits)
83 | valid_logits = valid_logits - max_logits #torch.clamp(valid_logits,max = MAX_EXP)
84 |
85 | # valid_logprobs = F.log_softmax(valid_logits,dim=1)
86 | valid_probs = F.softmax(valid_logits, dim=1)
87 | pool = pool[1].reshape(self.args.batchsize,-1)
88 | assert pool.size()==valid_probs.size()
89 |
90 | m = Categorical(valid_probs)
91 | action_inpool = m.sample()
92 | logprob = m.log_prob(action_inpool)
93 | action = pool[[x for x in range(self.args.batchsize)],action_inpool]
94 |
95 | return action,logprob
96 | if __name__=="__main__":
97 | unitTestProbQuery()
98 |
--------------------------------------------------------------------------------
/src/utils/rewardshaper.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from src.utils.common import *
3 |
4 |
5 | class RewardShaper(object):
6 | def __init__(self,args):
7 | self.metric = args.metric
8 | self.shaping = args.shaping
9 | self.entcoef = args.entcoef
10 | self.gamma = 0.5
11 | self.rate = 0.05
12 | self.alpha = args.frweight # to make tune the ratio of finalreward and midreward
13 | if "0" in self.shaping and "1" in self.shaping:
14 | raise ValueError("arguments invalid")
15 | self.hashistorymean,self.hashistoryvar = False,False
16 | if "3" in self.shaping:
17 | self.historymean = np.zeros((1000,args.batchsize))
18 | if "4" in self.shaping:
19 | self.histrvar = np.zeros((1000,args.batchsize))
20 | pass
21 |
22 | def reshape(self,rewards_all,finalrewards_all,logprobs):
23 | rewards_sub, finalrewards = self._roughProcess(rewards_all,finalrewards_all)
24 | self.componentRatio(rewards_sub,finalrewards,logprobs)
25 | rewards = np.zeros_like(rewards_sub)
26 | if "0" in self.shaping:
27 | rewards += rewards_sub
28 | if "1" in self.shaping:
29 | for i in range(rewards_sub.shape[0]-1,0,-1):
30 | rewards_sub[i-1] += self.gamma*rewards_sub[i]
31 | rewards += rewards_sub
32 | if "2" in self.shaping: #2
33 | rewards += finalrewards*self.alpha
34 |
35 | if "3" in self.shaping:
36 | if not self.hashistorymean:
37 | self.historymean[:rewards_sub.shape[0], :] += rewards.mean(1,keepdims=True)
38 | self.hashistorymean = True
39 | else:
40 | self.historymean[:rewards_sub.shape[0], :] = self.historymean[:rewards_sub.shape[0], :] * (1 - self.rate) + self.rate * rewards.mean(1,keepdims=True)
41 | rewards = rewards - self.historymean[:rewards.shape[0], :]
42 |
43 | if "4" in self.shaping:
44 | if not self.hashistoryvar:
45 | self.histrvar[:rewards_sub.shape[0], :] += (rewards**2).mean(1,keepdims=True)
46 | self.hashistoryvar = True
47 | else:
48 | self.histrvar[:rewards_sub.shape[0], :] = self.histrvar[:rewards.shape[0], :] * (
49 | 1 - self.rate) + self.rate * (rewards**2).mean(1,keepdims=True)
50 | rewards = rewards/np.power(self.histrvar[:rewards.shape[0], :],0.5)
51 | return rewards
52 |
53 | def _roughProcess(self,rewards_all,finalrewards_all):
54 | mic,mac = [np.array(x) for x in list(zip(*rewards_all))]
55 | finalmic,finalmac = [np.array(x) for x in finalrewards_all]
56 | if self.metric == "microf1":
57 | rewards = mic
58 | finalrewards = finalmic
59 | elif self.metric == "macrof1":
60 | rewards = mac
61 | finalrewards = finalmac
62 | elif self.metric == "mix":
63 | rewards = (mic+mac)/2
64 | finalrewards = (finalmic+finalmac)/2
65 | else:
66 | raise NotImplementedError("metric <{}> is not implemented".format(self.metric))
67 | rewards_sub = rewards[1:,:]-rewards[:-1,:]
68 | return rewards_sub,finalrewards
69 |
70 | def componentRatio(self,rewards_sub,finalrewards_all,logprobs):
71 | r_mean = np.mean(np.abs(rewards_sub))
72 | f_mean = np.mean(finalrewards_all)
73 | lp_mean = np.mean(np.abs(logprobs))
74 | f_ratio = f_mean/r_mean*self.alpha
75 | lp_ratio = lp_mean/r_mean*self.entcoef
76 | logger.debug("rmean {:.4f},fratio {:.2f}x{:.4f}={:.3f}, "
77 | "lpratio {:.1f}x{:.5f}={:.3f}".format(r_mean,
78 | f_mean/r_mean,self.alpha,f_ratio,
79 | lp_mean/r_mean,self.entcoef,lp_ratio))
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.sparse as sp
3 | import warnings
4 | from sklearn.metrics import f1_score
5 | import torch
6 | import math
7 |
8 | from src.utils.common import *
9 |
10 |
11 | def preprocess_features(features):
12 | """Row-normalize feature matrix and convert to tuple representation"""
13 | rowsum = np.array(features.sum(1))
14 | r_inv = np.power(rowsum, -1).flatten()
15 | r_inv[np.isinf(r_inv)] = 0.
16 | r_mat_inv = np.diag(r_inv)
17 | features = r_mat_inv.dot(features)
18 | return features
19 |
20 | def column_normalize(tens):
21 | ret = tens - tens.mean(axis=0)
22 | return ret
23 |
24 | def normalize_adj(adj):
25 | """Symmetrically normalize adjacency matrix."""
26 | adj = sp.coo_matrix(adj)
27 | rowsum = np.array(adj.sum(1))
28 | d_inv_sqrt = np.power(rowsum, -0.5).flatten()
29 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
30 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
31 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo()
32 |
33 | def preprocess_adj(adj):
34 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation."""
35 | adj_add_diag=adj + sp.eye(adj.shape[0])
36 | adj_normalized = normalize_adj(adj_add_diag)
37 | return adj_normalized.astype(np.float32) #sp.coo_matrix(adj_unnorm)
38 |
39 | def sparse_mx_to_torch_sparse_tensor(sparse_mx):
40 | """Convert a scipy sparse matrix to a torch sparse tensor."""
41 | sparse_mx = sparse_mx.tocoo().astype(np.float32)
42 | indices = torch.from_numpy(
43 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
44 | values = torch.from_numpy(sparse_mx.data)
45 | shape = torch.Size(sparse_mx.shape)
46 | return torch.sparse.FloatTensor(indices, values, shape)
47 |
48 | ##=========================================================================
49 |
50 | def accuracy(y_pred, labels):
51 | if len(labels.size())==1:
52 | y_pred = y_pred.max(1)[1].type_as(labels)
53 | y_pred=y_pred.cpu().detach().numpy()
54 | labels=labels.cpu().numpy()
55 |
56 |
57 | elif len(labels.size())==2:
58 | # print("rawy_pred",y_pred)
59 | y_pred=(y_pred > 0.).cpu().detach().numpy()
60 | labels=labels.cpu().numpy()
61 |
62 | # y_pred = np.zeros_like(y_pred)
63 |
64 | # print("y_pred",y_pred[:10,:])
65 | # print("labels",labels[:10,:])
66 | # exit()
67 |
68 |
69 |
70 | with warnings.catch_warnings():
71 | warnings.simplefilter("ignore")
72 | mic,mac=f1_score(labels, y_pred, average="micro"), f1_score(labels, y_pred, average="macro")
73 | return mic,mac
74 |
75 | def mean_std(L):
76 | if type(L)==np.ndarray:
77 | L=L.tolist()
78 | m=sum(L)/float(len(L))
79 | bias=[(x-m)**2 for x in L]
80 | std=math.sqrt(sum(bias)/float(len(L)-1))
81 | return [float(m)*100.,float(std)*100.]
82 |
83 | ##==========================================================================
84 |
85 | def entropy(tens,multilabel=False):
86 | if multilabel:#Todo
87 | reverse=1-tens
88 | ent_1= -torch.log(torch.clamp(tens, min=1e-7)) * tens
89 | ent_2= -torch.log(torch.clamp(reverse,min=1e-7))*reverse
90 | ent=ent_1+ent_2
91 | entropy=torch.mean(ent,dim=1)
92 | else:
93 | assert type(tens)==torch.Tensor and len(tens.size())==3,"calculating entropy of wrong size"
94 | entropy = - torch.log(torch.clamp(tens, min=1e-7)) * tens
95 | entropy = torch.sum(entropy, dim=2)
96 | return entropy
97 |
98 |
99 | ##==========================================================================
100 |
101 |
102 | class AverageMeter(object):
103 | def __init__(self,name='',ave_step=10):
104 | self.name = name
105 | self.ave_step = ave_step
106 | self.history =[]
107 | self.history_extrem = None
108 | self.S=5
109 |
110 | def update(self,data):
111 | if data is not None:
112 | self.history.append(data)
113 |
114 | def __call__(self):
115 | if len(self.history) == 0:
116 | value = None
117 | else:
118 | cal=self.history[-self.ave_step:]
119 | value = sum(cal)/float(len(cal))
120 | return value
121 |
122 | def should_save(self):
123 | if len(self.history)>self.S*2 and sum(self.history[-self.S:])/float(self.S)> sum(self.history[-self.S*2:])/float(self.S*2):
124 | if self.history_extrem is None :
125 | self.history_extrem =sum(self.history[-self.S:])/float(self.S)
126 | return False
127 | else:
128 | if self.history_extrem < sum(self.history[-self.S:])/float(self.S):
129 | self.history_extrem = sum(self.history[-self.S:])/float(self.S)
130 | return True
131 | else:
132 | return False
133 | else:
134 | return False
135 |
136 |
137 | #===========================================================
138 |
139 | def inspect_grad(model):
140 | name_grad = [(x[0], x[1].grad) for x in model.named_parameters() if x[1].grad is not None]
141 | name, grad = zip(*name_grad)
142 | assert not len(grad) == 0, "no layer requires grad"
143 | mean_grad = [torch.mean(x) for x in grad]
144 | max_grad = [torch.max(x) for x in grad]
145 | min_grad = [torch.min(x) for x in grad]
146 | logger.info("name {}, mean_max min {}".format(name,list(zip(mean_grad, max_grad, min_grad))))
147 |
148 | def inspect_weight(model):
149 | name_weight = [x[1] for x in model.named_parameters() if x[1].grad is not None]
150 | print("network_weight:{}".format(name_weight))
151 |
152 |
153 | #==============================================================
154 |
155 | def common_rate(counts,prediction,seq):
156 | summation = counts.sum(dim=1, keepdim=True)
157 | squaresum = (counts ** 2).sum(dim=1, keepdim=True)
158 | ret = (summation ** 2 - squaresum) / (summation * (summation - 1)+1)
159 | # print("here1")
160 | equal_rate=counts[seq,prediction].reshape(-1,1)/(summation+1)
161 | # print(ret,equal_rate)
162 | return ret,equal_rate
163 |
164 |
--------------------------------------------------------------------------------