├── README.md ├── data ├── citeseer │ ├── bestconfig.txt │ ├── citeseer.edgelist │ ├── citeseer.x.pkl │ ├── citeseer.y.pkl │ ├── processed │ │ └── data.pt │ └── raw │ │ ├── ind.citeseer.allx │ │ ├── ind.citeseer.ally │ │ ├── ind.citeseer.graph │ │ ├── ind.citeseer.test.index │ │ ├── ind.citeseer.tx │ │ ├── ind.citeseer.ty │ │ ├── ind.citeseer.x │ │ └── ind.citeseer.y ├── cora │ ├── bestconfig.txt │ ├── cora.edgelist │ ├── cora.x.pkl │ ├── cora.y.pkl │ ├── processed │ │ └── data.pt │ └── raw │ │ ├── ind.cora.allx │ │ ├── ind.cora.ally │ │ ├── ind.cora.graph │ │ ├── ind.cora.test.index │ │ ├── ind.cora.tx │ │ ├── ind.cora.ty │ │ ├── ind.cora.x │ │ └── ind.cora.y ├── pubmed │ ├── bestconfig.txt │ ├── processed │ │ └── data.pt │ ├── pubmed.edgelist │ ├── pubmed.x.pkl │ ├── pubmed.y.pkl │ └── raw │ │ ├── ind.pubmed.allx │ │ ├── ind.pubmed.ally │ │ ├── ind.pubmed.graph │ │ ├── ind.pubmed.test.index │ │ ├── ind.pubmed.tx │ │ ├── ind.pubmed.ty │ │ ├── ind.pubmed.x │ │ └── ind.pubmed.y └── reddit1401 │ ├── bestconfig.txt │ ├── graph1 │ ├── graph1.edgelist │ ├── graph1.x.pkl │ └── graph1.y.pkl │ ├── graph2 │ ├── graph2.edgelist │ ├── graph2.x.pkl │ └── graph2.y.pkl │ ├── graph3 │ ├── graph3.edgelist │ ├── graph3.x.pkl │ └── graph3.y.pkl │ ├── graph4 │ ├── graph4.edgelist │ ├── graph4.x.pkl │ └── graph4.y.pkl │ ├── graph5 │ ├── graph5.edgelist │ ├── graph5.x.pkl │ └── graph5.y.pkl │ └── graphforvis │ ├── graphforvis.edgelist │ ├── graphforvis.x.pkl │ └── graphforvis.y.pkl ├── mainfig_website.png ├── requirements.txt ├── saved_models ├── pretrain_citeseer+pubmed.pkl ├── pretrain_cora+citeseer.pkl ├── pretrain_cora+pubmed.pkl └── pretrain_reddit1+2.pkl └── src ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc └── train.cpython-36.pyc ├── baselines ├── __init__.py ├── active-learning │ ├── CONTRIBUTING.md │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── requirements.txt │ ├── run_experiment.py │ ├── sampling_methods │ │ ├── __init__.py │ │ ├── bandit_discrete.py │ │ ├── constants.py │ │ ├── graph_density.py │ │ ├── hierarchical_clustering_AL.py │ │ ├── informative_diverse.py │ │ ├── kcenter_greedy.py │ │ ├── margin_AL.py │ │ ├── mixture_of_samplers.py │ │ ├── represent_cluster_centers.py │ │ ├── sampling_def.py │ │ ├── simulate_batch.py │ │ ├── uniform_sampling.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── tree.py │ │ │ └── tree_test.py │ │ └── wrapper_sampler_def.py │ └── utils │ │ ├── __init__.py │ │ ├── allconv.py │ │ ├── chart_data.py │ │ ├── create_data.py │ │ ├── kernel_block_solver.py │ │ ├── small_cnn.py │ │ └── utils.py ├── age.py ├── anrmab.py ├── coreset.py ├── coreset │ ├── compute_distance_mat.py │ ├── configure.sh │ ├── full_solver_gurobi.py │ └── gurobi_solution_parser.py └── sampling_methods │ ├── __init__.py │ ├── bandit_discrete.py │ ├── constants.py │ ├── graph_density.py │ ├── hierarchical_clustering_AL.py │ ├── informative_diverse.py │ ├── kcenter_greedy.py │ ├── margin_AL.py │ ├── mixture_of_samplers.py │ ├── represent_cluster_centers.py │ ├── sampling_def.py │ ├── simulate_batch.py │ ├── uniform_sampling.py │ ├── utils │ ├── __init__.py │ ├── tree.py │ └── tree_test.py │ └── wrapper_sampler_def.py ├── datasetcollecting ├── biggraph.py ├── collecting.py ├── loadreddit.py └── mergeedgelist.py ├── test.py ├── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── classificationnet.cpython-36.pyc ├── common.cpython-36.pyc ├── const.cpython-36.pyc ├── dataloader.cpython-36.pyc ├── env.cpython-36.pyc ├── player.cpython-36.pyc ├── policynet.cpython-36.pyc ├── rewardshaper.cpython-36.pyc └── utils.cpython-36.pyc ├── classificationnet.py ├── common.py ├── const.py ├── dataloader.py ├── env.py ├── player.py ├── policynet.py ├── query.py ├── rewardshaper.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Graph Policy Network for Transferable Active Learning on Graphs 2 | This is the code of the paper **G**raph **P**olicy network for transferable **A**ctive learning on graphs (GPA). 3 | 4 | ## Dependencies 5 | matplotlib==2.2.3 6 | networkx==2.4 7 | scikit-learn==0.21.2 8 | numpy==1.16.3 9 | scipy==1.2.1 10 | torch==1.3.1 11 | 12 | ## Data 13 | We have provided Cora, Pubmed, Citeseer, Reddit1401 whose data format have been processed and can be directly consumed by our code. Reddit1401 is collected from the reddit data source (where 1401 means Janurary, 2014) and preprocessed by ourselves. If you use these graphs in your work, please cite our paper. For the Coauthor_CS and Coauthor_Phy dataset, we don't provide the processed data because they are too large for github repos. If you are interested, please email shengdinghu@gmail.com for the processed data. 14 | 15 | ## Train 16 | 17 | Use ```train.py``` to train the active learning policy on multiple labeled training graphs. Assume that we have two labeled training graphs ```A``` and ```B``` with query budgets of ```x``` and ```y``` respectively, and we want to save the trained model in ```temp.pkl```, then use the following commend: 18 | ``` 19 | python -m src.train --datasets A+B --budgets x+y --save 1 --savename temp 20 | ``` 21 | Please refer to the source code to see how to set the other arguments. 22 | 23 | ## Test 24 | Use ```test.py``` to test the learned active learning policy on unlabeled test graphs. Assume that we have an unlabeled test graph ```G``` with a query budget of ```z```, and we want to test the policy stored in ```temp.pkl```, then use the following commend: 25 | ``` 26 | python -m src.test --method 3 --modelname temp --datasets G --budgets z 27 | ``` 28 | Please refer to the source code to see how to set the other arguments. 29 | 30 | ## Pre-trained Models and Results 31 | We provide several pre-trained models with their test results on the unlabeled test graphs. 32 | For transferable active learning on graphs from the **same** domain, we train on Reddit {1, 2} on test on Reddit {3, 4, 5}. The pre-trained model is saved in ```models/pretrain_reddit1+2.pkl```. The test results are 33 | 34 | | Metric | Reddit 3 | Reddit 4 | Reddit 5 | 35 | | :---: | :---:| :---:| :---: | 36 | | Micro-F1 | 92.51 | 91.49 | 90.71 | 37 | | Macro-F1 | 92.22 | 89.57 | 90.32 | 38 | 39 | For tranferable active learning on graphs across **different** domains, we provide three pre-trained models trained on different training graphs as follows: 40 | 41 | 1. Train on Cora + Citeseer, and test on the remaining graphs. The pre-trained model is saved in ```models/pretrain_cora+citeseer.pkl```. The test results are 42 | 43 | | Metric | Pubmed | Reddit 1 | Reddit 2 | Reddit 3 | Reddit 4 | Reddit 5 | Physics | CS | 44 | | :---: | :---:| :---:| :---: | :---: | :---:| :---:| :---: | :---: | 45 | | Micro-F1 | 77.44 | 88.16 | 95.25 | 92.09 | 91.37 | 90.71 | 87.91 | 87.64 | 46 | | Macro-F1 | 75.28 | 87.84 | 95.04 | 91.77 | 89.50 | 90.30 | 82.57 | 84.45 | 47 | 48 | 2. Train on Cora + Pubmed, and test on the remaining graphs. The pre-trained model is saved in ```models/pretrain_cora+pubmed.pkl```. The test results are 49 | 50 | | Metric | Citeseer | Reddit 1 | Reddit 2 | Reddit 3 | Reddit 4 | Reddit 5 | Physics | CS | 51 | | :---: | :---:| :---:| :---: | :---: | :---:| :---:| :---: | :---: | 52 | | Micro-F1 | 65.76 | 88.14 | 95.14 | 92.08 | 91.05 | 90.38 | 87.14 | 88.15 | 53 | | Macro-F1 | 57.52 | 87.86 | 94.93 | 91.78 | 89.08 | 89.92 | 81.04 | 85.24 | 54 | 55 | 3. Train on Citeseer + Pubmed, and test on the remaining graphs. The pre-trained model is saved in ```models/pretrain_citeseer+pubmed.pkl```. The test results are 56 | 57 | | Metric | Cora | Reddit 1 | Reddit 2 | Reddit 3 | Reddit 4 | Reddit 5 | Physics | CS | 58 | | :---: | :---:| :---:| :---: | :---: | :---:| :---:| :---: | :---: | 59 | | Micro-F1 | 73.40 | 87.57 | 95.08 | 92.07 | 90.99 | 90.53 | 87.06 | 87.00 | 60 | | Macro-F1 | 71.22 | 87.11 | 94.87 | 91.74 | 88.97 | 90.14 | 81.20 | 83.90 | 61 | -------------------------------------------------------------------------------- /data/citeseer/bestconfig.txt: -------------------------------------------------------------------------------- 1 | feature_normalize 1 2 | 3 | -------------------------------------------------------------------------------- /data/citeseer/citeseer.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/citeseer.x.pkl -------------------------------------------------------------------------------- /data/citeseer/citeseer.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/citeseer.y.pkl -------------------------------------------------------------------------------- /data/citeseer/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/processed/data.pt -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.allx -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.ally -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.graph -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.tx -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.ty -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.x -------------------------------------------------------------------------------- /data/citeseer/raw/ind.citeseer.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/citeseer/raw/ind.citeseer.y -------------------------------------------------------------------------------- /data/cora/bestconfig.txt: -------------------------------------------------------------------------------- 1 | feature_normalize 0 2 | 3 | -------------------------------------------------------------------------------- /data/cora/cora.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/cora.x.pkl -------------------------------------------------------------------------------- /data/cora/cora.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/cora.y.pkl -------------------------------------------------------------------------------- /data/cora/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/processed/data.pt -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.allx -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.ally -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.graph -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.tx -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.ty -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.x -------------------------------------------------------------------------------- /data/cora/raw/ind.cora.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/cora/raw/ind.cora.y -------------------------------------------------------------------------------- /data/pubmed/bestconfig.txt: -------------------------------------------------------------------------------- 1 | feature_normalize 1 2 | 3 | -------------------------------------------------------------------------------- /data/pubmed/processed/data.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/processed/data.pt -------------------------------------------------------------------------------- /data/pubmed/pubmed.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/pubmed.x.pkl -------------------------------------------------------------------------------- /data/pubmed/pubmed.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/pubmed.y.pkl -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.allx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.allx -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.ally: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.ally -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.graph: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.graph -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.tx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.tx -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.ty: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.ty -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.x: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.x -------------------------------------------------------------------------------- /data/pubmed/raw/ind.pubmed.y: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/pubmed/raw/ind.pubmed.y -------------------------------------------------------------------------------- /data/reddit1401/bestconfig.txt: -------------------------------------------------------------------------------- 1 | feature_normalize 0 2 | -------------------------------------------------------------------------------- /data/reddit1401/graph1/graph1.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph1/graph1.x.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph1/graph1.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph1/graph1.y.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph2/graph2.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph2/graph2.x.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph2/graph2.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph2/graph2.y.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph3/graph3.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph3/graph3.x.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph3/graph3.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph3/graph3.y.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph4/graph4.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph4/graph4.x.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph4/graph4.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph4/graph4.y.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph5/graph5.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph5/graph5.x.pkl -------------------------------------------------------------------------------- /data/reddit1401/graph5/graph5.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graph5/graph5.y.pkl -------------------------------------------------------------------------------- /data/reddit1401/graphforvis/graphforvis.edgelist: -------------------------------------------------------------------------------- 1 | 85 2 | 16 32 3 | 72 31 4 | 1 31 5 | 69 35 6 | 78 35 7 | 69 74 8 | 35 29 9 | 15 70 10 | 32 70 11 | 32 7 12 | 68 7 13 | 56 30 14 | 32 30 15 | 74 48 16 | 72 9 17 | 34 9 18 | 79 9 19 | 54 42 20 | 78 61 21 | 35 61 22 | 61 17 23 | 72 55 24 | 31 55 25 | 1 67 26 | 9 67 27 | 55 67 28 | 15 84 29 | 2 84 30 | 70 84 31 | 29 84 32 | 4 84 33 | 32 84 34 | 30 84 35 | 48 25 36 | 31 25 37 | 17 25 38 | 35 25 39 | 74 25 40 | 27 25 41 | 61 25 42 | 78 25 43 | 25 53 44 | 2 40 45 | 32 40 46 | 25 40 47 | 62 40 48 | 78 0 49 | 72 10 50 | 37 10 51 | 2 46 52 | 1 12 53 | 35 77 54 | 25 77 55 | 75 77 56 | 60 77 57 | 51 77 58 | 40 77 59 | 31 77 60 | 74 77 61 | 10 77 62 | 73 77 63 | 76 77 64 | 69 77 65 | 23 77 66 | 22 77 67 | 44 77 68 | 4 77 69 | 84 77 70 | 17 77 71 | 27 77 72 | 77 18 73 | 60 18 74 | 77 36 75 | 51 36 76 | 18 36 77 | 77 39 78 | 77 58 79 | 77 63 80 | 36 63 81 | 18 63 82 | 77 83 83 | 1 6 84 | 54 6 85 | 42 6 86 | 67 6 87 | 37 6 88 | 74 26 89 | 35 26 90 | 25 26 91 | 27 26 92 | 69 26 93 | 35 43 94 | 72 65 95 | 9 65 96 | 37 65 97 | 10 65 98 | 32 14 99 | 35 71 100 | 77 21 101 | 74 49 102 | 77 66 103 | 77 80 104 | 26 80 105 | 35 80 106 | 25 80 107 | 74 80 108 | 49 80 109 | 61 8 110 | 25 8 111 | 80 8 112 | 74 8 113 | 26 8 114 | 73 11 115 | 84 33 116 | 60 59 117 | 17 57 118 | 25 57 119 | 35 57 120 | 69 57 121 | 84 41 122 | 14 64 123 | 77 82 124 | 32 5 125 | 2 5 126 | 50 5 127 | 25 28 128 | 77 47 129 | 63 47 130 | 77 81 131 | 18 81 132 | 65 3 133 | 67 3 134 | 72 3 135 | 2 13 136 | 5 13 137 | 44 24 138 | 81 24 139 | 76 24 140 | 77 24 141 | 40 45 142 | 25 19 143 | 77 19 144 | 49 19 145 | 35 19 146 | 57 19 147 | 80 19 148 | 77 38 149 | 18 38 150 | 35 52 151 | 77 52 152 | 18 52 153 | 57 52 154 | 19 52 155 | 61 52 156 | 40 20 157 | 52 20 158 | -------------------------------------------------------------------------------- /data/reddit1401/graphforvis/graphforvis.x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graphforvis/graphforvis.x.pkl -------------------------------------------------------------------------------- /data/reddit1401/graphforvis/graphforvis.y.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/data/reddit1401/graphforvis/graphforvis.y.pkl -------------------------------------------------------------------------------- /mainfig_website.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/mainfig_website.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==2.2.3 2 | networkx==2.4 3 | scikit-learn==0.21.2 4 | numpy==1.16.3 5 | scipy==1.2.1 6 | torch==1.3.1 7 | 8 | -------------------------------------------------------------------------------- /saved_models/pretrain_citeseer+pubmed.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_citeseer+pubmed.pkl -------------------------------------------------------------------------------- /saved_models/pretrain_cora+citeseer.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_cora+citeseer.pkl -------------------------------------------------------------------------------- /saved_models/pretrain_cora+pubmed.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_cora+pubmed.pkl -------------------------------------------------------------------------------- /saved_models/pretrain_reddit1+2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/saved_models/pretrain_reddit1+2.pkl -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/__init__.py -------------------------------------------------------------------------------- /src/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /src/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /src/baselines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShengdingHu/GraphPolicyNetworkActiveLearning/8231b60a0afa86aae1849ed107752106bbe77142/src/baselines/__init__.py -------------------------------------------------------------------------------- /src/baselines/active-learning/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution, 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | -------------------------------------------------------------------------------- /src/baselines/active-learning/README.md: -------------------------------------------------------------------------------- 1 | # Active Learning Playground 2 | 3 | ## Introduction 4 | 5 | This is a python module for experimenting with different active learning 6 | algorithms. There are a few key components to running active learning 7 | experiments: 8 | 9 | * Main experiment script is 10 | [`run_experiment.py`](run_experiment.py) 11 | with many flags for different run options. 12 | 13 | * Supported datasets can be downloaded to a specified directory by running 14 | [`utils/create_data.py`](utils/create_data.py). 15 | 16 | * Supported active learning methods are in 17 | [`sampling_methods`](sampling_methods/). 18 | 19 | Below I will go into each component in more detail. 20 | 21 | DISCLAIMER: This is not an official Google product. 22 | 23 | ## Setup 24 | The dependencies are in [`requirements.txt`](requirements.txt). Please make sure these packages are 25 | installed before running experiments. If GPU capable `tensorflow` is desired, please follow 26 | instructions [here](https://www.tensorflow.org/install/). 27 | 28 | It is highly suggested that you install all dependencies into a separate `virtualenv` for 29 | easy package management. 30 | 31 | ## Getting benchmark datasets 32 | 33 | By default the datasets are saved to `/tmp/data`. You can specify another directory via the 34 | `--save_dir` flag. 35 | 36 | Redownloading all the datasets will be very time consuming so please be patient. 37 | You can specify a subset of the data to download by passing in a comma separated 38 | string of datasets via the `--datasets` flag. 39 | 40 | ## Running experiments 41 | 42 | There are a few key flags for 43 | [`run_experiment.py`](run_experiment.py): 44 | 45 | * `dataset`: name of the dataset, must match the save name used in 46 | `create_data.py`. Must also exist in the data_dir. 47 | 48 | * `sampling_method`: active learning method to use. Must be specified in 49 | [`sampling_methods/constants.py`](sampling_methods/constants.py). 50 | 51 | * `warmstart_size`: initial batch of uniformly sampled examples to use as seed 52 | data. Float indicates percentage of total training data and integer 53 | indicates raw size. 54 | 55 | * `batch_size`: number of datapoints to request in each batch. Float indicates 56 | percentage of total training data and integer indicates raw size. 57 | 58 | * `score_method`: model to use to evaluate the performance of the sampling 59 | method. Must be in `get_model` method of 60 | [`utils/utils.py`](utils/utils.py). 61 | 62 | * `data_dir`: directory with saved datasets. 63 | 64 | * `save_dir`: directory to save results. 65 | 66 | This is just a subset of all the flags. There are also options for 67 | preprocessing, introducing labeling noise, dataset subsampling, and using a 68 | different model to select than to score/evaluate. 69 | 70 | ## Available active learning methods 71 | 72 | All named active learning methods are in 73 | [`sampling_methods/constants.py`](sampling_methods/constants.py). 74 | 75 | You can also specify a mixture of active learning methods by following the 76 | pattern of `[sampling_method]-[mixture_weight]` separated by dashes; i.e. 77 | `mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34`. 78 | 79 | Some supported sampling methods include: 80 | 81 | * Uniform: samples are selected via uniform sampling. 82 | 83 | * Margin: uncertainty based sampling method. 84 | 85 | * Informative and diverse: margin and cluster based sampling method. 86 | 87 | * k-center greedy: representative strategy that greedily forms a batch of 88 | points to minimize maximum distance from a labeled point. 89 | 90 | * Graph density: representative strategy that selects points in dense regions 91 | of pool. 92 | 93 | * Exp3 bandit: meta-active learning method that tries to learns optimal 94 | sampling method using a popular multi-armed bandit algorithm. 95 | 96 | ### Adding new active learning methods 97 | 98 | Implement either a base sampler that inherits from 99 | [`SamplingMethod`](sampling_methods/sampling_def.py) 100 | or a meta-sampler that calls base samplers which inherits from 101 | [`WrapperSamplingMethod`](sampling_methods/wrapper_sampler_def.py). 102 | 103 | The only method that must be implemented by any sampler is `select_batch_`, 104 | which can have arbitrary named arguments. The only restriction is that the name 105 | for the same input must be consistent across all the samplers (i.e. the indices 106 | for already selected examples all have the same name across samplers). Adding a 107 | new named argument that hasn't been used in other sampling methods will require 108 | feeding that into the `select_batch` call in 109 | [`run_experiment.py`](run_experiment.py). 110 | 111 | After implementing your sampler, be sure to add it to 112 | [`constants.py`](sampling_methods/constants.py) 113 | so that it can be called from 114 | [`run_experiment.py`](run_experiment.py). 115 | 116 | ## Available models 117 | 118 | All available models are in the `get_model` method of 119 | [`utils/utils.py`](utils/utils.py). 120 | 121 | Supported methods: 122 | 123 | * Linear SVM: scikit method with grid search wrapper for regularization 124 | parameter. 125 | 126 | * Kernel SVM: scikit method with grid search wrapper for regularization 127 | parameter. 128 | 129 | * Logistc Regression: scikit method with grid search wrapper for 130 | regularization parameter. 131 | 132 | * Small CNN: 4 layer CNN optimized using rmsprop implemented in Keras with 133 | tensorflow backend. 134 | 135 | * Kernel Least Squares Classification: block gradient descient solver that can 136 | use multiple cores so is often faster than scikit Kernel SVM. 137 | 138 | ### Adding new models 139 | 140 | New models must follow the scikit learn api and implement the following methods 141 | 142 | * `fit(X, y[, sample_weight])`: fit the model to the input features and 143 | target. 144 | 145 | * `predict(X)`: predict the value of the input features. 146 | 147 | * `score(X, y)`: returns target metric given test features and test targets. 148 | 149 | * `decision_function(X)` (optional): return class probabilities, distance to 150 | decision boundaries, or other metric that can be used by margin sampler as a 151 | measure of uncertainty. 152 | 153 | See 154 | [`small_cnn.py`](utils/small_cnn.py) 155 | for an example. 156 | 157 | After implementing your new model, be sure to add it to `get_model` method of 158 | [`utils/utils.py`](utils/utils.py). 159 | 160 | Currently models must be added on a one-off basis and not all scikit-learn 161 | classifiers are supported due to the need for user input on whether and how to 162 | tune the hyperparameters of the model. However, it is very easy to add a 163 | scikit-learn model with hyperparameter search wrapped around as a supported 164 | model. 165 | 166 | ## Collecting results and charting 167 | 168 | The 169 | [`utils/chart_data.py`](utils/chart_data.py) 170 | script handles processing of data and charting for a specified dataset and 171 | source directory. 172 | -------------------------------------------------------------------------------- /src/baselines/active-learning/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /src/baselines/active-learning/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | numpy>=1.13 3 | scipy>=0.19 4 | pandas>=0.20 5 | scikit-learn>=0.19 6 | matplotlib>=2.0.2 7 | 8 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/bandit_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Bandit wrapper around base AL sampling methods. 16 | 17 | Assumes adversarial multi-armed bandit setting where arms correspond to 18 | mixtures of different AL methods. 19 | 20 | Uses EXP3 algorithm to decide which AL method to use to create the next batch. 21 | Similar to Hsu & Lin 2015, Active Learning by Learning. 22 | https://www.csie.ntu.edu.tw/~htlin/paper/doc/aaai15albl.pdf 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import numpy as np 30 | 31 | from sampling_methods.wrapper_sampler_def import AL_MAPPING, WrapperSamplingMethod 32 | 33 | 34 | class BanditDiscreteSampler(WrapperSamplingMethod): 35 | """Wraps EXP3 around mixtures of indicated methods. 36 | 37 | Uses EXP3 mult-armed bandit algorithm to select sampler methods. 38 | """ 39 | def __init__(self, 40 | X, 41 | y, 42 | seed, 43 | reward_function = lambda AL_acc: AL_acc[-1], 44 | gamma=0.5, 45 | samplers=[{'methods':('margin','uniform'),'weights':(0,1)}, 46 | {'methods':('margin','uniform'),'weights':(1,0)}]): 47 | """Initializes sampler with indicated gamma and arms. 48 | 49 | Args: 50 | X: training data 51 | y: labels, may need to be input into base samplers 52 | seed: seed to use for random sampling 53 | reward_function: reward based on previously observed accuracies. Assumes 54 | that the input is a sequence of observed accuracies. Will ultimately be 55 | a class method and may need access to other class properties. 56 | gamma: weight on uniform mixture. Arm probability updates are a weighted 57 | mixture of uniform and an exponentially weighted distribution. 58 | Lower gamma more aggressively updates based on observed rewards. 59 | samplers: list of dicts with two fields 60 | 'samplers': list of named samplers 61 | 'weights': percentage of batch to allocate to each sampler 62 | """ 63 | 64 | self.name = 'bandit_discrete' 65 | np.random.seed(seed) 66 | self.X = X 67 | self.y = y 68 | self.seed = seed 69 | self.initialize_samplers(samplers) 70 | 71 | self.gamma = gamma 72 | self.n_arms = len(samplers) 73 | self.reward_function = reward_function 74 | 75 | self.pull_history = [] 76 | self.acc_history = [] 77 | self.w = np.ones(self.n_arms) 78 | self.x = np.zeros(self.n_arms) 79 | self.p = self.w / (1.0 * self.n_arms) 80 | self.probs = [] 81 | 82 | def update_vars(self, arm_pulled): 83 | reward = self.reward_function(self.acc_history) 84 | self.x = np.zeros(self.n_arms) 85 | self.x[arm_pulled] = reward / self.p[arm_pulled] 86 | self.w = self.w * np.exp(self.gamma * self.x / self.n_arms) 87 | self.p = ((1.0 - self.gamma) * self.w / sum(self.w) 88 | + self.gamma / self.n_arms) 89 | print(self.p) 90 | self.probs.append(self.p) 91 | 92 | def select_batch_(self, already_selected, N, eval_acc, **kwargs): 93 | """Returns batch of datapoints sampled using mixture of AL_methods. 94 | 95 | Assumes that data has already been shuffled. 96 | 97 | Args: 98 | already_selected: index of datapoints already selected 99 | N: batch size 100 | eval_acc: accuracy of model trained after incorporating datapoints from 101 | last recommended batch 102 | 103 | Returns: 104 | indices of points selected to label 105 | """ 106 | # Update observed reward and arm probabilities 107 | self.acc_history.append(eval_acc) 108 | if len(self.pull_history) > 0: 109 | self.update_vars(self.pull_history[-1]) 110 | # Sample an arm 111 | arm = np.random.choice(range(self.n_arms), p=self.p) 112 | self.pull_history.append(arm) 113 | kwargs['N'] = N 114 | kwargs['already_selected'] = already_selected 115 | sample = self.samplers[arm].select_batch(**kwargs) 116 | return sample 117 | 118 | def to_dict(self): 119 | output = {} 120 | output['samplers'] = self.base_samplers 121 | output['arm_probs'] = self.probs 122 | output['pull_history'] = self.pull_history 123 | output['rewards'] = self.acc_history 124 | return output 125 | 126 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Controls imports to fill up dictionary of different sampling methods. 16 | """ 17 | 18 | from functools import partial 19 | AL_MAPPING = {} 20 | 21 | 22 | def get_base_AL_mapping(): 23 | from sampling_methods.margin_AL import MarginAL 24 | from sampling_methods.informative_diverse import InformativeClusterDiverseSampler 25 | from sampling_methods.hierarchical_clustering_AL import HierarchicalClusterAL 26 | from sampling_methods.uniform_sampling import UniformSampling 27 | from sampling_methods.represent_cluster_centers import RepresentativeClusterMeanSampling 28 | from sampling_methods.graph_density import GraphDensitySampler 29 | from sampling_methods.kcenter_greedy import kCenterGreedy 30 | AL_MAPPING['margin'] = MarginAL 31 | AL_MAPPING['informative_diverse'] = InformativeClusterDiverseSampler 32 | AL_MAPPING['hierarchical'] = HierarchicalClusterAL 33 | AL_MAPPING['uniform'] = UniformSampling 34 | AL_MAPPING['margin_cluster_mean'] = RepresentativeClusterMeanSampling 35 | AL_MAPPING['graph_density'] = GraphDensitySampler 36 | AL_MAPPING['kcenter'] = kCenterGreedy 37 | 38 | 39 | def get_all_possible_arms(): 40 | from sampling_methods.mixture_of_samplers import MixtureOfSamplers 41 | AL_MAPPING['mixture_of_samplers'] = MixtureOfSamplers 42 | 43 | 44 | def get_wrapper_AL_mapping(): 45 | from sampling_methods.bandit_discrete import BanditDiscreteSampler 46 | from sampling_methods.simulate_batch import SimulateBatchSampler 47 | AL_MAPPING['bandit_mixture'] = partial( 48 | BanditDiscreteSampler, 49 | samplers=[{ 50 | 'methods': ['margin', 'uniform'], 51 | 'weights': [0, 1] 52 | }, { 53 | 'methods': ['margin', 'uniform'], 54 | 'weights': [0.25, 0.75] 55 | }, { 56 | 'methods': ['margin', 'uniform'], 57 | 'weights': [0.5, 0.5] 58 | }, { 59 | 'methods': ['margin', 'uniform'], 60 | 'weights': [0.75, 0.25] 61 | }, { 62 | 'methods': ['margin', 'uniform'], 63 | 'weights': [1, 0] 64 | }]) 65 | AL_MAPPING['bandit_discrete'] = partial( 66 | BanditDiscreteSampler, 67 | samplers=[{ 68 | 'methods': ['margin', 'uniform'], 69 | 'weights': [0, 1] 70 | }, { 71 | 'methods': ['margin', 'uniform'], 72 | 'weights': [1, 0] 73 | }]) 74 | AL_MAPPING['simulate_batch_mixture'] = partial( 75 | SimulateBatchSampler, 76 | samplers=({ 77 | 'methods': ['margin', 'uniform'], 78 | 'weights': [1, 0] 79 | }, { 80 | 'methods': ['margin', 'uniform'], 81 | 'weights': [0.5, 0.5] 82 | }, { 83 | 'methods': ['margin', 'uniform'], 84 | 'weights': [0, 1] 85 | }), 86 | n_sims=5, 87 | train_per_sim=10, 88 | return_best_sim=False) 89 | AL_MAPPING['simulate_batch_best_sim'] = partial( 90 | SimulateBatchSampler, 91 | samplers=[{ 92 | 'methods': ['margin', 'uniform'], 93 | 'weights': [1, 0] 94 | }], 95 | n_sims=10, 96 | train_per_sim=10, 97 | return_type='best_sim') 98 | AL_MAPPING['simulate_batch_frequency'] = partial( 99 | SimulateBatchSampler, 100 | samplers=[{ 101 | 'methods': ['margin', 'uniform'], 102 | 'weights': [1, 0] 103 | }], 104 | n_sims=10, 105 | train_per_sim=10, 106 | return_type='frequency') 107 | 108 | def get_mixture_of_samplers(name): 109 | assert 'mixture_of_samplers' in name 110 | if 'mixture_of_samplers' not in AL_MAPPING: 111 | raise KeyError('Mixture of Samplers not yet loaded.') 112 | args = name.split('-')[1:] 113 | samplers = args[0::2] 114 | weights = args[1::2] 115 | weights = [float(w) for w in weights] 116 | assert sum(weights) == 1 117 | mixture = {'methods': samplers, 'weights': weights} 118 | print(mixture) 119 | return partial(AL_MAPPING['mixture_of_samplers'], mixture=mixture) 120 | 121 | 122 | def get_AL_sampler(name): 123 | if name in AL_MAPPING and name != 'mixture_of_samplers': 124 | return AL_MAPPING[name] 125 | if 'mixture_of_samplers' in name: 126 | return get_mixture_of_samplers(name) 127 | raise NotImplementedError('The specified sampler is not available.') 128 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/graph_density.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Diversity promoting sampling method that uses graph density to determine 16 | most representative points. 17 | 18 | This is an implementation of the method described in 19 | https://www.mpi-inf.mpg.de/fileadmin/inf/d2/Research_projects_files/EbertCVPR2012.pdf 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import copy 27 | 28 | from sklearn.neighbors import kneighbors_graph 29 | from sklearn.metrics import pairwise_distances 30 | import numpy as np 31 | from sampling_methods.sampling_def import SamplingMethod 32 | 33 | 34 | class GraphDensitySampler(SamplingMethod): 35 | """Diversity promoting sampling method that uses graph density to determine 36 | most representative points. 37 | """ 38 | 39 | def __init__(self, X, y, seed): 40 | self.name = 'graph_density' 41 | self.X = X 42 | self.flat_X = self.flatten_X() 43 | # Set gamma for gaussian kernel to be equal to 1/n_features 44 | self.gamma = 1. / self.X.shape[1] 45 | self.compute_graph_density() 46 | 47 | def compute_graph_density(self, n_neighbor=10): 48 | # kneighbors graph is constructed using k=10 49 | connect = kneighbors_graph(self.flat_X, n_neighbor,p=1) 50 | # Make connectivity matrix symmetric, if a point is a k nearest neighbor of 51 | # another point, make it vice versa 52 | neighbors = connect.nonzero() 53 | inds = zip(neighbors[0],neighbors[1]) 54 | # Graph edges are weighted by applying gaussian kernel to manhattan dist. 55 | # By default, gamma for rbf kernel is equal to 1/n_features but may 56 | # get better results if gamma is tuned. 57 | for entry in inds: 58 | i = entry[0] 59 | j = entry[1] 60 | distance = pairwise_distances(self.flat_X[[i]],self.flat_X[[j]],metric='manhattan') 61 | distance = distance[0,0] 62 | weight = np.exp(-distance * self.gamma) 63 | connect[i,j] = weight 64 | connect[j,i] = weight 65 | self.connect = connect 66 | # Define graph density for an observation to be sum of weights for all 67 | # edges to the node representing the datapoint. Normalize sum weights 68 | # by total number of neighbors. 69 | self.graph_density = np.zeros(self.X.shape[0]) 70 | for i in np.arange(self.X.shape[0]): 71 | self.graph_density[i] = connect[i,:].sum() / (connect[i,:]>0).sum() 72 | self.starting_density = copy.deepcopy(self.graph_density) 73 | 74 | def select_batch_(self, N, already_selected, **kwargs): 75 | # If a neighbor has already been sampled, reduce the graph density 76 | # for its direct neighbors to promote diversity. 77 | batch = set() 78 | self.graph_density[already_selected] = min(self.graph_density) - 1 79 | while len(batch) < N: 80 | selected = np.argmax(self.graph_density) 81 | neighbors = (self.connect[selected,:] > 0).nonzero()[1] 82 | self.graph_density[neighbors] = self.graph_density[neighbors] - self.graph_density[selected] 83 | batch.add(selected) 84 | self.graph_density[already_selected] = min(self.graph_density) - 1 85 | self.graph_density[list(batch)] = min(self.graph_density) - 1 86 | return list(batch) 87 | 88 | def to_dict(self): 89 | output = {} 90 | output['connectivity'] = self.connect 91 | output['graph_density'] = self.starting_density 92 | return output -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/informative_diverse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Informative and diverse batch sampler that samples points with small margin 16 | while maintaining same distribution over clusters as entire training data. 17 | 18 | Batch is created by sorting datapoints by increasing margin and then growing 19 | the batch greedily. A point is added to the batch if the result batch still 20 | respects the constraint that the cluster distribution of the batch will 21 | match the cluster distribution of the entire training set. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from sklearn.cluster import MiniBatchKMeans 29 | import numpy as np 30 | from sampling_methods.sampling_def import SamplingMethod 31 | 32 | 33 | class InformativeClusterDiverseSampler(SamplingMethod): 34 | """Selects batch based on informative and diverse criteria. 35 | 36 | Returns highest uncertainty lowest margin points while maintaining 37 | same distribution over clusters as entire dataset. 38 | """ 39 | 40 | def __init__(self, X, y, seed): 41 | self.name = 'informative_and_diverse' 42 | self.X = X 43 | self.flat_X = self.flatten_X() 44 | # y only used for determining how many clusters there should be 45 | # probably not practical to assume we know # of classes before hand 46 | # should also probably scale with dimensionality of data 47 | self.y = y 48 | self.n_clusters = len(list(set(y))) 49 | self.cluster_model = MiniBatchKMeans(n_clusters=self.n_clusters) 50 | self.cluster_data() 51 | 52 | def cluster_data(self): 53 | # Probably okay to always use MiniBatchKMeans 54 | # Should standardize data before clustering 55 | # Can cluster on standardized data but train on raw features if desired 56 | self.cluster_model.fit(self.flat_X) 57 | unique, counts = np.unique(self.cluster_model.labels_, return_counts=True) 58 | self.cluster_prob = counts/sum(counts) 59 | self.cluster_labels = self.cluster_model.labels_ 60 | 61 | def select_batch_(self, model, already_selected, N, **kwargs): 62 | """Returns a batch of size N using informative and diverse selection. 63 | 64 | Args: 65 | model: scikit learn model with decision_function implemented 66 | already_selected: index of datapoints already selected 67 | N: batch size 68 | 69 | Returns: 70 | indices of points selected to add using margin active learner 71 | """ 72 | # TODO(lishal): have MarginSampler and this share margin function 73 | try: 74 | distances = model.decision_function(self.X) 75 | except: 76 | distances = model.predict_proba(self.X) 77 | if len(distances.shape) < 2: 78 | min_margin = abs(distances) 79 | else: 80 | sort_distances = np.sort(distances, 1)[:, -2:] 81 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 82 | rank_ind = np.argsort(min_margin) 83 | rank_ind = [i for i in rank_ind if i not in already_selected] 84 | new_batch_cluster_counts = [0 for _ in range(self.n_clusters)] 85 | new_batch = [] 86 | for i in rank_ind: 87 | if len(new_batch) == N: 88 | break 89 | label = self.cluster_labels[i] 90 | if new_batch_cluster_counts[label] / N < self.cluster_prob[label]: 91 | new_batch.append(i) 92 | new_batch_cluster_counts[label] += 1 93 | n_slot_remaining = N - len(new_batch) 94 | batch_filler = list(set(rank_ind) - set(already_selected) - set(new_batch)) 95 | new_batch.extend(batch_filler[0:n_slot_remaining]) 96 | return new_batch 97 | 98 | def to_dict(self): 99 | output = {} 100 | output['cluster_membership'] = self.cluster_labels 101 | return output 102 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/kcenter_greedy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Returns points that minimizes the maximum distance of any point to a center. 16 | 17 | Implements the k-Center-Greedy method in 18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for 19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 20 | 21 | Distance metric defaults to l2 distance. Features used to calculate distance 22 | are either raw features or if a model has transform method then uses the output 23 | of model.transform(X). 24 | 25 | Can be extended to a robust k centers algorithm that ignores a certain number of 26 | outlier datapoints. Resulting centers are solution to multiple integer program. 27 | """ 28 | 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | from __future__ import print_function 32 | 33 | import numpy as np 34 | from sklearn.metrics import pairwise_distances 35 | from sampling_methods.sampling_def import SamplingMethod 36 | 37 | 38 | class kCenterGreedy(SamplingMethod): 39 | 40 | def __init__(self, X, y, seed, metric='euclidean'): 41 | self.X = X 42 | self.y = y 43 | self.flat_X = self.flatten_X() 44 | self.name = 'kcenter' 45 | self.features = self.flat_X 46 | self.metric = metric 47 | self.min_distances = None 48 | self.n_obs = self.X.shape[0] 49 | self.already_selected = [] 50 | 51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False): 52 | """Update min distances given cluster centers. 53 | 54 | Args: 55 | cluster_centers: indices of cluster centers 56 | only_new: only calculate distance for newly selected points and update 57 | min_distances. 58 | rest_dist: whether to reset min_distances. 59 | """ 60 | 61 | if reset_dist: 62 | self.min_distances = None 63 | if only_new: 64 | cluster_centers = [d for d in cluster_centers 65 | if d not in self.already_selected] 66 | if cluster_centers: 67 | # Update min_distances for all examples given new cluster center. 68 | x = self.features[cluster_centers] 69 | dist = pairwise_distances(self.features, x, metric=self.metric) 70 | 71 | if self.min_distances is None: 72 | self.min_distances = np.min(dist, axis=1).reshape(-1,1) 73 | else: 74 | self.min_distances = np.minimum(self.min_distances, dist) 75 | 76 | def select_batch_(self, model, already_selected, N, **kwargs): 77 | """ 78 | Diversity promoting active learning method that greedily forms a batch 79 | to minimize the maximum distance to a cluster center among all unlabeled 80 | datapoints. 81 | 82 | Args: 83 | model: model with scikit-like API with decision_function implemented 84 | already_selected: index of datapoints already selected 85 | N: batch size 86 | 87 | Returns: 88 | indices of points selected to minimize distance to cluster centers 89 | """ 90 | 91 | try: 92 | # Assumes that the transform function takes in original data and not 93 | # flattened data. 94 | print('Getting transformed features...') 95 | self.features = model.transform(self.X) 96 | print('Calculating distances...') 97 | self.update_distances(already_selected, only_new=False, reset_dist=True) 98 | except: 99 | print('Using flat_X as features.') 100 | self.update_distances(already_selected, only_new=True, reset_dist=False) 101 | 102 | new_batch = [] 103 | 104 | for _ in range(N): 105 | if self.already_selected is None: 106 | # Initialize centers with a randomly selected datapoint 107 | ind = np.random.choice(np.arange(self.n_obs)) 108 | else: 109 | ind = np.argmax(self.min_distances) 110 | # New examples should not be in already selected since those points 111 | # should have min_distance of zero to a cluster center. 112 | assert ind not in already_selected 113 | 114 | self.update_distances([ind], only_new=True, reset_dist=False) 115 | new_batch.append(ind) 116 | print('Maximum distance from cluster centers is %0.2f' 117 | % max(self.min_distances)) 118 | 119 | 120 | self.already_selected = already_selected 121 | 122 | return new_batch 123 | 124 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/margin_AL.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Margin based AL method. 16 | 17 | Samples in batches based on margin scores. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | from sampling_methods.sampling_def import SamplingMethod 26 | 27 | 28 | class MarginAL(SamplingMethod): 29 | def __init__(self, X, y, seed): 30 | self.X = X 31 | self.y = y 32 | self.name = 'margin' 33 | 34 | def select_batch_(self, model, already_selected, N, **kwargs): 35 | """Returns batch of datapoints with smallest margin/highest uncertainty. 36 | 37 | For binary classification, can just take the absolute distance to decision 38 | boundary for each point. 39 | For multiclass classification, must consider the margin between distance for 40 | top two most likely classes. 41 | 42 | Args: 43 | model: scikit learn model with decision_function implemented 44 | already_selected: index of datapoints already selected 45 | N: batch size 46 | 47 | Returns: 48 | indices of points selected to add using margin active learner 49 | """ 50 | 51 | try: 52 | distances = model.decision_function(self.X) 53 | except: 54 | distances = model.predict_proba(self.X) 55 | if len(distances.shape) < 2: 56 | min_margin = abs(distances) 57 | else: 58 | sort_distances = np.sort(distances, 1)[:, -2:] 59 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 60 | rank_ind = np.argsort(min_margin) 61 | rank_ind = [i for i in rank_ind if i not in already_selected] 62 | active_samples = rank_ind[0:N] 63 | return active_samples 64 | 65 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/mixture_of_samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mixture of base sampling strategies 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import copy 24 | 25 | from sampling_methods.sampling_def import SamplingMethod 26 | from sampling_methods.constants import AL_MAPPING, get_base_AL_mapping 27 | 28 | get_base_AL_mapping() 29 | 30 | 31 | class MixtureOfSamplers(SamplingMethod): 32 | """Samples according to mixture of base sampling methods. 33 | 34 | If duplicate points are selected by the mixed strategies when forming the batch 35 | then the remaining slots are divided according to mixture weights and 36 | another partial batch is requested until the batch is full. 37 | """ 38 | def __init__(self, 39 | X, 40 | y, 41 | seed, 42 | mixture={'methods': ('margin', 'uniform'), 43 | 'weight': (0.5, 0.5)}, 44 | samplers=None): 45 | self.X = X 46 | self.y = y 47 | self.name = 'mixture_of_samplers' 48 | self.sampling_methods = mixture['methods'] 49 | self.sampling_weights = dict(zip(mixture['methods'], mixture['weights'])) 50 | self.seed = seed 51 | # A list of initialized samplers is allowed as an input because 52 | # for AL_methods that search over different mixtures, may want mixtures to 53 | # have shared AL_methods so that initialization is only performed once for 54 | # computation intensive methods like HierarchicalClusteringAL and 55 | # states are shared between mixtures. 56 | # If initialized samplers are not provided, initialize them ourselves. 57 | if samplers is None: 58 | self.samplers = {} 59 | self.initialize(self.sampling_methods) 60 | else: 61 | self.samplers = samplers 62 | self.history = [] 63 | 64 | def initialize(self, samplers): 65 | self.samplers = {} 66 | for s in samplers: 67 | self.samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed) 68 | 69 | def select_batch_(self, already_selected, N, **kwargs): 70 | """Returns batch of datapoints selected according to mixture weights. 71 | 72 | Args: 73 | already_included: index of datapoints already selected 74 | N: batch size 75 | 76 | Returns: 77 | indices of points selected to add using margin active learner 78 | """ 79 | kwargs['already_selected'] = copy.copy(already_selected) 80 | inds = set() 81 | self.selected_by_sampler = {} 82 | for s in self.sampling_methods: 83 | self.selected_by_sampler[s] = [] 84 | effective_N = 0 85 | while len(inds) < N: 86 | effective_N += N - len(inds) 87 | for s in self.sampling_methods: 88 | if len(inds) < N: 89 | batch_size = min(max(int(self.sampling_weights[s] * effective_N), 1), N) 90 | sampler = self.samplers[s] 91 | kwargs['N'] = batch_size 92 | s_inds = sampler.select_batch(**kwargs) 93 | for ind in s_inds: 94 | if ind not in self.selected_by_sampler[s]: 95 | self.selected_by_sampler[s].append(ind) 96 | s_inds = [d for d in s_inds if d not in inds] 97 | s_inds = s_inds[0 : min(len(s_inds), N-len(inds))] 98 | inds.update(s_inds) 99 | self.history.append(copy.deepcopy(self.selected_by_sampler)) 100 | return list(inds) 101 | 102 | def to_dict(self): 103 | output = {} 104 | output['history'] = self.history 105 | output['samplers'] = self.sampling_methods 106 | output['mixture_weights'] = self.sampling_weights 107 | for s in self.samplers: 108 | s_output = self.samplers[s].to_dict() 109 | output[s] = s_output 110 | return output 111 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/represent_cluster_centers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Another informative and diverse sampler that mirrors the algorithm described 16 | in Xu, et. al., Representative Sampling for Text Classification Using 17 | Support Vector Machines, 2003 18 | 19 | Batch is created by clustering points within the margin of the classifier and 20 | choosing points closest to the k centroids. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from sklearn.cluster import MiniBatchKMeans 28 | import numpy as np 29 | from sampling_methods.sampling_def import SamplingMethod 30 | 31 | 32 | class RepresentativeClusterMeanSampling(SamplingMethod): 33 | """Selects batch based on informative and diverse criteria. 34 | 35 | Returns points within the margin of the classifier that are closest to the 36 | k-means centers of those points. 37 | """ 38 | 39 | def __init__(self, X, y, seed): 40 | self.name = 'cluster_mean' 41 | self.X = X 42 | self.flat_X = self.flatten_X() 43 | self.y = y 44 | self.seed = seed 45 | 46 | def select_batch_(self, model, N, already_selected, **kwargs): 47 | # Probably okay to always use MiniBatchKMeans 48 | # Should standardize data before clustering 49 | # Can cluster on standardized data but train on raw features if desired 50 | try: 51 | distances = model.decision_function(self.X) 52 | except: 53 | distances = model.predict_proba(self.X) 54 | if len(distances.shape) < 2: 55 | min_margin = abs(distances) 56 | else: 57 | sort_distances = np.sort(distances, 1)[:, -2:] 58 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 59 | rank_ind = np.argsort(min_margin) 60 | rank_ind = [i for i in rank_ind if i not in already_selected] 61 | 62 | distances = abs(model.decision_function(self.X)) 63 | min_margin_by_class = np.min(abs(distances[already_selected]),axis=0) 64 | unlabeled_in_margin = np.array([i for i in range(len(self.y)) 65 | if i not in already_selected and 66 | any(distances[i] 2: 42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:]))) 43 | return flat_X 44 | 45 | 46 | @abc.abstractmethod 47 | def select_batch_(self): 48 | return 49 | 50 | def select_batch(self, **kwargs): 51 | return self.select_batch_(**kwargs) 52 | 53 | def to_dict(self): 54 | return None -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/uniform_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Uniform sampling method. 16 | 17 | Samples in batches. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | 26 | from sampling_methods.sampling_def import SamplingMethod 27 | 28 | 29 | class UniformSampling(SamplingMethod): 30 | 31 | def __init__(self, X, y, seed): 32 | self.X = X 33 | self.y = y 34 | self.name = 'uniform' 35 | np.random.seed(seed) 36 | 37 | def select_batch_(self, already_selected, N, **kwargs): 38 | """Returns batch of randomly sampled datapoints. 39 | 40 | Assumes that data has already been shuffled. 41 | 42 | Args: 43 | already_selected: index of datapoints already selected 44 | N: batch size 45 | 46 | Returns: 47 | indices of points selected to label 48 | """ 49 | 50 | # This is uniform given the remaining pool but biased wrt the entire pool. 51 | sample = [i for i in range(self.X.shape[0]) if i not in already_selected] 52 | return sample[0:N] 53 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/utils/tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Node and Tree class to support hierarchical clustering AL method. 16 | 17 | Assumed to be binary tree. 18 | 19 | Node class is used to represent each node in a hierarchical clustering. 20 | Each node has certain properties that are used in the AL method. 21 | 22 | Tree class is used to traverse a hierarchical clustering. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import copy 30 | 31 | 32 | class Node(object): 33 | """Node class for hierarchical clustering. 34 | 35 | Initialized with name and left right children. 36 | """ 37 | 38 | def __init__(self, name, left=None, right=None): 39 | self.name = name 40 | self.left = left 41 | self.right = right 42 | self.is_leaf = left is None and right is None 43 | self.parent = None 44 | # Fields for hierarchical clustering AL 45 | self.score = 1.0 46 | self.split = False 47 | self.best_label = None 48 | self.weight = None 49 | 50 | def set_parent(self, parent): 51 | self.parent = parent 52 | 53 | 54 | class Tree(object): 55 | """Tree object for traversing a binary tree. 56 | 57 | Most methods apply to trees in general with the exception of get_pruning 58 | which is specific to the hierarchical clustering AL method. 59 | """ 60 | 61 | def __init__(self, root, node_dict): 62 | """Initializes tree and creates all nodes in node_dict. 63 | 64 | Args: 65 | root: id of the root node 66 | node_dict: dictionary with node_id as keys and entries indicating 67 | left and right child of node respectively. 68 | """ 69 | self.node_dict = node_dict 70 | self.root = self.make_tree(root) 71 | self.nodes = {} 72 | self.leaves_mapping = {} 73 | self.fill_parents() 74 | self.n_leaves = None 75 | 76 | def print_tree(self, node, max_depth): 77 | """Helper function to print out tree for debugging.""" 78 | node_list = [node] 79 | output = "" 80 | level = 0 81 | while level < max_depth and len(node_list): 82 | children = set() 83 | for n in node_list: 84 | node = self.get_node(n) 85 | output += ("\t"*level+"node %d: score %.2f, weight %.2f" % 86 | (node.name, node.score, node.weight)+"\n") 87 | if node.left: 88 | children.add(node.left.name) 89 | if node.right: 90 | children.add(node.right.name) 91 | level += 1 92 | node_list = children 93 | return print(output) 94 | 95 | def make_tree(self, node_id): 96 | if node_id is not None: 97 | return Node(node_id, 98 | self.make_tree(self.node_dict[node_id][0]), 99 | self.make_tree(self.node_dict[node_id][1])) 100 | 101 | def fill_parents(self): 102 | # Setting parent and storing nodes in dict for fast access 103 | def rec(pointer, parent): 104 | if pointer is not None: 105 | self.nodes[pointer.name] = pointer 106 | pointer.set_parent(parent) 107 | rec(pointer.left, pointer) 108 | rec(pointer.right, pointer) 109 | rec(self.root, None) 110 | 111 | def get_node(self, node_id): 112 | return self.nodes[node_id] 113 | 114 | def get_ancestor(self, node): 115 | ancestors = [] 116 | if isinstance(node, int): 117 | node = self.get_node(node) 118 | while node.name != self.root.name: 119 | node = node.parent 120 | ancestors.append(node.name) 121 | return ancestors 122 | 123 | def fill_weights(self): 124 | for v in self.node_dict: 125 | node = self.get_node(v) 126 | node.weight = len(self.leaves_mapping[v]) / (1.0 * self.n_leaves) 127 | 128 | def create_child_leaves_mapping(self, leaves): 129 | """DP for creating child leaves mapping. 130 | 131 | Storing in dict to save recompute. 132 | """ 133 | self.n_leaves = len(leaves) 134 | for v in leaves: 135 | self.leaves_mapping[v] = [v] 136 | node_list = set([self.get_node(v).parent for v in leaves]) 137 | while node_list: 138 | to_fill = copy.copy(node_list) 139 | for v in node_list: 140 | if (v.left.name in self.leaves_mapping 141 | and v.right.name in self.leaves_mapping): 142 | to_fill.remove(v) 143 | self.leaves_mapping[v.name] = (self.leaves_mapping[v.left.name] + 144 | self.leaves_mapping[v.right.name]) 145 | if v.parent is not None: 146 | to_fill.add(v.parent) 147 | node_list = to_fill 148 | self.fill_weights() 149 | 150 | def get_child_leaves(self, node): 151 | return self.leaves_mapping[node] 152 | 153 | def get_pruning(self, node): 154 | if node.split: 155 | return self.get_pruning(node.left) + self.get_pruning(node.right) 156 | else: 157 | return [node.name] 158 | 159 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/utils/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sampling_methods.utils.tree.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import unittest 22 | from sampling_methods.utils import tree 23 | 24 | 25 | class TreeTest(unittest.TestCase): 26 | 27 | def setUp(self): 28 | node_dict = { 29 | 1: (2, 3), 30 | 2: (4, 5), 31 | 3: (6, 7), 32 | 4: [None, None], 33 | 5: [None, None], 34 | 6: [None, None], 35 | 7: [None, None] 36 | } 37 | self.tree = tree.Tree(1, node_dict) 38 | self.tree.create_child_leaves_mapping([4, 5, 6, 7]) 39 | node = self.tree.get_node(1) 40 | node.split = True 41 | node = self.tree.get_node(2) 42 | node.split = True 43 | 44 | def assertNode(self, node, name, left, right): 45 | self.assertEqual(node.name, name) 46 | self.assertEqual(node.left.name, left) 47 | self.assertEqual(node.right.name, right) 48 | 49 | def testTreeRootSetCorrectly(self): 50 | self.assertNode(self.tree.root, 1, 2, 3) 51 | 52 | def testGetNode(self): 53 | node = self.tree.get_node(1) 54 | assert isinstance(node, tree.Node) 55 | self.assertEqual(node.name, 1) 56 | 57 | def testFillParent(self): 58 | node = self.tree.get_node(3) 59 | self.assertEqual(node.parent.name, 1) 60 | 61 | def testGetAncestors(self): 62 | ancestors = self.tree.get_ancestor(5) 63 | self.assertTrue(all([a in ancestors for a in [1, 2]])) 64 | 65 | def testChildLeaves(self): 66 | leaves = self.tree.get_child_leaves(3) 67 | self.assertTrue(all([c in leaves for c in [6, 7]])) 68 | 69 | def testFillWeights(self): 70 | node = self.tree.get_node(3) 71 | self.assertEqual(node.weight, 0.5) 72 | 73 | def testGetPruning(self): 74 | node = self.tree.get_node(1) 75 | pruning = self.tree.get_pruning(node) 76 | self.assertTrue(all([n in pruning for n in [3, 4, 5]])) 77 | 78 | if __name__ == '__main__': 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /src/baselines/active-learning/sampling_methods/wrapper_sampler_def.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Abstract class for wrapper sampling methods that call base sampling methods. 16 | 17 | Provides interface to sampling methods that allow same signature 18 | for select_batch. Each subclass implements select_batch_ with the desired 19 | signature for readability. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import abc 27 | 28 | from sampling_methods.constants import AL_MAPPING 29 | from sampling_methods.constants import get_all_possible_arms 30 | from sampling_methods.sampling_def import SamplingMethod 31 | 32 | get_all_possible_arms() 33 | 34 | 35 | class WrapperSamplingMethod(SamplingMethod): 36 | __metaclass__ = abc.ABCMeta 37 | 38 | def initialize_samplers(self, mixtures): 39 | methods = [] 40 | for m in mixtures: 41 | methods += m['methods'] 42 | methods = set(methods) 43 | self.base_samplers = {} 44 | for s in methods: 45 | self.base_samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed) 46 | self.samplers = [] 47 | for m in mixtures: 48 | self.samplers.append( 49 | AL_MAPPING['mixture_of_samplers'](self.X, self.y, self.seed, m, 50 | self.base_samplers)) 51 | -------------------------------------------------------------------------------- /src/baselines/active-learning/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /src/baselines/active-learning/utils/allconv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements allconv model in keras using tensorflow backend.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | 22 | import keras 23 | import keras.backend as K 24 | from keras.layers import Activation 25 | from keras.layers import Conv2D 26 | from keras.layers import Dropout 27 | from keras.layers import GlobalAveragePooling2D 28 | from keras.models import Sequential 29 | 30 | import numpy as np 31 | import tensorflow as tf 32 | 33 | 34 | class AllConv(object): 35 | """allconv network that matches sklearn api.""" 36 | 37 | def __init__(self, 38 | random_state=1, 39 | epochs=50, 40 | batch_size=32, 41 | solver='rmsprop', 42 | learning_rate=0.001, 43 | lr_decay=0.): 44 | # params 45 | self.solver = solver 46 | self.epochs = epochs 47 | self.batch_size = batch_size 48 | self.learning_rate = learning_rate 49 | self.lr_decay = lr_decay 50 | # data 51 | self.encode_map = None 52 | self.decode_map = None 53 | self.model = None 54 | self.random_state = random_state 55 | self.n_classes = None 56 | 57 | def build_model(self, X): 58 | # assumes that data axis order is same as the backend 59 | input_shape = X.shape[1:] 60 | np.random.seed(self.random_state) 61 | tf.set_random_seed(self.random_state) 62 | 63 | model = Sequential() 64 | model.add(Conv2D(96, (3, 3), padding='same', 65 | input_shape=input_shape, name='conv1')) 66 | model.add(Activation('relu')) 67 | model.add(Conv2D(96, (3, 3), name='conv2', padding='same')) 68 | model.add(Activation('relu')) 69 | model.add(Conv2D(96, (3, 3), strides=(2, 2), padding='same', name='conv3')) 70 | model.add(Activation('relu')) 71 | model.add(Dropout(0.5)) 72 | 73 | model.add(Conv2D(192, (3, 3), name='conv4', padding='same')) 74 | model.add(Activation('relu')) 75 | model.add(Conv2D(192, (3, 3), name='conv5', padding='same')) 76 | model.add(Activation('relu')) 77 | model.add(Conv2D(192, (3, 3), strides=(2, 2), name='conv6', padding='same')) 78 | model.add(Activation('relu')) 79 | model.add(Dropout(0.5)) 80 | 81 | model.add(Conv2D(192, (3, 3), name='conv7', padding='same')) 82 | model.add(Activation('relu')) 83 | model.add(Conv2D(192, (1, 1), name='conv8', padding='valid')) 84 | model.add(Activation('relu')) 85 | model.add(Conv2D(10, (1, 1), name='conv9', padding='valid')) 86 | 87 | model.add(GlobalAveragePooling2D()) 88 | model.add(Activation('softmax', name='activation_top')) 89 | model.summary() 90 | 91 | try: 92 | optimizer = getattr(keras.optimizers, self.solver) 93 | except: 94 | raise NotImplementedError('optimizer not implemented in keras') 95 | # All optimizers with the exception of nadam take decay as named arg 96 | try: 97 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay) 98 | except: 99 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay) 100 | 101 | model.compile(loss='categorical_crossentropy', 102 | optimizer=opt, 103 | metrics=['accuracy']) 104 | # Save initial weights so that model can be retrained with same 105 | # initialization 106 | self.initial_weights = copy.deepcopy(model.get_weights()) 107 | 108 | self.model = model 109 | 110 | def create_y_mat(self, y): 111 | y_encode = self.encode_y(y) 112 | y_encode = np.reshape(y_encode, (len(y_encode), 1)) 113 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes) 114 | return y_mat 115 | 116 | # Add handling for classes that do not start counting from 0 117 | def encode_y(self, y): 118 | if self.encode_map is None: 119 | self.classes_ = sorted(list(set(y))) 120 | self.n_classes = len(self.classes_) 121 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_)))) 122 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_)) 123 | mapper = lambda x: self.encode_map[x] 124 | transformed_y = np.array(map(mapper, y)) 125 | return transformed_y 126 | 127 | def decode_y(self, y): 128 | mapper = lambda x: self.decode_map[x] 129 | transformed_y = np.array(map(mapper, y)) 130 | return transformed_y 131 | 132 | def fit(self, X_train, y_train, sample_weight=None): 133 | y_mat = self.create_y_mat(y_train) 134 | 135 | if self.model is None: 136 | self.build_model(X_train) 137 | 138 | # We don't want incremental fit so reset learning rate and weights 139 | K.set_value(self.model.optimizer.lr, self.learning_rate) 140 | self.model.set_weights(self.initial_weights) 141 | self.model.fit( 142 | X_train, 143 | y_mat, 144 | batch_size=self.batch_size, 145 | epochs=self.epochs, 146 | shuffle=True, 147 | sample_weight=sample_weight, 148 | verbose=0) 149 | 150 | def predict(self, X_val): 151 | predicted = self.model.predict(X_val) 152 | return predicted 153 | 154 | def score(self, X_val, val_y): 155 | y_mat = self.create_y_mat(val_y) 156 | val_acc = self.model.evaluate(X_val, y_mat)[1] 157 | return val_acc 158 | 159 | def decision_function(self, X): 160 | return self.predict(X) 161 | 162 | def transform(self, X): 163 | model = self.model 164 | inp = [model.input] 165 | activations = [] 166 | 167 | # Get activations of the last conv layer. 168 | output = [layer.output for layer in model.layers if 169 | layer.name == 'conv9'][0] 170 | func = K.function(inp + [K.learning_phase()], [output]) 171 | for i in range(int(X.shape[0]/self.batch_size) + 1): 172 | minibatch = X[i * self.batch_size 173 | : min(X.shape[0], (i+1) * self.batch_size)] 174 | list_inputs = [minibatch, 0.] 175 | # Learning phase. 0 = Test mode (no dropout or batch normalization) 176 | layer_output = func(list_inputs)[0] 177 | activations.append(layer_output) 178 | output = np.vstack(tuple(activations)) 179 | output = np.reshape(output, (output.shape[0],np.product(output.shape[1:]))) 180 | return output 181 | 182 | def get_params(self, deep = False): 183 | params = {} 184 | params['solver'] = self.solver 185 | params['epochs'] = self.epochs 186 | params['batch_size'] = self.batch_size 187 | params['learning_rate'] = self.learning_rate 188 | params['weight_decay'] = self.lr_decay 189 | if deep: 190 | return copy.deepcopy(params) 191 | return copy.copy(params) 192 | 193 | def set_params(self, **parameters): 194 | for parameter, value in parameters.items(): 195 | setattr(self, parameter, value) 196 | return self 197 | -------------------------------------------------------------------------------- /src/baselines/active-learning/utils/chart_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Experiment charting script. 16 | """ 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | import pickle 24 | 25 | import numpy as np 26 | import matplotlib.pyplot as plt 27 | from matplotlib.backends.backend_pdf import PdfPages 28 | 29 | from absl import app 30 | from absl import flags 31 | from tensorflow import gfile 32 | 33 | flags.DEFINE_string('source_dir', 34 | '/tmp/toy_experiments', 35 | 'Directory with the output to analyze.') 36 | flags.DEFINE_string('save_dir', '/tmp/active_learning', 37 | 'Directory to save charts.') 38 | flags.DEFINE_string('dataset', 'letter', 'Dataset to analyze.') 39 | flags.DEFINE_string( 40 | 'sampling_methods', 41 | ('uniform,margin,informative_diverse,' 42 | 'pred_expert_advice_trip_agg,' 43 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34'), 44 | 'Comma separated string of sampling methods to include in chart.') 45 | flags.DEFINE_string('scoring_methods', 'logistic,kernel_ls', 46 | 'Comma separated string of scoring methods to chart.') 47 | flags.DEFINE_bool('normalize', False, 'Chart runs using normalized data.') 48 | flags.DEFINE_bool('standardize', True, 'Chart runs using standardized data.') 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | 53 | def combine_results(files, diff=False): 54 | all_results = {} 55 | for f in files: 56 | data = pickle.load(gfile.FastGFile(f, 'r')) 57 | for k in data: 58 | if isinstance(k, tuple): 59 | data[k].pop('noisy_targets') 60 | data[k].pop('indices') 61 | data[k].pop('selected_inds') 62 | data[k].pop('sampler_output') 63 | key = list(k) 64 | seed = key[-1] 65 | key = key[0:10] 66 | key = tuple(key) 67 | if key in all_results: 68 | if seed not in all_results[key]['random_seeds']: 69 | all_results[key]['random_seeds'].append(seed) 70 | for field in [f for f in data[k] if f != 'n_points']: 71 | all_results[key][field] = np.vstack( 72 | (all_results[key][field], data[k][field])) 73 | else: 74 | all_results[key] = data[k] 75 | all_results[key]['random_seeds'] = [seed] 76 | else: 77 | all_results[k] = data[k] 78 | return all_results 79 | 80 | 81 | def plot_results(all_results, score_method, norm, stand, sampler_filter): 82 | colors = { 83 | 'margin': 84 | 'gold', 85 | 'uniform': 86 | 'k', 87 | 'informative_diverse': 88 | 'r', 89 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34': 90 | 'b', 91 | 'pred_expert_advice_trip_agg': 92 | 'g' 93 | } 94 | labels = { 95 | 'margin': 96 | 'margin', 97 | 'uniform': 98 | 'uniform', 99 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34': 100 | 'margin:0.33,informative_diverse:0.33, uniform:0.34', 101 | 'informative_diverse': 102 | 'informative and diverse', 103 | 'pred_expert_advice_trip_agg': 104 | 'expert: margin,informative_diverse,uniform' 105 | } 106 | markers = { 107 | 'margin': 108 | 'None', 109 | 'uniform': 110 | 'None', 111 | 'mixture_of_samplers-margin-0.33-informative_diverse-0.33-uniform-0.34': 112 | '>', 113 | 'informative_diverse': 114 | 'None', 115 | 'pred_expert_advice_trip_agg': 116 | 'p' 117 | } 118 | fields = all_results['tuple_keys'] 119 | fields = dict(zip(fields, range(len(fields)))) 120 | 121 | for k in sorted(all_results.keys()): 122 | sampler = k[fields['sampler']] 123 | if (isinstance(k, tuple) and 124 | k[fields['score_method']] == score_method and 125 | k[fields['standardize']] == stand and 126 | k[fields['normalize']] == norm and 127 | (sampler_filter is None or sampler in sampler_filter)): 128 | results = all_results[k] 129 | n_trials = results['accuracy'].shape[0] 130 | x = results['data_sizes'][0] 131 | mean_acc = np.mean(results['accuracy'], axis=0) 132 | CI_acc = np.std(results['accuracy'], axis=0) / np.sqrt(n_trials) * 2.96 133 | if sampler == 'uniform': 134 | plt.plot( 135 | x, 136 | mean_acc, 137 | linewidth=1, 138 | label=labels[sampler], 139 | color=colors[sampler], 140 | linestyle='--' 141 | ) 142 | plt.fill_between( 143 | x, 144 | mean_acc - CI_acc, 145 | mean_acc + CI_acc, 146 | color=colors[sampler], 147 | alpha=0.2 148 | ) 149 | else: 150 | plt.plot( 151 | x, 152 | mean_acc, 153 | linewidth=1, 154 | label=labels[sampler], 155 | color=colors[sampler], 156 | marker=markers[sampler], 157 | markeredgecolor=colors[sampler] 158 | ) 159 | plt.fill_between( 160 | x, 161 | mean_acc - CI_acc, 162 | mean_acc + CI_acc, 163 | color=colors[sampler], 164 | alpha=0.2 165 | ) 166 | plt.legend(loc=4) 167 | 168 | 169 | def get_between(filename, start, end): 170 | start_ind = filename.find(start) + len(start) 171 | end_ind = filename.rfind(end) 172 | return filename[start_ind:end_ind] 173 | 174 | 175 | def get_sampling_method(dataset, filename): 176 | return get_between(filename, dataset + '_', '/') 177 | 178 | 179 | def get_scoring_method(filename): 180 | return get_between(filename, 'results_score_', '_select_') 181 | 182 | 183 | def get_normalize(filename): 184 | return get_between(filename, '_norm_', '_stand_') == 'True' 185 | 186 | 187 | def get_standardize(filename): 188 | return get_between( 189 | filename, '_stand_', filename[filename.rfind('_'):]) == 'True' 190 | 191 | 192 | def main(argv): 193 | del argv # Unused. 194 | if not gfile.Exists(FLAGS.save_dir): 195 | gfile.MkDir(FLAGS.save_dir) 196 | charting_filepath = os.path.join(FLAGS.save_dir, 197 | FLAGS.dataset + '_charts.pdf') 198 | sampling_methods = FLAGS.sampling_methods.split(',') 199 | scoring_methods = FLAGS.scoring_methods.split(',') 200 | files = gfile.Glob( 201 | os.path.join(FLAGS.source_dir, FLAGS.dataset + '*/results*.pkl')) 202 | files = [ 203 | f for f in files 204 | if (get_sampling_method(FLAGS.dataset, f) in sampling_methods and 205 | get_scoring_method(f) in scoring_methods and 206 | get_normalize(f) == FLAGS.normalize and 207 | get_standardize(f) == FLAGS.standardize) 208 | ] 209 | 210 | print('Reading in %d files...' % len(files)) 211 | all_results = combine_results(files) 212 | pdf = PdfPages(charting_filepath) 213 | 214 | print('Plotting charts...') 215 | plt.style.use('ggplot') 216 | for m in scoring_methods: 217 | plot_results( 218 | all_results, 219 | m, 220 | FLAGS.normalize, 221 | FLAGS.standardize, 222 | sampler_filter=sampling_methods) 223 | plt.title('Dataset: %s, Score Method: %s' % (FLAGS.dataset, m)) 224 | pdf.savefig() 225 | plt.close() 226 | pdf.close() 227 | 228 | 229 | if __name__ == '__main__': 230 | app.run(main) 231 | -------------------------------------------------------------------------------- /src/baselines/active-learning/utils/kernel_block_solver.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Block kernel lsqr solver for multi-class classification.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | import math 22 | 23 | import numpy as np 24 | import scipy.linalg as linalg 25 | from scipy.sparse.linalg import spsolve 26 | from sklearn import metrics 27 | 28 | 29 | class BlockKernelSolver(object): 30 | """Inspired by algorithm from https://arxiv.org/pdf/1602.05310.pdf.""" 31 | # TODO(lishal): save preformed kernel matrix and reuse if possible 32 | # perhaps not possible if want to keep scikitlearn signature 33 | 34 | def __init__(self, 35 | random_state=1, 36 | C=0.1, 37 | block_size=4000, 38 | epochs=3, 39 | verbose=False, 40 | gamma=None): 41 | self.block_size = block_size 42 | self.epochs = epochs 43 | self.C = C 44 | self.kernel = 'rbf' 45 | self.coef_ = None 46 | self.verbose = verbose 47 | self.encode_map = None 48 | self.decode_map = None 49 | self.gamma = gamma 50 | self.X_train = None 51 | self.random_state = random_state 52 | 53 | def encode_y(self, y): 54 | # Handles classes that do not start counting from 0. 55 | if self.encode_map is None: 56 | self.classes_ = sorted(list(set(y))) 57 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_)))) 58 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_)) 59 | mapper = lambda x: self.encode_map[x] 60 | transformed_y = np.array(map(mapper, y)) 61 | return transformed_y 62 | 63 | def decode_y(self, y): 64 | mapper = lambda x: self.decode_map[x] 65 | transformed_y = np.array(map(mapper, y)) 66 | return transformed_y 67 | 68 | def fit(self, X_train, y_train, sample_weight=None): 69 | """Form K and solve (K + lambda * I)x = y in a block-wise fashion.""" 70 | np.random.seed(self.random_state) 71 | self.X_train = X_train 72 | n_features = X_train.shape[1] 73 | y = self.encode_y(y_train) 74 | if self.gamma is None: 75 | self.gamma = 1./n_features 76 | K = metrics.pairwise.pairwise_kernels( 77 | X_train, metric=self.kernel, gamma=self.gamma) 78 | if self.verbose: 79 | print('Finished forming kernel matrix.') 80 | 81 | # compute some constants 82 | num_classes = len(list(set(y))) 83 | num_samples = K.shape[0] 84 | num_blocks = math.ceil(num_samples*1.0/self.block_size) 85 | x = np.zeros((K.shape[0], num_classes)) 86 | y_hat = np.zeros((K.shape[0], num_classes)) 87 | onehot = lambda x: np.eye(num_classes)[x] 88 | y_onehot = np.array(map(onehot, y)) 89 | idxes = np.diag_indices(num_samples) 90 | if sample_weight is not None: 91 | weights = np.sqrt(sample_weight) 92 | weights = weights[:, np.newaxis] 93 | y_onehot = weights * y_onehot 94 | K *= np.outer(weights, weights) 95 | if num_blocks == 1: 96 | epochs = 1 97 | else: 98 | epochs = self.epochs 99 | 100 | for e in range(epochs): 101 | shuffled_coords = np.random.choice( 102 | num_samples, num_samples, replace=False) 103 | for b in range(int(num_blocks)): 104 | residuals = y_onehot - y_hat 105 | 106 | # Form a block of K. 107 | K[idxes] += (self.C * num_samples) 108 | block = shuffled_coords[b*self.block_size: 109 | min((b+1)*self.block_size, num_samples)] 110 | K_block = K[:, block] 111 | # Dim should be block size x block size 112 | KbTKb = K_block.T.dot(K_block) 113 | 114 | if self.verbose: 115 | print('solving block {0}'.format(b)) 116 | # Try linalg solve then sparse solve for handling of sparse input. 117 | try: 118 | x_block = linalg.solve(KbTKb, K_block.T.dot(residuals)) 119 | except: 120 | try: 121 | x_block = spsolve(KbTKb, K_block.T.dot(residuals)) 122 | except: 123 | return None 124 | 125 | # update model 126 | x[block] = x[block] + x_block 127 | K[idxes] = K[idxes] - (self.C * num_samples) 128 | y_hat = K.dot(x) 129 | 130 | y_pred = np.argmax(y_hat, axis=1) 131 | train_acc = metrics.accuracy_score(y, y_pred) 132 | if self.verbose: 133 | print('Epoch: {0}, Block: {1}, Train Accuracy: {2}' 134 | .format(e, b, train_acc)) 135 | self.coef_ = x 136 | 137 | def predict(self, X_val): 138 | val_K = metrics.pairwise.pairwise_kernels( 139 | X_val, self.X_train, metric=self.kernel, gamma=self.gamma) 140 | val_pred = np.argmax(val_K.dot(self.coef_), axis=1) 141 | return self.decode_y(val_pred) 142 | 143 | def score(self, X_val, val_y): 144 | val_pred = self.predict(X_val) 145 | val_acc = metrics.accuracy_score(val_y, val_pred) 146 | return val_acc 147 | 148 | def decision_function(self, X, type='predicted'): 149 | # Return the predicted value of the best class 150 | # Margin_AL will see that a vector is returned and not a matrix and 151 | # simply select the points that have the lowest predicted value to label 152 | K = metrics.pairwise.pairwise_kernels( 153 | X, self.X_train, metric=self.kernel, gamma=self.gamma) 154 | predicted = K.dot(self.coef_) 155 | if type == 'scores': 156 | val_best = np.max(K.dot(self.coef_), axis=1) 157 | return val_best 158 | elif type == 'predicted': 159 | return predicted 160 | else: 161 | raise NotImplementedError('Invalid return type for decision function.') 162 | 163 | def get_params(self, deep=False): 164 | params = {} 165 | params['C'] = self.C 166 | params['gamma'] = self.gamma 167 | if deep: 168 | return copy.deepcopy(params) 169 | return copy.copy(params) 170 | 171 | def set_params(self, **parameters): 172 | for parameter, value in parameters.items(): 173 | setattr(self, parameter, value) 174 | return self 175 | 176 | def softmax_over_predicted(self, X): 177 | val_K = metrics.pairwise.pairwise_kernels( 178 | X, self.X_train, metric=self.kernel, gamma=self.gamma) 179 | val_pred = val_K.dot(self.coef_) 180 | row_min = np.min(val_pred, axis=1) 181 | val_pred = val_pred - row_min[:, None] 182 | val_pred = np.exp(val_pred) 183 | sum_exp = np.sum(val_pred, axis=1) 184 | val_pred = val_pred/sum_exp[:, None] 185 | return val_pred 186 | -------------------------------------------------------------------------------- /src/baselines/active-learning/utils/small_cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Implements Small CNN model in keras using tensorflow backend.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import copy 21 | 22 | import keras 23 | import keras.backend as K 24 | from keras.layers import Activation 25 | from keras.layers import Conv2D 26 | from keras.layers import Dense 27 | from keras.layers import Dropout 28 | from keras.layers import Flatten 29 | from keras.layers import MaxPooling2D 30 | from keras.models import Sequential 31 | 32 | import numpy as np 33 | import tensorflow as tf 34 | 35 | 36 | class SmallCNN(object): 37 | """Small convnet that matches sklearn api. 38 | 39 | Implements model from 40 | https://github.com/fchollet/keras/blob/master/examples/cifar10_cnn.py 41 | Adapts for inputs of variable size, expects data to be 4d tensor, with 42 | # of obserations as first dimension and other dimensions to correspond to 43 | length width and # of channels in image. 44 | """ 45 | 46 | def __init__(self, 47 | random_state=1, 48 | epochs=50, 49 | batch_size=32, 50 | solver='rmsprop', 51 | learning_rate=0.001, 52 | lr_decay=0.): 53 | # params 54 | self.solver = solver 55 | self.epochs = epochs 56 | self.batch_size = batch_size 57 | self.learning_rate = learning_rate 58 | self.lr_decay = lr_decay 59 | # data 60 | self.encode_map = None 61 | self.decode_map = None 62 | self.model = None 63 | self.random_state = random_state 64 | self.n_classes = None 65 | 66 | def build_model(self, X): 67 | # assumes that data axis order is same as the backend 68 | input_shape = X.shape[1:] 69 | np.random.seed(self.random_state) 70 | tf.set_random_seed(self.random_state) 71 | 72 | model = Sequential() 73 | model.add(Conv2D(32, (3, 3), padding='same', 74 | input_shape=input_shape, name='conv1')) 75 | model.add(Activation('relu')) 76 | model.add(Conv2D(32, (3, 3), name='conv2')) 77 | model.add(Activation('relu')) 78 | model.add(MaxPooling2D(pool_size=(2, 2))) 79 | model.add(Dropout(0.25)) 80 | 81 | model.add(Conv2D(64, (3, 3), padding='same', name='conv3')) 82 | model.add(Activation('relu')) 83 | model.add(Conv2D(64, (3, 3), name='conv4')) 84 | model.add(Activation('relu')) 85 | model.add(MaxPooling2D(pool_size=(2, 2))) 86 | model.add(Dropout(0.25)) 87 | 88 | model.add(Flatten()) 89 | model.add(Dense(512, name='dense1')) 90 | model.add(Activation('relu')) 91 | model.add(Dropout(0.5)) 92 | model.add(Dense(self.n_classes, name='dense2')) 93 | model.add(Activation('softmax')) 94 | 95 | try: 96 | optimizer = getattr(keras.optimizers, self.solver) 97 | except: 98 | raise NotImplementedError('optimizer not implemented in keras') 99 | # All optimizers with the exception of nadam take decay as named arg 100 | try: 101 | opt = optimizer(lr=self.learning_rate, decay=self.lr_decay) 102 | except: 103 | opt = optimizer(lr=self.learning_rate, schedule_decay=self.lr_decay) 104 | 105 | model.compile(loss='categorical_crossentropy', 106 | optimizer=opt, 107 | metrics=['accuracy']) 108 | # Save initial weights so that model can be retrained with same 109 | # initialization 110 | self.initial_weights = copy.deepcopy(model.get_weights()) 111 | 112 | self.model = model 113 | 114 | def create_y_mat(self, y): 115 | y_encode = self.encode_y(y) 116 | y_encode = np.reshape(y_encode, (len(y_encode), 1)) 117 | y_mat = keras.utils.to_categorical(y_encode, self.n_classes) 118 | return y_mat 119 | 120 | # Add handling for classes that do not start counting from 0 121 | def encode_y(self, y): 122 | if self.encode_map is None: 123 | self.classes_ = sorted(list(set(y))) 124 | self.n_classes = len(self.classes_) 125 | self.encode_map = dict(zip(self.classes_, range(len(self.classes_)))) 126 | self.decode_map = dict(zip(range(len(self.classes_)), self.classes_)) 127 | mapper = lambda x: self.encode_map[x] 128 | transformed_y = np.array(map(mapper, y)) 129 | return transformed_y 130 | 131 | def decode_y(self, y): 132 | mapper = lambda x: self.decode_map[x] 133 | transformed_y = np.array(map(mapper, y)) 134 | return transformed_y 135 | 136 | def fit(self, X_train, y_train, sample_weight=None): 137 | y_mat = self.create_y_mat(y_train) 138 | 139 | if self.model is None: 140 | self.build_model(X_train) 141 | 142 | # We don't want incremental fit so reset learning rate and weights 143 | K.set_value(self.model.optimizer.lr, self.learning_rate) 144 | self.model.set_weights(self.initial_weights) 145 | self.model.fit( 146 | X_train, 147 | y_mat, 148 | batch_size=self.batch_size, 149 | epochs=self.epochs, 150 | shuffle=True, 151 | sample_weight=sample_weight, 152 | verbose=0) 153 | 154 | def predict(self, X_val): 155 | predicted = self.model.predict(X_val) 156 | return predicted 157 | 158 | def score(self, X_val, val_y): 159 | y_mat = self.create_y_mat(val_y) 160 | val_acc = self.model.evaluate(X_val, y_mat)[1] 161 | return val_acc 162 | 163 | def decision_function(self, X): 164 | return self.predict(X) 165 | 166 | def transform(self, X): 167 | model = self.model 168 | inp = [model.input] 169 | activations = [] 170 | 171 | # Get activations of the first dense layer. 172 | output = [layer.output for layer in model.layers if 173 | layer.name == 'dense1'][0] 174 | func = K.function(inp + [K.learning_phase()], [output]) 175 | for i in range(int(X.shape[0]/self.batch_size) + 1): 176 | minibatch = X[i * self.batch_size 177 | : min(X.shape[0], (i+1) * self.batch_size)] 178 | list_inputs = [minibatch, 0.] 179 | # Learning phase. 0 = Test mode (no dropout or batch normalization) 180 | layer_output = func(list_inputs)[0] 181 | activations.append(layer_output) 182 | output = np.vstack(tuple(activations)) 183 | return output 184 | 185 | def get_params(self, deep = False): 186 | params = {} 187 | params['solver'] = self.solver 188 | params['epochs'] = self.epochs 189 | params['batch_size'] = self.batch_size 190 | params['learning_rate'] = self.learning_rate 191 | params['weight_decay'] = self.lr_decay 192 | if deep: 193 | return copy.deepcopy(params) 194 | return copy.copy(params) 195 | 196 | def set_params(self, **parameters): 197 | for parameter, value in parameters.items(): 198 | setattr(self, parameter, value) 199 | return self 200 | -------------------------------------------------------------------------------- /src/baselines/anrmab.py: -------------------------------------------------------------------------------- 1 | # from src.baselines.sampling_methods.bandit_discrete import BanditDiscreteSampler 2 | import numpy as np 3 | import scipy as sc 4 | from sklearn.cluster import KMeans 5 | from sklearn.metrics.pairwise import euclidean_distances 6 | import networkx as nx 7 | def centralissimo(G): 8 | centralities = [] 9 | centralities.append(nx.pagerank(G)) #centralities.append(nx.harmonic_centrality(G)) 10 | L = len(centralities[0]) 11 | Nc = len(centralities) 12 | cenarray = np.zeros((Nc,L)) 13 | for i in range(Nc): 14 | cenarray[i][list(centralities[i].keys())]=list(centralities[i].values()) 15 | normcen = (cenarray.astype(float)-np.min(cenarray,axis=1)[:,None])/(np.max(cenarray,axis=1)-np.min(cenarray,axis=1))[:,None] 16 | return normcen 17 | 18 | class ProbSampler(object): 19 | def __init__(self): 20 | pass 21 | 22 | def select_batch(self,wraped_feat): 23 | entropy = wraped_feat[0] 24 | selected = np.argmax(entropy) 25 | return selected 26 | 27 | class DegSampler(object): 28 | def __init__(self): 29 | pass 30 | def select_batch(self,wraped_feat): 31 | deg = wraped_feat[1] 32 | selected = np.argmax(deg) 33 | # print("in DegSampler {}".format(np.max(deg))) 34 | return selected 35 | 36 | class ClusterSampler(object): 37 | def __init__(self): 38 | pass 39 | def select_batch(self,wraped_feat): 40 | edscore = wraped_feat[2] 41 | selected = np.argmax(-edscore) 42 | # print("in DegSampler {}".format(np.max(deg))) 43 | return selected 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | class BanditDiscreteSampler(object): 52 | """Wraps EXP3 around mixtures of indicated methods. 53 | 54 | Uses EXP3 mult-armed bandit algorithm to select sampler methods. 55 | """ 56 | def __init__(self,budget,num_nodes,n_arms,seed=123, 57 | reward_function = lambda AL_acc: AL_acc[-1], 58 | gamma=0.5): 59 | 60 | # self.name = 'bandit_discrete' 61 | # np.random.seed(seed) 62 | # self.seed = seed 63 | # self.initialize_samplers(samplers) 64 | # self.gamma = gamma 65 | 66 | # print(budget,num_nodes,n_arms) 67 | self.n_arms = n_arms 68 | self.reward_function = reward_function 69 | self.num_arm = float(self.n_arms) 70 | self.pmin=np.sqrt(np.log(self.num_arm) / (self.num_arm * budget)) 71 | 72 | self.pull_history = [] 73 | self.acc_history = [] 74 | self.w = np.ones(self.n_arms) 75 | # self.x = np.zeros(self.n_arms) 76 | self.p = None 77 | # self.probs = [] 78 | self.selectionhistory = [] 79 | self.num_nodes = float(num_nodes) 80 | self.rewardlist = [] 81 | self.phi = [] 82 | self.budget = budget 83 | self.Q = [] 84 | self.lastselected = [] 85 | 86 | def select_batch(self, r, wraped_feat): 87 | 88 | # print(r) 89 | # self.acc_history.append(eval_acc) 90 | 91 | wraped_feat = np.exp(wraped_feat*20.) #make the prob sharper 92 | wraped_feat = wraped_feat/np.sum(wraped_feat,axis=-1,keepdims=True) 93 | if self.p is not None : 94 | self.rewardlist.append(1.0 / (self.phi[self.lastselected] * float(wraped_feat.shape[-1])) * r) 95 | # print("rhat {} {} {}".format(self.rewardlist[-1],self.phi[self.lastselected] , float(wraped_feat.shape[-1]))) 96 | reward = 1/float(self.budget)*sum(self.rewardlist) 97 | self.rhat = reward*self.Q[:, self.lastselected] / self.phi[self.lastselected] 98 | self.w = self.w*np.exp(1*0.5*self.pmin*(self.rhat+1/self.p*np.sqrt(np.log(self.num_nodes/0.1)/(self.num_arm*self.budget)))) 99 | 100 | self.p = (1 - self.num_arm * self.pmin) * self.w / np.sum(self.w) + self.pmin 101 | # print(self.p,np.sum(self.p),self.pmin) 102 | self.Q = wraped_feat 103 | self.phi = np.matmul(self.p.reshape((1, -1)), wraped_feat).squeeze() 104 | 105 | t= 10 106 | # self.phi = np.exp(self.phi * 100.) 107 | # self.phi = self.phi/np.sum(self.phi) 108 | # print(self.phi.shape) 109 | 110 | 111 | selected = np.random.choice(range(wraped_feat.shape[-1]), p=self.phi) 112 | self.lastselected = selected 113 | 114 | # print(wraped_feat[:,selected]) 115 | return selected 116 | 117 | 118 | class AnrmabQuery(object): 119 | def __init__(self, G, budget,num_nodes,batchsize): 120 | self.q = [] 121 | self.batchsize = batchsize 122 | # self.alreadyselected = [[] for i in range(batchsize)] 123 | self.G = G 124 | self.NCL = G.stat["nclass"] 125 | self.normcen = centralissimo(self.G.G)[0] 126 | self.cenperc = self.perc(self.normcen) 127 | 128 | for i in range(batchsize): 129 | self.q.append(BanditDiscreteSampler(budget=budget,num_nodes=num_nodes,n_arms=3)) 130 | pass 131 | 132 | def __call__(self, output, acc, pool): 133 | ret = [] 134 | for id in range(self.batchsize): 135 | selected = self.selectOneNode( self.q[id],acc[id], output[id],pool[id]) 136 | ret.append(selected) 137 | ret = np.array(ret) #.reshape(-1,1) 138 | ret = ret.tolist() 139 | return ret 140 | 141 | def selectOneNode(self, q, acc, output,pool): 142 | 143 | # if self.multilabel: 144 | # probs = 1. / (1. + np.exp(-output)) 145 | # entropy = multiclassentropy_numpy(probs) 146 | # else: 147 | entropy = sc.stats.entropy(output.transpose()) 148 | # print("entropy shape{}".format(entropy.shape)) 149 | 150 | validentropy = entropy[pool] 151 | validdeg = self.normcen[pool] 152 | 153 | kmeans = KMeans(n_clusters=self.NCL, random_state=0).fit(output) 154 | ed = euclidean_distances(output, kmeans.cluster_centers_) 155 | ed_score = np.min(ed, axis=1) 156 | valided_score = ed_score[pool] 157 | 158 | 159 | entrperc = self.perc(entropy) 160 | edprec = self.percd(ed_score) 161 | 162 | 163 | wraped_feat = np.stack([entrperc, edprec, self.cenperc]) 164 | 165 | wraped_feat_valid = wraped_feat[:,pool] 166 | # print("warpedfaet {}".format(wraped_feat_valid.shape)) 167 | selected = q.select_batch( acc, wraped_feat_valid) 168 | 169 | realselected = pool[selected] 170 | # print(realselected) 171 | # print(self.normcen[realselected]) 172 | 173 | return realselected 174 | 175 | def percd(self,input): 176 | return 1-np.argsort(np.argsort(input,kind='stable'),kind='stable')/len(input) 177 | 178 | # calculate the percentage of elements smaller than the k-th element 179 | def perc(self,input): 180 | return 1-np.argsort(np.argsort(-input,kind='stable'),kind='stable')/len(input) 181 | 182 | 183 | def unitTest(): 184 | q = CoreSetQuery(3) 185 | 186 | # pool = np.array([[2,3],[2,3],[4,6]]) 187 | 188 | for i in range(3): 189 | features = np.random.randn(3, 10, 5) 190 | selected = q(features) 191 | print(selected) 192 | 193 | 194 | if __name__ == "__main__": 195 | unitTest() 196 | 197 | 198 | -------------------------------------------------------------------------------- /src/baselines/coreset.py: -------------------------------------------------------------------------------- 1 | from src.baselines.sampling_methods.kcenter_greedy import kCenterGreedy 2 | import numpy as np 3 | 4 | class CoreSetQuery(object): 5 | def __init__(self, batchsize,trainsetid): 6 | self.q = [] 7 | self.trainsetid = trainsetid 8 | self.batchsize = batchsize 9 | self.alreadyselected = [[] for i in range(batchsize)] 10 | for i in range(batchsize): 11 | self.q.append(kCenterGreedy()) 12 | pass 13 | 14 | def __call__(self, outputs): 15 | ret = [] 16 | for id,row in enumerate(self.alreadyselected): 17 | selected = self.selectOneNode(outputs[id], row, self.q[id]) 18 | ret.append(selected) 19 | ret = np.array(ret) #.reshape(-1,1) 20 | ret = ret.tolist() 21 | self.alreadyselected = [x + ret[id] for id, x in enumerate(self.alreadyselected)] 22 | 23 | selectedtrueid = [] 24 | for i in range(self.batchsize): 25 | print(ret[i]) 26 | selectedtrueid.append(self.trainsetid[i][ret[i][0]]) 27 | 28 | return selectedtrueid 29 | 30 | 31 | def selectOneNode(self, output, pool,q): 32 | selected = q.select_batch(output,pool,1) 33 | return selected 34 | 35 | 36 | def unitTest(): 37 | q = CoreSetQuery(3) 38 | 39 | # pool = np.array([[2,3],[2,3],[4,6]]) 40 | 41 | for i in range(3): 42 | features = np.random.randn(3, 10, 5) 43 | selected = q(features) 44 | print(selected) 45 | 46 | 47 | if __name__ == "__main__": 48 | unitTest() 49 | 50 | 51 | -------------------------------------------------------------------------------- /src/baselines/coreset/compute_distance_mat.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy 3 | import numpy.matlib 4 | import time 5 | import pickle 6 | import bisect 7 | 8 | dat = pickle.load(open('feature_vectors_pickled')) 9 | data = numpy.concatenate((dat['gt_f'],dat['f']), axis=0) 10 | budget = 5000 11 | 12 | start = time.clock() 13 | num_images = data.shape[0] 14 | 15 | dist_mat = numpy.matmul(data,data.transpose()) 16 | 17 | sq = numpy.array(dist_mat.diagonal()).reshape(num_images,1) 18 | dist_mat *= -2 19 | dist_mat+=sq 20 | dist_mat+=sq.transpose() 21 | 22 | elapsed = time.clock() - start 23 | print "Time spent in (distance computation) is: ", elapsed 24 | numpy.save('distances.npy', dist_mat) 25 | 26 | -------------------------------------------------------------------------------- /src/baselines/coreset/configure.sh: -------------------------------------------------------------------------------- 1 | export GUROBI_HOME="/opt/gurobi702/linux64" 2 | export PATH="${PATH}:${GUROBI_HOME}/bin" 3 | export LD_LIBRARY_PATH="${LD_LIBRARY_PATH}:${GUROBI_HOME}/lib" 4 | export GRB_LICENSE_FILE="/afs/cs.stanford.edu/u/ozansener/gurobi.lic" 5 | 6 | -------------------------------------------------------------------------------- /src/baselines/coreset/full_solver_gurobi.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | from gurobipy import * 3 | import pickle 4 | import numpy.matlib 5 | import time 6 | import pickle 7 | import bisect 8 | 9 | def solve_fac_loc(xx,yy,subset,n,budget): 10 | model = Model("k-center") 11 | x={} 12 | y={} 13 | z={} 14 | for i in range(n): 15 | # z_i: is a loss 16 | z[i] = model.addVar(obj=1, ub=0.0, vtype="B", name="z_{}".format(i)) 17 | 18 | m = len(xx) 19 | for i in range(m): 20 | _x = xx[i] 21 | _y = yy[i] 22 | # y_i = 1 means i is facility, 0 means it is not 23 | if _y not in y: 24 | if _y in subset: 25 | y[_y] = model.addVar(obj=0, ub=1.0, lb=1.0, vtype="B", name="y_{}".format(_y)) 26 | else: 27 | y[_y] = model.addVar(obj=0,vtype="B", name="y_{}".format(_y)) 28 | #if not _x == _y: 29 | x[_x,_y] = model.addVar(obj=0, vtype="B", name="x_{},{}".format(_x,_y)) 30 | model.update() 31 | 32 | coef = [1 for j in range(n)] 33 | var = [y[j] for j in range(n)] 34 | model.addConstr(LinExpr(coef,var), "=", rhs=budget+len(subset), name="k_center") 35 | 36 | for i in range(m): 37 | _x = xx[i] 38 | _y = yy[i] 39 | #if not _x == _y: 40 | model.addConstr(x[_x,_y], "<", y[_y], name="Strong_{},{}".format(_x,_y)) 41 | 42 | yyy = {} 43 | for v in range(m): 44 | _x = xx[v] 45 | _y = yy[v] 46 | if _x not in yyy: 47 | yyy[_x]=[] 48 | if _y not in yyy[_x]: 49 | yyy[_x].append(_y) 50 | 51 | for _x in yyy: 52 | coef = [] 53 | var = [] 54 | for _y in yyy[_x]: 55 | #if not _x==_y: 56 | coef.append(1) 57 | var.append(x[_x,_y]) 58 | coef.append(1) 59 | var.append(z[_x]) 60 | model.addConstr(LinExpr(coef,var), "=", 1, name="Assign{}".format(_x)) 61 | model.__data = x,y,z 62 | return model 63 | 64 | 65 | data = pickle.load(open('feature_vectors_pickled')) 66 | budget = 10000 67 | 68 | start = time.clock() 69 | num_images = data.shape[0] 70 | dist_mat = numpy.matmul(data,data.transpose()) 71 | 72 | sq = numpy.array(dist_mat.diagonal()).reshape(num_images,1) 73 | dist_mat *= -2 74 | dist_mat+=sq 75 | dist_mat+=sq.transpose() 76 | 77 | elapsed = time.clock() - start 78 | print "Time spent in (distance computation) is: ", elapsed 79 | 80 | num_images = 50000 81 | 82 | # We need to get k centers start with greedy solution 83 | budget = 10000 84 | subset = [i for i in range(1)] 85 | 86 | ub= UB 87 | lb = ub/2.0 88 | max_dist=ub 89 | 90 | _x,_y = numpy.where(dist_mat<=max_dist) 91 | _d = dist_mat[_x,_y] 92 | subset = [i for i in range(1)] 93 | model = solve_fac_loc(_x,_y,subset,num_images,budget) 94 | #model.setParam( 'OutputFlag', False ) 95 | x,y,z = model.__data 96 | delta=1e-7 97 | while ub-lb>delta: 98 | print "State",ub,lb 99 | cur_r = (ub+lb)/2.0 100 | viol = numpy.where(_d>cur_r) 101 | new_max_d = numpy.min(_d[_d>=cur_r]) 102 | new_min_d = numpy.max(_d[_d<=cur_r]) 103 | print "If it succeeds, new max is:",new_max_d,new_min_d 104 | for v in viol[0]: 105 | x[_x[v],_y[v]].UB = 0 106 | 107 | model.update() 108 | r = model.optimize() 109 | if model.getAttr(GRB.Attr.Status) == GRB.INFEASIBLE: 110 | failed=True 111 | print "Infeasible" 112 | elif sum([z[i].X for i in range(len(z))]) > 0: 113 | failed=True 114 | print "Failed" 115 | else: 116 | failed=False 117 | if failed: 118 | lb = max(cur_r,new_max_d) 119 | #failed so put edges back 120 | for v in viol[0]: 121 | x[_x[v],_y[v]].UB = 1 122 | else: 123 | print "sol founded",cur_r,lb,ub 124 | ub = min(cur_r,new_min_d) 125 | model.write("s_{}_solution_{}.sol".format(budget,cur_r)) 126 | 127 | -------------------------------------------------------------------------------- /src/baselines/coreset/gurobi_solution_parser.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | 5 | import numpy 6 | import pickle 7 | import numpy.matlib 8 | import time 9 | import pickle 10 | import bisect 11 | 12 | gurobi_solution_file = 'solution_2.86083525164.sol' 13 | results = open(gurobi_solution_file).read().split('\n') 14 | results_nodes = filter(lambda x: 'y' in x,filter(lambda x:'#' not in x, results)) 15 | string_to_id = lambda x:(int(x.split(' ')[0].split('_')[1]),int(x.split(' ')[1])) 16 | result_node_ids = map(string_to_id, results_nodes) 17 | 18 | results_as_dict = {v[0]:v[1] for v in result_node_ids} 19 | 20 | centers = [] 21 | for node_result in result_node_ids: 22 | if node_result[1] > 0: 23 | centers.append(node_result[0]) 24 | 25 | pickle.dump(centers,open('centers.bn','wb')) 26 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/bandit_discrete.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Bandit wrapper around base AL sampling methods. 16 | 17 | Assumes adversarial multi-armed bandit setting where arms correspond to 18 | mixtures of different AL methods. 19 | 20 | Uses EXP3 algorithm to decide which AL method to use to create the next batch. 21 | Similar to Hsu & Lin 2015, Active Learning by Learning. 22 | https://www.csie.ntu.edu.tw/~htlin/paper/doc/aaai15albl.pdf 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import numpy as np 30 | 31 | from src.baselines.sampling_methods.wrapper_sampler_def import AL_MAPPING, WrapperSamplingMethod 32 | 33 | 34 | class BanditDiscreteSampler(WrapperSamplingMethod): 35 | """Wraps EXP3 around mixtures of indicated methods. 36 | 37 | Uses EXP3 mult-armed bandit algorithm to select sampler methods. 38 | """ 39 | def __init__(self,budget,seed=123, 40 | reward_function = lambda AL_acc: AL_acc[-1], 41 | gamma=0.5, 42 | samplers=[{'methods':('margin','uniform'),'weights':(0,1)}, 43 | {'methods':('margin','uniform'),'weights':(1,0)}]): 44 | """Initializes sampler with indicated gamma and arms. 45 | 46 | Args: 47 | X: training data 48 | y: labels, may need to be input into base samplers 49 | seed: seed to use for random sampling 50 | reward_function: reward based on previously observed accuracies. Assumes 51 | that the input is a sequence of observed accuracies. Will ultimately be 52 | a class method and may need access to other class properties. 53 | gamma: weight on uniform mixture. Arm probability updates are a weighted 54 | mixture of uniform and an exponentially weighted distribution. 55 | Lower gamma more aggressively updates based on observed rewards. 56 | samplers: list of dicts with two fields 57 | 'samplers': list of named samplers 58 | 'weights': percentage of batch to allocate to each sampler 59 | """ 60 | 61 | self.name = 'bandit_discrete' 62 | np.random.seed(seed) 63 | self.seed = seed 64 | # self.initialize_samplers(samplers) 65 | self.samplers = samplers 66 | self.gamma = gamma 67 | self.n_arms = len(samplers) 68 | self.reward_function = reward_function 69 | 70 | self.pull_history = [] 71 | self.acc_history = [] 72 | self.w = np.ones(self.n_arms) 73 | self.x = np.zeros(self.n_arms) 74 | self.p = self.w / (1.0 * self.n_arms) 75 | self.probs = [] 76 | self.selectionhistory = [] 77 | self.num_arm = float(len(self.samplers)) 78 | 79 | self.pmin = np.sqrt(np.ln(self.num_arm)/(self.num_arm*budget)) 80 | 81 | def update_vars_arnmab(self, arm_pulled): 82 | reward = self.reward_function(self.acc_history) 83 | Qkstar = [self.samplers[arm_pulled][self.selectionhistory[-1]] 84 | phistar = sum([self.p[i] * self.samplers[i].valuelist[self.selectionhistory[-1]] for i in range(len(self.samplers))) 85 | rhat = Qkstar / phistar 86 | self.w = self.w*np.exp(self.pmin/2.0*(rhat)) 87 | 88 | 89 | 90 | def update_vars(self, arm_pulled): 91 | reward = self.reward_function(self.acc_history) 92 | 93 | self.x = np.zeros(self.n_arms) 94 | self.x[arm_pulled] = reward / self.p[arm_pulled] 95 | self.w = self.w * np.exp(self.gamma * self.x / self.n_arms) 96 | self.p = ((1.0 - self.gamma) * self.w / sum(self.w) 97 | + self.gamma / self.n_arms) 98 | # print(self.p) 99 | self.probs.append(self.p) 100 | 101 | def select_batch_arnmab(N, eval_acc, wraped_feat) : 102 | self.acc_history.append(eval_acc) 103 | 104 | 105 | if len(self.pull_history) > 0: 106 | self.update_vars(self.pull_history[-1]) 107 | 108 | def select_batch(self, N, eval_acc,wraped_feat): 109 | """Returns batch of datapoints sampled using mixture of AL_methods. 110 | 111 | Assumes that data has already been shuffled. 112 | 113 | Args: 114 | already_selected: index of datapoints already selected 115 | N: batch size 116 | eval_acc: accuracy of model trained after incorporating datapoints from 117 | last recommended batch 118 | 119 | Returns: 120 | indices of points selected to label 121 | """ 122 | 123 | # print("eval_acc {}".format(eval_acc)) 124 | # exit() 125 | # Update observed reward and arm probabilities 126 | self.acc_history.append(eval_acc) 127 | if len(self.pull_history) > 0: 128 | self.update_vars(self.pull_history[-1]) 129 | # Sample an arm 130 | 131 | 132 | 133 | arm = np.random.choice(range(self.n_arms), p=self.p) 134 | self.pull_history.append(arm) 135 | # kwargs['N'] = N 136 | # kwargs['already_selected'] = already_selected 137 | 138 | # print("use arm {}".format(arm)) 139 | sample = self.samplers[arm].select_batch(wraped_feat) 140 | return sample 141 | 142 | def to_dict(self): 143 | output = {} 144 | output['samplers'] = self.base_samplers 145 | output['arm_probs'] = self.probs 146 | output['pull_history'] = self.pull_history 147 | output['rewards'] = self.acc_history 148 | return output 149 | 150 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Controls imports to fill up dictionary of different sampling methods. 16 | """ 17 | 18 | from functools import partial 19 | AL_MAPPING = {} 20 | 21 | 22 | def get_base_AL_mapping(): 23 | from src.baselines.sampling_methods.margin_AL import MarginAL 24 | from src.baselines.sampling_methods.informative_diverse import InformativeClusterDiverseSampler 25 | from src.baselines.sampling_methods.hierarchical_clustering_AL import HierarchicalClusterAL 26 | from src.baselines.sampling_methods.uniform_sampling import UniformSampling 27 | from src.baselines.sampling_methods.represent_cluster_centers import RepresentativeClusterMeanSampling 28 | from src.baselines.sampling_methods.graph_density import GraphDensitySampler 29 | from src.baselines.sampling_methods.kcenter_greedy import kCenterGreedy 30 | AL_MAPPING['margin'] = MarginAL 31 | AL_MAPPING['informative_diverse'] = InformativeClusterDiverseSampler 32 | AL_MAPPING['hierarchical'] = HierarchicalClusterAL 33 | AL_MAPPING['uniform'] = UniformSampling 34 | AL_MAPPING['margin_cluster_mean'] = RepresentativeClusterMeanSampling 35 | AL_MAPPING['graph_density'] = GraphDensitySampler 36 | AL_MAPPING['kcenter'] = kCenterGreedy 37 | 38 | 39 | def get_all_possible_arms(): 40 | from src.baselines.sampling_methods.mixture_of_samplers import MixtureOfSamplers 41 | AL_MAPPING['mixture_of_samplers'] = MixtureOfSamplers 42 | 43 | 44 | def get_wrapper_AL_mapping(): 45 | from src.baselines.sampling_methods.bandit_discrete import BanditDiscreteSampler 46 | from src.baselines.sampling_methods.simulate_batch import SimulateBatchSampler 47 | AL_MAPPING['bandit_mixture'] = partial( 48 | BanditDiscreteSampler, 49 | samplers=[{ 50 | 'methods': ['margin', 'uniform'], 51 | 'weights': [0, 1] 52 | }, { 53 | 'methods': ['margin', 'uniform'], 54 | 'weights': [0.25, 0.75] 55 | }, { 56 | 'methods': ['margin', 'uniform'], 57 | 'weights': [0.5, 0.5] 58 | }, { 59 | 'methods': ['margin', 'uniform'], 60 | 'weights': [0.75, 0.25] 61 | }, { 62 | 'methods': ['margin', 'uniform'], 63 | 'weights': [1, 0] 64 | }]) 65 | AL_MAPPING['bandit_discrete'] = partial( 66 | BanditDiscreteSampler, 67 | samplers=[{ 68 | 'methods': ['margin', 'uniform'], 69 | 'weights': [0, 1] 70 | }, { 71 | 'methods': ['margin', 'uniform'], 72 | 'weights': [1, 0] 73 | }]) 74 | AL_MAPPING['simulate_batch_mixture'] = partial( 75 | SimulateBatchSampler, 76 | samplers=({ 77 | 'methods': ['margin', 'uniform'], 78 | 'weights': [1, 0] 79 | }, { 80 | 'methods': ['margin', 'uniform'], 81 | 'weights': [0.5, 0.5] 82 | }, { 83 | 'methods': ['margin', 'uniform'], 84 | 'weights': [0, 1] 85 | }), 86 | n_sims=5, 87 | train_per_sim=10, 88 | return_best_sim=False) 89 | AL_MAPPING['simulate_batch_best_sim'] = partial( 90 | SimulateBatchSampler, 91 | samplers=[{ 92 | 'methods': ['margin', 'uniform'], 93 | 'weights': [1, 0] 94 | }], 95 | n_sims=10, 96 | train_per_sim=10, 97 | return_type='best_sim') 98 | AL_MAPPING['simulate_batch_frequency'] = partial( 99 | SimulateBatchSampler, 100 | samplers=[{ 101 | 'methods': ['margin', 'uniform'], 102 | 'weights': [1, 0] 103 | }], 104 | n_sims=10, 105 | train_per_sim=10, 106 | return_type='frequency') 107 | 108 | def get_mixture_of_samplers(name): 109 | assert 'mixture_of_samplers' in name 110 | if 'mixture_of_samplers' not in AL_MAPPING: 111 | raise KeyError('Mixture of Samplers not yet loaded.') 112 | args = name.split('-')[1:] 113 | samplers = args[0::2] 114 | weights = args[1::2] 115 | weights = [float(w) for w in weights] 116 | assert sum(weights) == 1 117 | mixture = {'methods': samplers, 'weights': weights} 118 | print(mixture) 119 | return partial(AL_MAPPING['mixture_of_samplers'], mixture=mixture) 120 | 121 | 122 | def get_AL_sampler(name): 123 | if name in AL_MAPPING and name != 'mixture_of_samplers': 124 | return AL_MAPPING[name] 125 | if 'mixture_of_samplers' in name: 126 | return get_mixture_of_samplers(name) 127 | raise NotImplementedError('The specified sampler is not available.') 128 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/graph_density.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Diversity promoting sampling method that uses graph density to determine 16 | most representative points. 17 | 18 | This is an implementation of the method described in 19 | https://www.mpi-inf.mpg.de/fileadmin/inf/d2/Research_projects_files/EbertCVPR2012.pdf 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import copy 27 | 28 | from sklearn.neighbors import kneighbors_graph 29 | from sklearn.metrics import pairwise_distances 30 | import numpy as np 31 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 32 | 33 | 34 | class GraphDensitySampler(SamplingMethod): 35 | """Diversity promoting sampling method that uses graph density to determine 36 | most representative points. 37 | """ 38 | 39 | def __init__(self, X, y, seed): 40 | self.name = 'graph_density' 41 | self.X = X 42 | self.flat_X = self.flatten_X() 43 | # Set gamma for gaussian kernel to be equal to 1/n_features 44 | self.gamma = 1. / self.X.shape[1] 45 | self.compute_graph_density() 46 | 47 | def compute_graph_density(self, n_neighbor=10): 48 | # kneighbors graph is constructed using k=10 49 | connect = kneighbors_graph(self.flat_X, n_neighbor,p=1) 50 | # Make connectivity matrix symmetric, if a point is a k nearest neighbor of 51 | # another point, make it vice versa 52 | neighbors = connect.nonzero() 53 | inds = zip(neighbors[0],neighbors[1]) 54 | # Graph edges are weighted by applying gaussian kernel to manhattan dist. 55 | # By default, gamma for rbf kernel is equal to 1/n_features but may 56 | # get better results if gamma is tuned. 57 | for entry in inds: 58 | i = entry[0] 59 | j = entry[1] 60 | distance = pairwise_distances(self.flat_X[[i]],self.flat_X[[j]],metric='manhattan') 61 | distance = distance[0,0] 62 | weight = np.exp(-distance * self.gamma) 63 | connect[i,j] = weight 64 | connect[j,i] = weight 65 | self.connect = connect 66 | # Define graph density for an observation to be sum of weights for all 67 | # edges to the node representing the datapoint. Normalize sum weights 68 | # by total number of neighbors. 69 | self.graph_density = np.zeros(self.X.shape[0]) 70 | for i in np.arange(self.X.shape[0]): 71 | self.graph_density[i] = connect[i,:].sum() / (connect[i,:]>0).sum() 72 | self.starting_density = copy.deepcopy(self.graph_density) 73 | 74 | def select_batch_(self, N, already_selected, **kwargs): 75 | # If a neighbor has already been sampled, reduce the graph density 76 | # for its direct neighbors to promote diversity. 77 | batch = set() 78 | self.graph_density[already_selected] = min(self.graph_density) - 1 79 | while len(batch) < N: 80 | selected = np.argmax(self.graph_density) 81 | neighbors = (self.connect[selected,:] > 0).nonzero()[1] 82 | self.graph_density[neighbors] = self.graph_density[neighbors] - self.graph_density[selected] 83 | batch.add(selected) 84 | self.graph_density[already_selected] = min(self.graph_density) - 1 85 | self.graph_density[list(batch)] = min(self.graph_density) - 1 86 | return list(batch) 87 | 88 | def to_dict(self): 89 | output = {} 90 | output['connectivity'] = self.connect 91 | output['graph_density'] = self.starting_density 92 | return output -------------------------------------------------------------------------------- /src/baselines/sampling_methods/informative_diverse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Informative and diverse batch sampler that samples points with small margin 16 | while maintaining same distribution over clusters as entire training data. 17 | 18 | Batch is created by sorting datapoints by increasing margin and then growing 19 | the batch greedily. A point is added to the batch if the result batch still 20 | respects the constraint that the cluster distribution of the batch will 21 | match the cluster distribution of the entire training set. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | from sklearn.cluster import MiniBatchKMeans 29 | import numpy as np 30 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 31 | 32 | 33 | class InformativeClusterDiverseSampler(SamplingMethod): 34 | """Selects batch based on informative and diverse criteria. 35 | 36 | Returns highest uncertainty lowest margin points while maintaining 37 | same distribution over clusters as entire dataset. 38 | """ 39 | 40 | def __init__(self, X, y, seed): 41 | self.name = 'informative_and_diverse' 42 | self.X = X 43 | self.flat_X = self.flatten_X() 44 | # y only used for determining how many clusters there should be 45 | # probably not practical to assume we know # of classes before hand 46 | # should also probably scale with dimensionality of data 47 | self.y = y 48 | self.n_clusters = len(list(set(y))) 49 | self.cluster_model = MiniBatchKMeans(n_clusters=self.n_clusters) 50 | self.cluster_data() 51 | 52 | def cluster_data(self): 53 | # Probably okay to always use MiniBatchKMeans 54 | # Should standardize data before clustering 55 | # Can cluster on standardized data but train on raw features if desired 56 | self.cluster_model.fit(self.flat_X) 57 | unique, counts = np.unique(self.cluster_model.labels_, return_counts=True) 58 | self.cluster_prob = counts/sum(counts) 59 | self.cluster_labels = self.cluster_model.labels_ 60 | 61 | def select_batch_(self, model, already_selected, N, **kwargs): 62 | """Returns a batch of size N using informative and diverse selection. 63 | 64 | Args: 65 | model: scikit learn model with decision_function implemented 66 | already_selected: index of datapoints already selected 67 | N: batch size 68 | 69 | Returns: 70 | indices of points selected to add using margin active learner 71 | """ 72 | # TODO(lishal): have MarginSampler and this share margin function 73 | try: 74 | distances = model.decision_function(self.X) 75 | except: 76 | distances = model.predict_proba(self.X) 77 | if len(distances.shape) < 2: 78 | min_margin = abs(distances) 79 | else: 80 | sort_distances = np.sort(distances, 1)[:, -2:] 81 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 82 | rank_ind = np.argsort(min_margin) 83 | rank_ind = [i for i in rank_ind if i not in already_selected] 84 | new_batch_cluster_counts = [0 for _ in range(self.n_clusters)] 85 | new_batch = [] 86 | for i in rank_ind: 87 | if len(new_batch) == N: 88 | break 89 | label = self.cluster_labels[i] 90 | if new_batch_cluster_counts[label] / N < self.cluster_prob[label]: 91 | new_batch.append(i) 92 | new_batch_cluster_counts[label] += 1 93 | n_slot_remaining = N - len(new_batch) 94 | batch_filler = list(set(rank_ind) - set(already_selected) - set(new_batch)) 95 | new_batch.extend(batch_filler[0:n_slot_remaining]) 96 | return new_batch 97 | 98 | def to_dict(self): 99 | output = {} 100 | output['cluster_membership'] = self.cluster_labels 101 | return output 102 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/kcenter_greedy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Returns points that minimizes the maximum distance of any point to a center. 16 | 17 | Implements the k-Center-Greedy method in 18 | Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for 19 | Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 20 | 21 | Distance metric defaults to l2 distance. Features used to calculate distance 22 | are either raw features or if a model has transform method then uses the output 23 | of model.transform(X). 24 | 25 | Can be extended to a robust k centers algorithm that ignores a certain number of 26 | outlier datapoints. Resulting centers are solution to multiple integer program. 27 | """ 28 | 29 | from __future__ import absolute_import 30 | from __future__ import division 31 | from __future__ import print_function 32 | 33 | import numpy as np 34 | from sklearn.metrics import pairwise_distances 35 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 36 | 37 | 38 | class kCenterGreedy(object): 39 | 40 | def __init__(self, metric='euclidean'): 41 | # self.X = X 42 | # self.y = y 43 | # self.flat_X = self.flatten_X() 44 | self.name = 'kcenter' 45 | # self.features = self.flat_X 46 | self.metric = metric 47 | self.min_distances = None 48 | self.n_obs = 0#self.X.shape[0] 49 | self.already_selected = [] 50 | 51 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False): 52 | """Update min distances given cluster centers. 53 | 54 | Args: 55 | cluster_centers: indices of cluster centers 56 | only_new: only calculate distance for newly selected points and update 57 | min_distances. 58 | rest_dist: whether to reset min_distances. 59 | """ 60 | 61 | if reset_dist: 62 | self.min_distances = None 63 | if only_new: 64 | cluster_centers = [d for d in cluster_centers 65 | if d not in self.already_selected] 66 | if len(cluster_centers)>0: 67 | # Update min_distances for all examples given new cluster center. 68 | x = self.features[cluster_centers] 69 | dist = pairwise_distances(self.features, x, metric=self.metric) 70 | 71 | if self.min_distances is None: 72 | self.min_distances = np.min(dist, axis=1).reshape(-1, 1) 73 | else: 74 | self.min_distances = np.minimum(self.min_distances, dist) 75 | 76 | 77 | 78 | 79 | def select_batch(self, features, already_selected, N, **kwargs): 80 | """ 81 | Diversity promoting active learning method that greedily forms a batch 82 | to minimize the maximum distance to a cluster center among all unlabeled 83 | datapoints. 84 | 85 | Args: 86 | model: model with scikit-like API with decision_function implemented 87 | already_selected: index of datapoints already selected 88 | N: batch size 89 | 90 | Returns: 91 | indices of points selected to minimize distance to cluster centers 92 | """ 93 | 94 | try: 95 | # Assumes that the transform function takes in original data and not 96 | # flattened data. 97 | # print('Getting transformed features...') 98 | self.features = features 99 | # print('Calculating distances...') 100 | self.update_distances(already_selected, only_new=False, reset_dist=True) 101 | except: 102 | # print('Using flat_X as features.') 103 | self.update_distances(already_selected, only_new=True, reset_dist=False) 104 | 105 | new_batch = [] 106 | 107 | for _ in range(N): 108 | if self.already_selected is None: 109 | # Initialize centers with a randomly selected datapoint 110 | ind = np.random.choice(np.arange(self.n_obs)) 111 | else: 112 | ind = np.argmax(self.min_distances) 113 | 114 | # New examples should not be in already selected since those points 115 | # should have min_distance of zero to a cluster center. 116 | assert ind not in already_selected 117 | 118 | self.update_distances([ind], only_new=True, reset_dist=False) 119 | new_batch.append(ind) 120 | # print('Maximum distance from cluster centers is %0.2f' 121 | # % max(self.min_distances)) 122 | 123 | 124 | self.already_selected = already_selected 125 | # print("already selected3 {}".format(self.already_selected)) 126 | 127 | 128 | return new_batch 129 | 130 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/margin_AL.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Margin based AL method. 16 | 17 | Samples in batches based on margin scores. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 26 | 27 | 28 | class MarginAL(SamplingMethod): 29 | def __init__(self, X, y, seed): 30 | self.X = X 31 | self.y = y 32 | self.name = 'margin' 33 | 34 | def select_batch_(self, model, already_selected, N, **kwargs): 35 | """Returns batch of datapoints with smallest margin/highest uncertainty. 36 | 37 | For binary classification, can just take the absolute distance to decision 38 | boundary for each point. 39 | For multiclass classification, must consider the margin between distance for 40 | top two most likely classes. 41 | 42 | Args: 43 | model: scikit learn model with decision_function implemented 44 | already_selected: index of datapoints already selected 45 | N: batch size 46 | 47 | Returns: 48 | indices of points selected to add using margin active learner 49 | """ 50 | 51 | try: 52 | distances = model.decision_function(self.X) 53 | except: 54 | distances = model.predict_proba(self.X) 55 | if len(distances.shape) < 2: 56 | min_margin = abs(distances) 57 | else: 58 | sort_distances = np.sort(distances, 1)[:, -2:] 59 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 60 | rank_ind = np.argsort(min_margin) 61 | rank_ind = [i for i in rank_ind if i not in already_selected] 62 | active_samples = rank_ind[0:N] 63 | return active_samples 64 | 65 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/mixture_of_samplers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Mixture of base sampling strategies 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import copy 24 | 25 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 26 | from src.baselines.sampling_methods.constants import AL_MAPPING, get_base_AL_mapping 27 | 28 | get_base_AL_mapping() 29 | 30 | 31 | class MixtureOfSamplers(SamplingMethod): 32 | """Samples according to mixture of base sampling methods. 33 | 34 | If duplicate points are selected by the mixed strategies when forming the batch 35 | then the remaining slots are divided according to mixture weights and 36 | another partial batch is requested until the batch is full. 37 | """ 38 | def __init__(self, 39 | X, 40 | y, 41 | seed, 42 | mixture={'methods': ('margin', 'uniform'), 43 | 'weight': (0.5, 0.5)}, 44 | samplers=None): 45 | self.X = X 46 | self.y = y 47 | self.name = 'mixture_of_samplers' 48 | self.sampling_methods = mixture['methods'] 49 | self.sampling_weights = dict(zip(mixture['methods'], mixture['weights'])) 50 | self.seed = seed 51 | # A list of initialized samplers is allowed as an input because 52 | # for AL_methods that search over different mixtures, may want mixtures to 53 | # have shared AL_methods so that initialization is only performed once for 54 | # computation intensive methods like HierarchicalClusteringAL and 55 | # states are shared between mixtures. 56 | # If initialized samplers are not provided, initialize them ourselves. 57 | if samplers is None: 58 | self.samplers = {} 59 | self.initialize(self.sampling_methods) 60 | else: 61 | self.samplers = samplers 62 | self.history = [] 63 | 64 | def initialize(self, samplers): 65 | self.samplers = {} 66 | for s in samplers: 67 | self.samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed) 68 | 69 | def select_batch_(self, already_selected, N, **kwargs): 70 | """Returns batch of datapoints selected according to mixture weights. 71 | 72 | Args: 73 | already_included: index of datapoints already selected 74 | N: batch size 75 | 76 | Returns: 77 | indices of points selected to add using margin active learner 78 | """ 79 | kwargs['already_selected'] = copy.copy(already_selected) 80 | inds = set() 81 | self.selected_by_sampler = {} 82 | for s in self.sampling_methods: 83 | self.selected_by_sampler[s] = [] 84 | effective_N = 0 85 | while len(inds) < N: 86 | effective_N += N - len(inds) 87 | for s in self.sampling_methods: 88 | if len(inds) < N: 89 | batch_size = min(max(int(self.sampling_weights[s] * effective_N), 1), N) 90 | sampler = self.samplers[s] 91 | kwargs['N'] = batch_size 92 | s_inds = sampler.select_batch(**kwargs) 93 | for ind in s_inds: 94 | if ind not in self.selected_by_sampler[s]: 95 | self.selected_by_sampler[s].append(ind) 96 | s_inds = [d for d in s_inds if d not in inds] 97 | s_inds = s_inds[0 : min(len(s_inds), N-len(inds))] 98 | inds.update(s_inds) 99 | self.history.append(copy.deepcopy(self.selected_by_sampler)) 100 | return list(inds) 101 | 102 | def to_dict(self): 103 | output = {} 104 | output['history'] = self.history 105 | output['samplers'] = self.sampling_methods 106 | output['mixture_weights'] = self.sampling_weights 107 | for s in self.samplers: 108 | s_output = self.samplers[s].to_dict() 109 | output[s] = s_output 110 | return output 111 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/represent_cluster_centers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Another informative and diverse sampler that mirrors the algorithm described 16 | in Xu, et. al., Representative Sampling for Text Classification Using 17 | Support Vector Machines, 2003 18 | 19 | Batch is created by clustering points within the margin of the classifier and 20 | choosing points closest to the k centroids. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | from sklearn.cluster import MiniBatchKMeans 28 | import numpy as np 29 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 30 | 31 | 32 | class RepresentativeClusterMeanSampling(SamplingMethod): 33 | """Selects batch based on informative and diverse criteria. 34 | 35 | Returns points within the margin of the classifier that are closest to the 36 | k-means centers of those points. 37 | """ 38 | 39 | def __init__(self, X, y, seed): 40 | self.name = 'cluster_mean' 41 | self.X = X 42 | self.flat_X = self.flatten_X() 43 | self.y = y 44 | self.seed = seed 45 | 46 | def select_batch_(self, model, N, already_selected, **kwargs): 47 | # Probably okay to always use MiniBatchKMeans 48 | # Should standardize data before clustering 49 | # Can cluster on standardized data but train on raw features if desired 50 | try: 51 | distances = model.decision_function(self.X) 52 | except: 53 | distances = model.predict_proba(self.X) 54 | if len(distances.shape) < 2: 55 | min_margin = abs(distances) 56 | else: 57 | sort_distances = np.sort(distances, 1)[:, -2:] 58 | min_margin = sort_distances[:, 1] - sort_distances[:, 0] 59 | rank_ind = np.argsort(min_margin) 60 | rank_ind = [i for i in rank_ind if i not in already_selected] 61 | 62 | distances = abs(model.decision_function(self.X)) 63 | min_margin_by_class = np.min(abs(distances[already_selected]),axis=0) 64 | unlabeled_in_margin = np.array([i for i in range(len(self.y)) 65 | if i not in already_selected and 66 | any(distances[i] 2: 42 | flat_X = np.reshape(self.X, (shape[0],np.product(shape[1:]))) 43 | return flat_X 44 | 45 | 46 | @abc.abstractmethod 47 | def select_batch_(self): 48 | return 49 | 50 | def select_batch(self, **kwargs): 51 | return self.select_batch_(**kwargs) 52 | 53 | def to_dict(self): 54 | return None -------------------------------------------------------------------------------- /src/baselines/sampling_methods/uniform_sampling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Uniform sampling method. 16 | 17 | Samples in batches. 18 | """ 19 | 20 | from __future__ import absolute_import 21 | from __future__ import division 22 | from __future__ import print_function 23 | 24 | import numpy as np 25 | 26 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 27 | 28 | 29 | class UniformSampling(SamplingMethod): 30 | 31 | def __init__(self, X, y, seed): 32 | self.X = X 33 | self.y = y 34 | self.name = 'uniform' 35 | np.random.seed(seed) 36 | 37 | def select_batch_(self, already_selected, N, **kwargs): 38 | """Returns batch of randomly sampled datapoints. 39 | 40 | Assumes that data has already been shuffled. 41 | 42 | Args: 43 | already_selected: index of datapoints already selected 44 | N: batch size 45 | 46 | Returns: 47 | indices of points selected to label 48 | """ 49 | 50 | # This is uniform given the remaining pool but biased wrt the entire pool. 51 | sample = [i for i in range(self.X.shape[0]) if i not in already_selected] 52 | return sample[0:N] 53 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/utils/tree.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Node and Tree class to support hierarchical clustering AL method. 16 | 17 | Assumed to be binary tree. 18 | 19 | Node class is used to represent each node in a hierarchical clustering. 20 | Each node has certain properties that are used in the AL method. 21 | 22 | Tree class is used to traverse a hierarchical clustering. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import copy 30 | 31 | 32 | class Node(object): 33 | """Node class for hierarchical clustering. 34 | 35 | Initialized with name and left right children. 36 | """ 37 | 38 | def __init__(self, name, left=None, right=None): 39 | self.name = name 40 | self.left = left 41 | self.right = right 42 | self.is_leaf = left is None and right is None 43 | self.parent = None 44 | # Fields for hierarchical clustering AL 45 | self.score = 1.0 46 | self.split = False 47 | self.best_label = None 48 | self.weight = None 49 | 50 | def set_parent(self, parent): 51 | self.parent = parent 52 | 53 | 54 | class Tree(object): 55 | """Tree object for traversing a binary tree. 56 | 57 | Most methods apply to trees in general with the exception of get_pruning 58 | which is specific to the hierarchical clustering AL method. 59 | """ 60 | 61 | def __init__(self, root, node_dict): 62 | """Initializes tree and creates all nodes in node_dict. 63 | 64 | Args: 65 | root: id of the root node 66 | node_dict: dictionary with node_id as keys and entries indicating 67 | left and right child of node respectively. 68 | """ 69 | self.node_dict = node_dict 70 | self.root = self.make_tree(root) 71 | self.nodes = {} 72 | self.leaves_mapping = {} 73 | self.fill_parents() 74 | self.n_leaves = None 75 | 76 | def print_tree(self, node, max_depth): 77 | """Helper function to print out tree for debugging.""" 78 | node_list = [node] 79 | output = "" 80 | level = 0 81 | while level < max_depth and len(node_list): 82 | children = set() 83 | for n in node_list: 84 | node = self.get_node(n) 85 | output += ("\t"*level+"node %d: score %.2f, weight %.2f" % 86 | (node.name, node.score, node.weight)+"\n") 87 | if node.left: 88 | children.add(node.left.name) 89 | if node.right: 90 | children.add(node.right.name) 91 | level += 1 92 | node_list = children 93 | return print(output) 94 | 95 | def make_tree(self, node_id): 96 | if node_id is not None: 97 | return Node(node_id, 98 | self.make_tree(self.node_dict[node_id][0]), 99 | self.make_tree(self.node_dict[node_id][1])) 100 | 101 | def fill_parents(self): 102 | # Setting parent and storing nodes in dict for fast access 103 | def rec(pointer, parent): 104 | if pointer is not None: 105 | self.nodes[pointer.name] = pointer 106 | pointer.set_parent(parent) 107 | rec(pointer.left, pointer) 108 | rec(pointer.right, pointer) 109 | rec(self.root, None) 110 | 111 | def get_node(self, node_id): 112 | return self.nodes[node_id] 113 | 114 | def get_ancestor(self, node): 115 | ancestors = [] 116 | if isinstance(node, int): 117 | node = self.get_node(node) 118 | while node.name != self.root.name: 119 | node = node.parent 120 | ancestors.append(node.name) 121 | return ancestors 122 | 123 | def fill_weights(self): 124 | for v in self.node_dict: 125 | node = self.get_node(v) 126 | node.weight = len(self.leaves_mapping[v]) / (1.0 * self.n_leaves) 127 | 128 | def create_child_leaves_mapping(self, leaves): 129 | """DP for creating child leaves mapping. 130 | 131 | Storing in dict to save recompute. 132 | """ 133 | self.n_leaves = len(leaves) 134 | for v in leaves: 135 | self.leaves_mapping[v] = [v] 136 | node_list = set([self.get_node(v).parent for v in leaves]) 137 | while node_list: 138 | to_fill = copy.copy(node_list) 139 | for v in node_list: 140 | if (v.left.name in self.leaves_mapping 141 | and v.right.name in self.leaves_mapping): 142 | to_fill.remove(v) 143 | self.leaves_mapping[v.name] = (self.leaves_mapping[v.left.name] + 144 | self.leaves_mapping[v.right.name]) 145 | if v.parent is not None: 146 | to_fill.add(v.parent) 147 | node_list = to_fill 148 | self.fill_weights() 149 | 150 | def get_child_leaves(self, node): 151 | return self.leaves_mapping[node] 152 | 153 | def get_pruning(self, node): 154 | if node.split: 155 | return self.get_pruning(node.left) + self.get_pruning(node.right) 156 | else: 157 | return [node.name] 158 | 159 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/utils/tree_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sampling_methods.utils.tree.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import unittest 22 | from sampling_methods.utils import tree 23 | 24 | 25 | class TreeTest(unittest.TestCase): 26 | 27 | def setUp(self): 28 | node_dict = { 29 | 1: (2, 3), 30 | 2: (4, 5), 31 | 3: (6, 7), 32 | 4: [None, None], 33 | 5: [None, None], 34 | 6: [None, None], 35 | 7: [None, None] 36 | } 37 | self.tree = tree.Tree(1, node_dict) 38 | self.tree.create_child_leaves_mapping([4, 5, 6, 7]) 39 | node = self.tree.get_node(1) 40 | node.split = True 41 | node = self.tree.get_node(2) 42 | node.split = True 43 | 44 | def assertNode(self, node, name, left, right): 45 | self.assertEqual(node.name, name) 46 | self.assertEqual(node.left.name, left) 47 | self.assertEqual(node.right.name, right) 48 | 49 | def testTreeRootSetCorrectly(self): 50 | self.assertNode(self.tree.root, 1, 2, 3) 51 | 52 | def testGetNode(self): 53 | node = self.tree.get_node(1) 54 | assert isinstance(node, tree.Node) 55 | self.assertEqual(node.name, 1) 56 | 57 | def testFillParent(self): 58 | node = self.tree.get_node(3) 59 | self.assertEqual(node.parent.name, 1) 60 | 61 | def testGetAncestors(self): 62 | ancestors = self.tree.get_ancestor(5) 63 | self.assertTrue(all([a in ancestors for a in [1, 2]])) 64 | 65 | def testChildLeaves(self): 66 | leaves = self.tree.get_child_leaves(3) 67 | self.assertTrue(all([c in leaves for c in [6, 7]])) 68 | 69 | def testFillWeights(self): 70 | node = self.tree.get_node(3) 71 | self.assertEqual(node.weight, 0.5) 72 | 73 | def testGetPruning(self): 74 | node = self.tree.get_node(1) 75 | pruning = self.tree.get_pruning(node) 76 | self.assertTrue(all([n in pruning for n in [3, 4, 5]])) 77 | 78 | if __name__ == '__main__': 79 | unittest.main() 80 | -------------------------------------------------------------------------------- /src/baselines/sampling_methods/wrapper_sampler_def.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 Google Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Abstract class for wrapper sampling methods that call base sampling methods. 16 | 17 | Provides interface to sampling methods that allow same signature 18 | for select_batch. Each subclass implements select_batch_ with the desired 19 | signature for readability. 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import abc 27 | 28 | from src.baselines.sampling_methods.constants import AL_MAPPING 29 | from src.baselines.sampling_methods.constants import get_all_possible_arms 30 | from src.baselines.sampling_methods.sampling_def import SamplingMethod 31 | 32 | get_all_possible_arms() 33 | 34 | 35 | class WrapperSamplingMethod(SamplingMethod): 36 | __metaclass__ = abc.ABCMeta 37 | 38 | def initialize_samplers(self, mixtures): 39 | methods = [] 40 | for m in mixtures: 41 | methods += m['methods'] 42 | methods = set(methods) 43 | self.base_samplers = {} 44 | for s in methods: 45 | self.base_samplers[s] = AL_MAPPING[s](self.X, self.y, self.seed) 46 | self.samplers = [] 47 | for m in mixtures: 48 | self.samplers.append( 49 | AL_MAPPING['mixture_of_samplers'](self.X, self.y, self.seed, m, 50 | self.base_samplers)) 51 | -------------------------------------------------------------------------------- /src/datasetcollecting/biggraph.py: -------------------------------------------------------------------------------- 1 | from ogb.nodeproppred.dataset_pyg import PygNodePropPredDataset 2 | import argparse 3 | import torch 4 | from src.utils.dataloader import GraphLoader 5 | import pickle as pkl 6 | def selectclass(labels): 7 | thres=4 8 | numof1or0 = labels.sum(dim=0)/labels.size(0) 9 | idxinclude = torch.where((numof1or0>0.75+(numof1or0<0.25))==0)[0] 10 | # print(len(idxinclude)) 11 | labels = labels[:,idxinclude] 12 | equalsset = set() 13 | useset =set() 14 | for i in range(len(idxinclude)): 15 | if i in equalsset: 16 | continue 17 | base = labels[:,i:i+1] 18 | rate = (labels*base).sum(dim=0)/labels.sum(dim=0) # labels 1 is inside base's 1 19 | rate2 = ((1-labels)*(1-base)).sum(dim=0)/(1-labels).sum(dim=0) # labels 0 is inside base's 0, 20 | #if many large rate, base doesn't have many 1 (at hierachy's bottom) 21 | largeloc = torch.where(rate>0.9) [0].numpy().tolist() 22 | largeloc2 = torch.where(rate2>0.9) [0].numpy().tolist() # if both high, two classes are equal 23 | 24 | 25 | 26 | 27 | 28 | # print("i:{} {}; {}".format(i,largeloc,largeloc2)) 29 | if len(largeloc)jip",input,self.weight) 35 | support = torch.reshape(support,[support.size(0),-1]) 36 | # support = torch.mm(input, self.weight) 37 | if self.bias is not None: 38 | support = support + self.bias 39 | output = torch.spmm(adj, support) 40 | output = torch.reshape(output,[output.size(0),self.batchsize,-1]) 41 | return output 42 | 43 | def __repr__(self): 44 | return self.__class__.__name__ + ' (' \ 45 | + str(self.in_features) + ' -> ' \ 46 | + str(self.out_features) + ')' 47 | 48 | 49 | class GCN(nn.Module): 50 | 51 | def __init__(self, nfeat, nhid, nclass, batchsize, dropout=0.5, softmax=True, bias=False): 52 | super(GCN, self).__init__() 53 | self.dropout_rate = dropout 54 | self.softmax = softmax 55 | self.batchsize = batchsize 56 | self.gc1 = GraphConvolution(nfeat, nhid, batchsize, bias=bias) 57 | self.gc2 = GraphConvolution(nhid, nclass, batchsize, bias=bias) 58 | self.dropout = nn.Dropout(p=dropout,inplace=False) 59 | 60 | def forward(self, x, adj): 61 | x = x.expand([self.batchsize]+list(x.size())).transpose(0,1) 62 | x = self.dropout(x) 63 | x = F.relu(self.gc1(x, adj)) 64 | x = self.dropout(x) 65 | x = self.gc2(x, adj) 66 | x = x.transpose(0,1).transpose(1,2) 67 | if self.softmax: 68 | return F.log_softmax(x, dim=1) 69 | else: 70 | return x.squeeze() 71 | 72 | def reset(self): 73 | self.gc1.reset_parameters() 74 | self.gc2.reset_parameters() -------------------------------------------------------------------------------- /src/utils/common.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | 4 | 5 | def ConfigRootLogger(name='', version=None, level="info"): 6 | logger = logging.getLogger(name) 7 | if level == "debug": 8 | logger.setLevel(logging.DEBUG) 9 | elif level == "info": 10 | logger.setLevel(logging.INFO) 11 | elif level == "warning": 12 | logger.setLevel(logging.WARNING) 13 | format = logging.Formatter("T[%(asctime)s]-V[{}]-POS[%(module)s." 14 | "%(funcName)s(line %(lineno)s)]-PID[%(process)d] %(levelname)s" 15 | ">>> %(message)s ".format(version),"%H:%M:%S") 16 | stdout_format = logging.Formatter("T[%(asctime)s]-V[{}]-POS[%(module)s." 17 | "%(funcName)s(line %(lineno)s)]-PID[%(process)d] %(levelname)s" 18 | ">>> %(message)s ".format(version),"%H:%M:%S") 19 | 20 | file_handler = logging.FileHandler("log.txt") 21 | file_handler.setFormatter(format) 22 | stream_handler = logging.StreamHandler(sys.stdout) 23 | stream_handler.setFormatter(stdout_format) 24 | logger.addHandler(file_handler) 25 | logger.addHandler(stream_handler) 26 | return logger 27 | 28 | 29 | ConfigRootLogger("main","1",level="info") 30 | logger = logging.getLogger("main") 31 | 32 | 33 | 34 | def logargs(args,tablename="",width=120 ): 35 | length = 1 36 | L=[] 37 | l= "|" 38 | for id,arg in enumerate(vars(args)): 39 | name,value = arg, str(getattr(args, arg)) 40 | nv = name+":"+value 41 | if length +(len(nv)+2)>width: 42 | L.append(l) 43 | l = "|" 44 | length = 1 45 | l += nv + " |" 46 | length += (len(nv)+2) 47 | if id+1 == len(vars(args)): 48 | L.append(l) 49 | printstr = niceprint(L) 50 | logger.info("{}:\n{}".format(tablename,printstr)) 51 | 52 | def niceprint(L,mark="-"): 53 | printstr = [] 54 | printstr.append("-"*len(L[0])) 55 | printstr.append(L[0]) 56 | for id in range(1,len(L)): 57 | printstr.append("-"*max(len(L[id-1]),len(L[id]))) 58 | printstr.append(L[id]) 59 | printstr.append("-"*len(L[-1])) 60 | printstr = "\n".join(printstr) 61 | return printstr 62 | 63 | def logdicts(dic,tablename="",width=120 ): 64 | length = 1 65 | L=[] 66 | l= "|" 67 | tup = dic.items() 68 | for id,arg in enumerate(tup): 69 | name,value = arg 70 | nv = name+":"+str(value) 71 | if length +(len(nv)+2)>width: 72 | L.append(l) 73 | l = "|" 74 | length = 1 75 | l += nv + " |" 76 | length += (len(nv)+2) 77 | if id+1 == len(tup): 78 | L.append(l) 79 | printstr = niceprint(L) 80 | 81 | logger.info("{}:\n{}".format(tablename,printstr)) 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /src/utils/const.py: -------------------------------------------------------------------------------- 1 | MIN_EPSILON = 1e-10 2 | MAX_EXP = 70 3 | -------------------------------------------------------------------------------- /src/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import networkx as nx 3 | import pickle as pkl 4 | import torch 5 | import numpy as np 6 | from collections import OrderedDict 7 | import time 8 | from src.utils.common import * 9 | from src.utils.utils import * 10 | 11 | 12 | class GraphLoader(object): 13 | 14 | def __init__(self,name,root = "./data",undirected=True, hasX=True,hasY=True,header=True,sparse=True,multigraphindex=None,args=None): 15 | 16 | self.name = name 17 | self.undirected = undirected 18 | self.hasX = hasX 19 | self.hasY = hasY 20 | self.header = header 21 | self.sparse = sparse 22 | self.dirname = os.path.join(root,name) 23 | if name == "reddit1401": 24 | self.prefix = os.path.join(root, name,multigraphindex,multigraphindex) 25 | else: 26 | self.prefix = os.path.join(root,name,name) 27 | self._load() 28 | self._registerStat() 29 | self.printStat() 30 | 31 | 32 | def _loadConfig(self): 33 | file_name = os.path.join(self.dirname,"bestconfig.txt") 34 | f = open(file_name,'r') 35 | L = f.readlines() 36 | L = [x.strip().split() for x in L] 37 | self.bestconfig = {x[0]:x[1] for x in L if len(x)!=0} 38 | 39 | 40 | def _loadGraph(self, header = True): 41 | """ 42 | load file in form: 43 | -------------------- 44 | NUM_Of_NODE\n 45 | v1 v2\n 46 | v3 v4\n 47 | -------------------- 48 | """ 49 | file_name = self.prefix+".edgelist" 50 | if not header: 51 | logger.warning("You are reading an edgelist with no explicit number of nodes") 52 | if self.undirected: 53 | G = nx.Graph() 54 | else: 55 | G = nx.DiGraph() 56 | with open(file_name) as f: 57 | L = f.readlines() 58 | if header: 59 | num_node = int(L[0].strip()) 60 | L = L[1:] 61 | edge_list = [[int(x) for x in e.strip().split()] for e in L] 62 | nodeset = set([x for e in edge_list for x in e]) 63 | # if header: 64 | # assert min(nodeset) == 0 and max(nodeset) + 1 == num_node, "input standard violated {} ,{}".format(num_node,max(nodeset)) 65 | 66 | if header: 67 | G.add_nodes_from([x for x in range(num_node)]) 68 | else: 69 | G.add_nodes_from([x for x in range(max(nodeset)+1)]) 70 | G.add_edges_from(edge_list) 71 | self.G = G 72 | 73 | def _loadX(self): 74 | self.X = pkl.load(open(self.prefix + ".x.pkl", 'rb')) 75 | self.X = self.X.astype(np.float32) 76 | if self.name in ["coauthor_phy","corafull"]: 77 | self.X = self.X[:,:2000] # the coauthor_phy's feature is too large to fit in the memory. 78 | 79 | def _loadY(self): 80 | self.Y = pkl.load(open(self.prefix+".y.pkl",'rb'))#.astype(np.float32) 81 | 82 | def _getAdj(self): 83 | self.adj = nx.adjacency_matrix(self.G).astype(np.float32) 84 | 85 | def _toTensor(self,device=None): 86 | if device is None: 87 | if self.sparse: 88 | self.adj = sparse_mx_to_torch_sparse_tensor(self.adj).cuda() 89 | self.normadj = sparse_mx_to_torch_sparse_tensor(self.normadj).cuda() 90 | else: 91 | self.adj = torch.from_numpy(self.adj).cuda() 92 | self.normadj = torch.from_numpy(self.normadj).cuda() 93 | self.X = torch.from_numpy(self.X).cuda() 94 | self.Y = torch.from_numpy(self.Y).cuda() 95 | 96 | def _load(self): 97 | self._loadGraph(header=self.header) 98 | self._loadConfig() 99 | if self.hasX: 100 | self._loadX() 101 | if self.hasY: 102 | self._loadY() 103 | self._getAdj() 104 | 105 | def _registerStat(self): 106 | L=OrderedDict() 107 | L["name"] = self.name 108 | L["nnode"] = self.G.number_of_nodes() 109 | L["nedge"] = self.G.number_of_edges() 110 | L["nfeat"] = self.X.shape[1] 111 | L["nclass"] = self.Y.max() + 1 112 | L["sparse"] = self.sparse 113 | L["multilabel"] = False 114 | L.update(self.bestconfig) 115 | self.stat = L 116 | 117 | def process(self): 118 | if int(self.bestconfig['feature_normalize']): 119 | self.X = column_normalize(preprocess_features(self.X)) # take some time 120 | 121 | # self.X = self.X - self.X.min(axis=0) 122 | # print(np.where(self.X)) 123 | # exit() 124 | # print(self.X[:3,:].tolist()) 125 | self.normadj = preprocess_adj(self.adj) 126 | if not self.sparse: 127 | self.adj = self.adj.todense() 128 | self.normadj = self.normadj.todense() 129 | self._toTensor() 130 | 131 | self.normdeg = self._getNormDeg() 132 | 133 | 134 | def printStat(self): 135 | logdicts(self.stat,tablename="dataset stat") 136 | 137 | def _getNormDeg(self): 138 | self.deg = torch.sparse.sum(self.adj, dim=1).to_dense() 139 | normdeg =self.deg/ self.deg.max() 140 | return normdeg -------------------------------------------------------------------------------- /src/utils/env.py: -------------------------------------------------------------------------------- 1 | import torch.multiprocessing as mp 2 | import time 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from src.utils.player import Player 8 | from src.utils.utils import * 9 | 10 | 11 | def logprob2Prob(logprobs,multilabel=False): 12 | if multilabel: 13 | probs = torch.sigmoid(logprobs) 14 | else: 15 | probs = F.softmax(logprobs, dim=2) 16 | return probs 17 | 18 | def normalizeEntropy(entro,classnum): #this is needed because different number of classes will have different entropy 19 | maxentro = np.log(float(classnum)) 20 | entro = entro/maxentro 21 | return entro 22 | 23 | def prob2Logprob(probs,multilabel=False): 24 | if multilabel: 25 | raise NotImplementedError("multilabel for prob2Logprob is not implemented") 26 | else: 27 | logprobs = torch.log(probs) 28 | return logprobs 29 | 30 | def perc(input): 31 | # the biger valueis the biger result is 32 | numnode = input.size(-2) 33 | res = torch.argsort(torch.argsort(input, dim=-2), dim=-2) / float(numnode) 34 | return res 35 | 36 | def degprocess(deg): 37 | # deg = torch.log(1+deg) 38 | #return deg/20. 39 | return torch.clamp_max(deg / 20., 1.) 40 | 41 | def localdiversity(probs,adj,deg): 42 | indices = adj.coalesce().indices() 43 | N =adj.size()[0] 44 | classnum = probs.size()[-1] 45 | maxentro = np.log(float(classnum)) 46 | edgeprobs = probs[:,indices.transpose(0,1),:] 47 | headprobs = edgeprobs[:,:,0,:] 48 | tailprobs = edgeprobs[:,:,1,:] 49 | kl_ht = (torch.sum(torch.log(torch.clamp_min(tailprobs,1e-10))*tailprobs,dim=-1) - \ 50 | torch.sum(torch.log(torch.clamp_min(headprobs,1e-10))*tailprobs,dim=-1)).transpose(0,1) 51 | kl_th = (torch.sum(torch.log(torch.clamp_min(headprobs,1e-10))*headprobs,dim=-1) - \ 52 | torch.sum(torch.log(torch.clamp_min(tailprobs,1e-10))*headprobs,dim=-1)).transpose(0,1) 53 | sparse_output_kl_ht = torch.sparse.FloatTensor(indices,kl_ht,size=torch.Size([N,N,kl_ht.size(-1)])) 54 | sparse_output_kl_th = torch.sparse.FloatTensor(indices,kl_th,size=torch.Size([N,N,kl_th.size(-1)])) 55 | sum_kl_ht = torch.sparse.sum(sparse_output_kl_ht,dim=1).to_dense().transpose(0,1) 56 | sum_kl_th = torch.sparse.sum(sparse_output_kl_th,dim=1).to_dense().transpose(0,1) 57 | mean_kl_ht = sum_kl_ht/(deg+1e-10) 58 | mean_kl_th = sum_kl_th/(deg+1e-10) 59 | # normalize 60 | mean_kl_ht = mean_kl_ht / mean_kl_ht.max(dim=1, keepdim=True).values 61 | mean_kl_th = mean_kl_th / mean_kl_th.max(dim=1, keepdim=True).values 62 | return mean_kl_ht,mean_kl_th 63 | 64 | 65 | class Env(object): 66 | ## an environment for multiple players testing the policy at the same time 67 | def __init__(self,players,args): 68 | ''' 69 | players: a list containing main player (many task) (or only one task 70 | ''' 71 | self.players = players 72 | self.args = args 73 | self.nplayer = len(self.players) 74 | self.graphs = [p.G for p in self.players] 75 | featdim =-1 76 | self.statedim = self.getState(0).size(featdim) 77 | 78 | 79 | def step(self,actions,playerid=0): 80 | p = self.players[playerid] 81 | p.query(actions) 82 | p.trainOnce() 83 | reward = p.validation(test=False, rerun=False) 84 | return reward 85 | 86 | 87 | def getState(self,playerid=0): 88 | p = self.players[playerid] 89 | output = logprob2Prob(p.allnodes_output.transpose(1,2),multilabel=p.G.stat["multilabel"]) 90 | state = self.makeState(output,p.trainmask,p.G.deg,playerid) 91 | return state 92 | 93 | 94 | def reset(self,playerid=0): 95 | self.players[playerid].reset(fix_test=False) 96 | 97 | 98 | def makeState(self, probs, selected, deg,playerid, adj=None, multilabel=False ): 99 | entro = entropy(probs, multilabel=multilabel) 100 | entro = normalizeEntropy(entro,probs.size(-1)) ## in order to transfer 101 | deg = degprocess(deg.expand([probs.size(0)]+list(deg.size()))) 102 | 103 | features = [] 104 | if self.args.use_entropy: 105 | features.append(entro) 106 | if self.args.use_degree: 107 | features.append(deg) 108 | if self.args.use_local_diversity: 109 | mean_kl_ht,mean_kl_th = localdiversity(probs,self.players[playerid].G.adj,self.players[playerid].G.deg) 110 | features.extend([mean_kl_ht, mean_kl_th]) 111 | if self.args.use_select: 112 | features.append(selected) 113 | state = torch.stack(features, dim=-1) 114 | 115 | return state -------------------------------------------------------------------------------- /src/utils/player.py: -------------------------------------------------------------------------------- 1 | # individual player who takes the action and evaluates the effect 2 | import torch 3 | import torch.nn as nn 4 | import time 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from src.utils.common import * 9 | from src.utils.classificationnet import GCN 10 | from src.utils.utils import * 11 | 12 | 13 | class Player(nn.Module): 14 | 15 | def __init__(self,G,args,rank=0): 16 | 17 | super(Player,self).__init__() 18 | self.G = G 19 | self.args = args 20 | self.rank = rank 21 | self.batchsize = args.batchsize 22 | 23 | if self.G.stat['multilabel']: 24 | self.net = GCN(self.G.stat['nfeat'],args.nhid,self.G.stat['nclass'],args.batchsize,args.dropout,False,bias=True).cuda() 25 | self.loss_func=F.binary_cross_entropy_with_logits 26 | else: 27 | self.net = GCN(self.G.stat['nfeat'],args.nhid,self.G.stat['nclass'],args.batchsize,args.dropout,True).cuda() 28 | self.loss_func=F.nll_loss 29 | 30 | self.fulllabel = self.G.Y.expand([self.batchsize]+list(self.G.Y.size())) 31 | 32 | self.reset(fix_test=False) #initialize 33 | self.count = 0 34 | 35 | 36 | def makeValTestMask(self, fix_test=True): 37 | #if fix_test: 38 | # assert False 39 | valmask=torch.zeros((self.batchsize,self.G.stat['nnode'])).to(torch.float).cuda() 40 | testmask = torch.zeros((self.batchsize,self.G.stat['nnode'])).to(torch.float).cuda() 41 | valid = [] 42 | testid = [] 43 | vallabel = [] 44 | testlabel = [] 45 | for i in range(self.batchsize): 46 | base = np.array([x for x in range(self.G.stat["nnode"])]) 47 | if fix_test: 48 | testid_=[x for x in range(self.G.stat["nnode"] - self.args.ntest,self.G.stat["nnode"])] 49 | else: 50 | testid_ = np.sort(np.random.choice(base, size=self.args.ntest, replace=False)).tolist() 51 | testmask[i, testid_] = 1. 52 | testid.append(testid_) 53 | testlabel.append(self.G.Y[testid_]) 54 | s = set(testid_) 55 | base= [x for x in range(self.G.stat["nnode"]) if x not in s ] 56 | valid_ = np.sort(np.random.choice(base, size=self.args.nval, replace=False)).tolist() 57 | valmask[i,valid_]=1. 58 | valid.append(valid_) 59 | vallabel.append(self.G.Y[valid_]) 60 | self.valid = torch.tensor(valid).cuda() 61 | self.testid = torch.tensor(testid).cuda() 62 | self.vallabel = torch.stack(vallabel).cuda() 63 | self.testlabel = torch.stack(testlabel).cuda() 64 | self.valmask=valmask 65 | self.testmask=testmask 66 | 67 | 68 | def lossWeighting(self,epoch): 69 | return min(epoch,10.)/10. 70 | 71 | 72 | def query(self,nodes): 73 | self.trainmask[[x for x in range(self.batchsize)],nodes] = 1. 74 | 75 | 76 | def getPool(self,reduce=True): 77 | mask = self.testmask+self.valmask+self.trainmask 78 | row,col = torch.where(mask<0.1) 79 | if reduce: 80 | row, col = row.cpu().numpy(),col.cpu().numpy() 81 | pool = [] 82 | for i in range(self.batchsize): 83 | pool.append(col[row==i]) 84 | return pool 85 | else: 86 | return row,col 87 | 88 | 89 | def trainOnce(self,log=False): 90 | nlabeled = torch.sum(self.trainmask)/self.batchsize 91 | self.net.train() 92 | self.opt.zero_grad() 93 | output = self.net(self.G.X,self.G.normadj) 94 | # print(output.size()) 95 | # exit() 96 | if self.G.stat["multilabel"]: 97 | output_trans = output.transpose(1,2) 98 | # print(output_trans[:,-20:,:]) 99 | 100 | losses = self.loss_func(output_trans,self.fulllabel,reduction="none").sum(dim=2) 101 | else: 102 | losses = self.loss_func(output,self.fulllabel,reduction="none") 103 | loss = torch.sum(losses*self.trainmask)/nlabeled*self.lossWeighting(float(nlabeled.cpu())) 104 | loss.backward() 105 | self.opt.step() 106 | #if log: 107 | #logger.info("nnodes selected:{},loss:{}".format(nlabeled,loss.detach().cpu().numpy())) 108 | self.allnodes_output=output.detach() 109 | return output 110 | 111 | 112 | def validation(self,test=False,rerun=True): 113 | if test: 114 | mask = self.testmask 115 | labels= self.testlabel 116 | index = self.testid 117 | else: 118 | mask = self.valmask 119 | labels = self.vallabel 120 | index = self.valid 121 | if rerun: 122 | self.net.eval() 123 | output = self.net(self.G.X,self.G.normadj) 124 | else: 125 | output = self.allnodes_output 126 | if self.G.stat["multilabel"]: 127 | # logger.info("output of classification {}".format(output)) 128 | output_trans = output.transpose(1,2) 129 | losses_val = self.loss_func(output_trans,self.fulllabel,reduction="none").mean(dim=2) 130 | else: 131 | losses_val = self.loss_func(output,self.fulllabel,reduction="none") 132 | loss_val = torch.sum(losses_val*mask,dim =1,keepdim=True)/torch.sum(mask,dim =1,keepdim=True) 133 | acc= [] 134 | for i in range(self.batchsize): 135 | pred_val = (output[i][:,index[i]]).transpose(0,1) 136 | # logger.info("pred_val {}".format(pred_val)) 137 | acc.append(accuracy(pred_val,labels[i])) 138 | 139 | # logger.info("validation acc {}".format(acc)) 140 | return list(zip(*acc)) 141 | 142 | 143 | def trainRemain(self): 144 | for i in range(self.args.remain_epoch): 145 | self.trainOnce() 146 | 147 | 148 | def reset(self,resplit=True,fix_test=True): 149 | if resplit: 150 | self.makeValTestMask(fix_test=fix_test) 151 | self.trainmask = torch.zeros((self.batchsize,self.G.stat['nnode'])).to(torch.float).cuda() 152 | self.net.reset() 153 | self.opt = torch.optim.Adam(self.net.parameters(),lr=self.args.lr,weight_decay=5e-4) 154 | self.allnodes_output = self.net(self.G.X,self.G.normadj).detach() 155 | 156 | 157 | 158 | import argparse 159 | def parse_args(): 160 | parser = argparse.ArgumentParser() 161 | parser.add_argument("--dropout",type=float,default=0.5) 162 | parser.add_argument("--ntest",type=int,default=1000) 163 | parser.add_argument("--nval",type=int,default=500) 164 | parser.add_argument("--nhid", type=int, default=64) 165 | parser.add_argument("--lr", type=float, default=3e-2) 166 | parser.add_argument("--batchsize", type=int, default=2) 167 | parser.add_argument("--budget", type=int, default=20, help="budget per class") 168 | parser.add_argument("--dataset", type=str, default="cora") 169 | parser.add_argument("--remain_epoch", type=int, default=35, help="continues training $remain_epoch") 170 | 171 | args = parser.parse_args() 172 | return args 173 | 174 | if __name__=="__main__": 175 | from src.utils.dataloader import GraphLoader 176 | args = parse_args() 177 | G = GraphLoader("cora") 178 | G.process() 179 | p = Player(G,args) 180 | p.query([2,3]) 181 | p.query([4,6]) 182 | 183 | p.trainOnce() 184 | 185 | print(p.trainmask[:,:10]) 186 | print(p.allnodes_output[0].size()) -------------------------------------------------------------------------------- /src/utils/query.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.distributions import Categorical 5 | import torch.nn.functional as F 6 | 7 | 8 | def choose(p,pool): 9 | return np.random.choice(pool,size=1,replace=False,p=p) 10 | 11 | 12 | class RandomQuery(object): 13 | def __init__(self): 14 | pass 15 | 16 | def __call__(self,pool): 17 | ret = [] 18 | for row in pool: 19 | p = np.ones(len(row)) 20 | p/=p.sum() 21 | ret .append(choose(p,row)) 22 | ret = np.concatenate(ret) 23 | return ret 24 | 25 | 26 | class ProbQuery(object): 27 | def __init__(self,type="soft"): 28 | self.type = type 29 | 30 | def __call__(self,probs,pool): 31 | if self.type == "soft": 32 | return self.softquery(probs,pool) 33 | elif self.type == "hard": 34 | return self.hardquery(probs,pool) 35 | 36 | def softquery(self,logits,pool): 37 | batchsize = logits.size(0) 38 | valid_logits = logits[pool].reshape(batchsize,-1) 39 | max_logits = torch.max(valid_logits,dim=1,keepdim=True)[0].detach() 40 | valid_logits = valid_logits - max_logits #torch.clamp(valid_logits,max = MAX_EXP) 41 | valid_probs = F.softmax(valid_logits, dim=1) 42 | pool = pool[1].reshape(batchsize, -1) 43 | assert pool.size() == valid_probs.size() 44 | m = Categorical(valid_probs) 45 | action_inpool = m.sample() 46 | action = pool[[x for x in range(batchsize)], action_inpool] 47 | return action 48 | 49 | def hardquery(self,logits,pool): 50 | batchsize = logits.size(0) 51 | valid_logits = logits[pool].reshape(batchsize,-1) 52 | max_logits = torch.max(valid_logits,dim=1,keepdim=True)[0].detach() 53 | valid_logits = valid_logits - max_logits #torch.clamp(valid_logits,max = MAX_EXP) 54 | valid_probs = F.softmax(valid_logits, dim=1) 55 | 56 | pool = pool[1].reshape(batchsize, -1) 57 | action_inpool = torch.argmax(valid_probs,dim=1) 58 | action = pool[[x for x in range(batchsize)], action_inpool] 59 | return action 60 | 61 | 62 | def unitTestProbQuery(): 63 | probs = F.softmax(torch.randn(4,7)*3,dim=1) 64 | mask = torch.zeros_like(probs) 65 | mask[:,1] = 1 66 | pool = torch.where(mask==0) 67 | print(probs,pool) 68 | 69 | q = ProbQuery(type = "soft") 70 | action = q(probs,pool) 71 | print(action) 72 | 73 | q.type = "hard" 74 | action = q(probs,pool) 75 | print(action) 76 | 77 | 78 | def selectActions(self,logits,pool): 79 | valid_logits = logits[pool].reshape(self.args.batchsize,-1) 80 | max_logits = torch.max(valid_logits,dim=1,keepdim=True)[0].detach() 81 | if self.globel_number %10==0: 82 | logger.info(max_logits) 83 | valid_logits = valid_logits - max_logits #torch.clamp(valid_logits,max = MAX_EXP) 84 | 85 | # valid_logprobs = F.log_softmax(valid_logits,dim=1) 86 | valid_probs = F.softmax(valid_logits, dim=1) 87 | pool = pool[1].reshape(self.args.batchsize,-1) 88 | assert pool.size()==valid_probs.size() 89 | 90 | m = Categorical(valid_probs) 91 | action_inpool = m.sample() 92 | logprob = m.log_prob(action_inpool) 93 | action = pool[[x for x in range(self.args.batchsize)],action_inpool] 94 | 95 | return action,logprob 96 | if __name__=="__main__": 97 | unitTestProbQuery() 98 | -------------------------------------------------------------------------------- /src/utils/rewardshaper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.utils.common import * 3 | 4 | 5 | class RewardShaper(object): 6 | def __init__(self,args): 7 | self.metric = args.metric 8 | self.shaping = args.shaping 9 | self.entcoef = args.entcoef 10 | self.gamma = 0.5 11 | self.rate = 0.05 12 | self.alpha = args.frweight # to make tune the ratio of finalreward and midreward 13 | if "0" in self.shaping and "1" in self.shaping: 14 | raise ValueError("arguments invalid") 15 | self.hashistorymean,self.hashistoryvar = False,False 16 | if "3" in self.shaping: 17 | self.historymean = np.zeros((1000,args.batchsize)) 18 | if "4" in self.shaping: 19 | self.histrvar = np.zeros((1000,args.batchsize)) 20 | pass 21 | 22 | def reshape(self,rewards_all,finalrewards_all,logprobs): 23 | rewards_sub, finalrewards = self._roughProcess(rewards_all,finalrewards_all) 24 | self.componentRatio(rewards_sub,finalrewards,logprobs) 25 | rewards = np.zeros_like(rewards_sub) 26 | if "0" in self.shaping: 27 | rewards += rewards_sub 28 | if "1" in self.shaping: 29 | for i in range(rewards_sub.shape[0]-1,0,-1): 30 | rewards_sub[i-1] += self.gamma*rewards_sub[i] 31 | rewards += rewards_sub 32 | if "2" in self.shaping: #2 33 | rewards += finalrewards*self.alpha 34 | 35 | if "3" in self.shaping: 36 | if not self.hashistorymean: 37 | self.historymean[:rewards_sub.shape[0], :] += rewards.mean(1,keepdims=True) 38 | self.hashistorymean = True 39 | else: 40 | self.historymean[:rewards_sub.shape[0], :] = self.historymean[:rewards_sub.shape[0], :] * (1 - self.rate) + self.rate * rewards.mean(1,keepdims=True) 41 | rewards = rewards - self.historymean[:rewards.shape[0], :] 42 | 43 | if "4" in self.shaping: 44 | if not self.hashistoryvar: 45 | self.histrvar[:rewards_sub.shape[0], :] += (rewards**2).mean(1,keepdims=True) 46 | self.hashistoryvar = True 47 | else: 48 | self.histrvar[:rewards_sub.shape[0], :] = self.histrvar[:rewards.shape[0], :] * ( 49 | 1 - self.rate) + self.rate * (rewards**2).mean(1,keepdims=True) 50 | rewards = rewards/np.power(self.histrvar[:rewards.shape[0], :],0.5) 51 | return rewards 52 | 53 | def _roughProcess(self,rewards_all,finalrewards_all): 54 | mic,mac = [np.array(x) for x in list(zip(*rewards_all))] 55 | finalmic,finalmac = [np.array(x) for x in finalrewards_all] 56 | if self.metric == "microf1": 57 | rewards = mic 58 | finalrewards = finalmic 59 | elif self.metric == "macrof1": 60 | rewards = mac 61 | finalrewards = finalmac 62 | elif self.metric == "mix": 63 | rewards = (mic+mac)/2 64 | finalrewards = (finalmic+finalmac)/2 65 | else: 66 | raise NotImplementedError("metric <{}> is not implemented".format(self.metric)) 67 | rewards_sub = rewards[1:,:]-rewards[:-1,:] 68 | return rewards_sub,finalrewards 69 | 70 | def componentRatio(self,rewards_sub,finalrewards_all,logprobs): 71 | r_mean = np.mean(np.abs(rewards_sub)) 72 | f_mean = np.mean(finalrewards_all) 73 | lp_mean = np.mean(np.abs(logprobs)) 74 | f_ratio = f_mean/r_mean*self.alpha 75 | lp_ratio = lp_mean/r_mean*self.entcoef 76 | logger.debug("rmean {:.4f},fratio {:.2f}x{:.4f}={:.3f}, " 77 | "lpratio {:.1f}x{:.5f}={:.3f}".format(r_mean, 78 | f_mean/r_mean,self.alpha,f_ratio, 79 | lp_mean/r_mean,self.entcoef,lp_ratio)) -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.sparse as sp 3 | import warnings 4 | from sklearn.metrics import f1_score 5 | import torch 6 | import math 7 | 8 | from src.utils.common import * 9 | 10 | 11 | def preprocess_features(features): 12 | """Row-normalize feature matrix and convert to tuple representation""" 13 | rowsum = np.array(features.sum(1)) 14 | r_inv = np.power(rowsum, -1).flatten() 15 | r_inv[np.isinf(r_inv)] = 0. 16 | r_mat_inv = np.diag(r_inv) 17 | features = r_mat_inv.dot(features) 18 | return features 19 | 20 | def column_normalize(tens): 21 | ret = tens - tens.mean(axis=0) 22 | return ret 23 | 24 | def normalize_adj(adj): 25 | """Symmetrically normalize adjacency matrix.""" 26 | adj = sp.coo_matrix(adj) 27 | rowsum = np.array(adj.sum(1)) 28 | d_inv_sqrt = np.power(rowsum, -0.5).flatten() 29 | d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0. 30 | d_mat_inv_sqrt = sp.diags(d_inv_sqrt) 31 | return adj.dot(d_mat_inv_sqrt).transpose().dot(d_mat_inv_sqrt).tocoo() 32 | 33 | def preprocess_adj(adj): 34 | """Preprocessing of adjacency matrix for simple GCN model and conversion to tuple representation.""" 35 | adj_add_diag=adj + sp.eye(adj.shape[0]) 36 | adj_normalized = normalize_adj(adj_add_diag) 37 | return adj_normalized.astype(np.float32) #sp.coo_matrix(adj_unnorm) 38 | 39 | def sparse_mx_to_torch_sparse_tensor(sparse_mx): 40 | """Convert a scipy sparse matrix to a torch sparse tensor.""" 41 | sparse_mx = sparse_mx.tocoo().astype(np.float32) 42 | indices = torch.from_numpy( 43 | np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) 44 | values = torch.from_numpy(sparse_mx.data) 45 | shape = torch.Size(sparse_mx.shape) 46 | return torch.sparse.FloatTensor(indices, values, shape) 47 | 48 | ##========================================================================= 49 | 50 | def accuracy(y_pred, labels): 51 | if len(labels.size())==1: 52 | y_pred = y_pred.max(1)[1].type_as(labels) 53 | y_pred=y_pred.cpu().detach().numpy() 54 | labels=labels.cpu().numpy() 55 | 56 | 57 | elif len(labels.size())==2: 58 | # print("rawy_pred",y_pred) 59 | y_pred=(y_pred > 0.).cpu().detach().numpy() 60 | labels=labels.cpu().numpy() 61 | 62 | # y_pred = np.zeros_like(y_pred) 63 | 64 | # print("y_pred",y_pred[:10,:]) 65 | # print("labels",labels[:10,:]) 66 | # exit() 67 | 68 | 69 | 70 | with warnings.catch_warnings(): 71 | warnings.simplefilter("ignore") 72 | mic,mac=f1_score(labels, y_pred, average="micro"), f1_score(labels, y_pred, average="macro") 73 | return mic,mac 74 | 75 | def mean_std(L): 76 | if type(L)==np.ndarray: 77 | L=L.tolist() 78 | m=sum(L)/float(len(L)) 79 | bias=[(x-m)**2 for x in L] 80 | std=math.sqrt(sum(bias)/float(len(L)-1)) 81 | return [float(m)*100.,float(std)*100.] 82 | 83 | ##========================================================================== 84 | 85 | def entropy(tens,multilabel=False): 86 | if multilabel:#Todo 87 | reverse=1-tens 88 | ent_1= -torch.log(torch.clamp(tens, min=1e-7)) * tens 89 | ent_2= -torch.log(torch.clamp(reverse,min=1e-7))*reverse 90 | ent=ent_1+ent_2 91 | entropy=torch.mean(ent,dim=1) 92 | else: 93 | assert type(tens)==torch.Tensor and len(tens.size())==3,"calculating entropy of wrong size" 94 | entropy = - torch.log(torch.clamp(tens, min=1e-7)) * tens 95 | entropy = torch.sum(entropy, dim=2) 96 | return entropy 97 | 98 | 99 | ##========================================================================== 100 | 101 | 102 | class AverageMeter(object): 103 | def __init__(self,name='',ave_step=10): 104 | self.name = name 105 | self.ave_step = ave_step 106 | self.history =[] 107 | self.history_extrem = None 108 | self.S=5 109 | 110 | def update(self,data): 111 | if data is not None: 112 | self.history.append(data) 113 | 114 | def __call__(self): 115 | if len(self.history) == 0: 116 | value = None 117 | else: 118 | cal=self.history[-self.ave_step:] 119 | value = sum(cal)/float(len(cal)) 120 | return value 121 | 122 | def should_save(self): 123 | if len(self.history)>self.S*2 and sum(self.history[-self.S:])/float(self.S)> sum(self.history[-self.S*2:])/float(self.S*2): 124 | if self.history_extrem is None : 125 | self.history_extrem =sum(self.history[-self.S:])/float(self.S) 126 | return False 127 | else: 128 | if self.history_extrem < sum(self.history[-self.S:])/float(self.S): 129 | self.history_extrem = sum(self.history[-self.S:])/float(self.S) 130 | return True 131 | else: 132 | return False 133 | else: 134 | return False 135 | 136 | 137 | #=========================================================== 138 | 139 | def inspect_grad(model): 140 | name_grad = [(x[0], x[1].grad) for x in model.named_parameters() if x[1].grad is not None] 141 | name, grad = zip(*name_grad) 142 | assert not len(grad) == 0, "no layer requires grad" 143 | mean_grad = [torch.mean(x) for x in grad] 144 | max_grad = [torch.max(x) for x in grad] 145 | min_grad = [torch.min(x) for x in grad] 146 | logger.info("name {}, mean_max min {}".format(name,list(zip(mean_grad, max_grad, min_grad)))) 147 | 148 | def inspect_weight(model): 149 | name_weight = [x[1] for x in model.named_parameters() if x[1].grad is not None] 150 | print("network_weight:{}".format(name_weight)) 151 | 152 | 153 | #============================================================== 154 | 155 | def common_rate(counts,prediction,seq): 156 | summation = counts.sum(dim=1, keepdim=True) 157 | squaresum = (counts ** 2).sum(dim=1, keepdim=True) 158 | ret = (summation ** 2 - squaresum) / (summation * (summation - 1)+1) 159 | # print("here1") 160 | equal_rate=counts[seq,prediction].reshape(-1,1)/(summation+1) 161 | # print(ret,equal_rate) 162 | return ret,equal_rate 163 | 164 | --------------------------------------------------------------------------------