├── submission ├── example_submission │ ├── 0.csv │ ├── 1.csv │ ├── 2.csv │ ├── 3.csv │ └── 4.csv └── bilge │ ├── 0.csv │ ├── 1.csv │ ├── 2.csv │ ├── 3.csv │ └── 4.csv ├── requirements.txt ├── .DS_Store ├── overview.png ├── result └── .DS_Store ├── config └── bilge20230301.json ├── src ├── price.txt ├── utils.py ├── example.py ├── pricefunction.py ├── buyer.py ├── seller.py ├── dam.py ├── evaluator.py ├── marketengine.py ├── marketengine_demo.ipynb ├── evaluator_submission.py ├── evaluator_acc_cost.py ├── visualize_acc_cost.py ├── helper.py └── visualizetools.py ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── README.md └── LICENSE /submission/example_submission/0.csv: -------------------------------------------------------------------------------- 1 | 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,10,10,10,10,15 -------------------------------------------------------------------------------- /submission/example_submission/1.csv: -------------------------------------------------------------------------------- 1 | 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,10,10,10,10,15 -------------------------------------------------------------------------------- /submission/example_submission/2.csv: -------------------------------------------------------------------------------- 1 | 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,10,10,10,10,15 -------------------------------------------------------------------------------- /submission/example_submission/3.csv: -------------------------------------------------------------------------------- 1 | 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,10,10,10,10,15 -------------------------------------------------------------------------------- /submission/example_submission/4.csv: -------------------------------------------------------------------------------- 1 | 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,10,10,10,10,15 -------------------------------------------------------------------------------- /submission/bilge/0.csv: -------------------------------------------------------------------------------- 1 | 32, 26, 4, 23, 2, 30, 24, 7, 12, 8, 19, 8, 17, 6, 12, 6, 14, 20, 10, 20 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.0 2 | pandas==1.5.3 3 | scikit_learn==1.2.0 4 | pyarrow==10.0.1 5 | 6 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Data_Acquisition_for_ML_Benchmark/HEAD/.DS_Store -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Data_Acquisition_for_ML_Benchmark/HEAD/overview.png -------------------------------------------------------------------------------- /result/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/Data_Acquisition_for_ML_Benchmark/HEAD/result/.DS_Store -------------------------------------------------------------------------------- /submission/bilge/1.csv: -------------------------------------------------------------------------------- 1 | 228, 96, 216, 348, 48, 60, 108, 0, 108, 0, 180, 204, 252, 288, 108, 204, 204, 168, 120, 60 2 | -------------------------------------------------------------------------------- /submission/bilge/2.csv: -------------------------------------------------------------------------------- 1 | 408, 108, 324, 348, 24, 48, 120, 12, 0, 36, 156, 192, 300, 240, 72, 276, 0, 156, 72, 108 2 | -------------------------------------------------------------------------------- /submission/bilge/3.csv: -------------------------------------------------------------------------------- 1 | 408, 108, 324, 348, 24, 48, 120, 12, 0, 36, 156, 192, 300, 240, 72, 276, 0, 156, 72, 108 2 | -------------------------------------------------------------------------------- /submission/bilge/4.csv: -------------------------------------------------------------------------------- 1 | 420, 48, 204, 348, 36, 144, 192, 288, 24, 48, 240, 132, 132, 48, 168, 168, 0, 300, 12, 48 2 | -------------------------------------------------------------------------------- /config/bilge20230301.json: -------------------------------------------------------------------------------- 1 | { 2 | "instance_ids":["0","1","2","3","4"], 3 | "submission_path":"../submission/bilge/", 4 | "model_name":"knn", 5 | "save_path":"../result/bilge20230301_knn.json" 6 | } 7 | -------------------------------------------------------------------------------- /src/price.txt: -------------------------------------------------------------------------------- 1 | Lin,100 2 | Lin,100 3 | Lin,100 4 | Lin,100 5 | Lin,100 6 | Lin,100 7 | Lin,100 8 | Lin,100 9 | Lin,100 10 | Lin,100 11 | Lin,100 12 | Lin,100 13 | Lin,100 14 | Lin,100 15 | Lin,100 16 | Lin,100 17 | Lin,100 18 | Lin,100 19 | Lin,100 20 | Lin,100 21 | Lin,100 -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:39:21 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | def pricefunc_lin(frac = 1, 26 | max_p = 100): 27 | p1 = max_p * frac 28 | return p1 29 | 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this repository 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to this repository, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /src/example.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | from dam import Dam 18 | print("Loading Dataset...") 19 | instance=2 # instance id, can be 0,1,2,3,4 20 | MyDam = Dam(instance=instance) 21 | print("Dataset loaded!") 22 | budget = MyDam.getbudget() # get budget 23 | print("budget is:",budget) 24 | # 3. Display seller_data 25 | buyer_data = MyDam.getbuyerdata() # get buyer data 26 | print("buyer data is:",buyer_data) 27 | 28 | 29 | mlmodel = MyDam.getmlmodel() # get ml model 30 | print("mlmodel is",mlmodel) 31 | 32 | sellers_id = MyDam.getsellerid() # seller ids 33 | print("seller ids are", sellers_id) 34 | for i in sellers_id: 35 | seller_i_price, seller_i_summary, seller_i_samples = MyDam.getsellerinfo(seller_id=int(i)) 36 | print("seller ", i, " price: ", seller_i_price.get_price_samplesize(100)) 37 | print("seller ", i, " summary: ", seller_i_summary) 38 | print("seller ", i, " samples: ", seller_i_samples) 39 | 40 | -------------------------------------------------------------------------------- /src/pricefunction.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:35:29 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | 26 | 27 | class PriceFunction(object): 28 | 29 | def __init__(self): 30 | return 31 | 32 | def setup(self, max_p = 100, method="lin", 33 | data_size=1): 34 | self.max_p = max_p 35 | self.method = "lin" 36 | self.data_size = data_size 37 | 38 | def get_price(self, 39 | frac=1, 40 | ): 41 | if(frac<0 or frac>1): 42 | raise ValueError("The fraction of samples must be within [0,1]!") 43 | max_p = self.max_p 44 | if(self.method=="lin"): 45 | p1 = max_p * frac 46 | return p1 47 | 48 | return 49 | 50 | def get_price_samplesize(self, 51 | samplesize=10, 52 | ): 53 | frac = samplesize/self.data_size 54 | #print("frac is",frac) 55 | return self.get_price(frac) 56 | 57 | 58 | 59 | 60 | def main(): 61 | print("test of the price func") 62 | 63 | 64 | if __name__ == '__main__': 65 | main() 66 | 67 | -------------------------------------------------------------------------------- /src/buyer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:35:29 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | from sklearn.linear_model import LogisticRegression 26 | 27 | import numpy 28 | 29 | class Buyer(object): 30 | 31 | def __init__(self): 32 | return 33 | 34 | def loaddata(self, 35 | data=None, 36 | datapath=None,): 37 | if(not (data is None)): 38 | self.data = data 39 | return 40 | if(datapath != None): 41 | self.data = numpy.loadtxt(open(datapath, "rb"), 42 | delimiter=",", 43 | skiprows=1) 44 | return 45 | raise ValueError("Not implemented load data of buyer") 46 | return 47 | 48 | def load_stretagy(self, 49 | stretagy=None): 50 | return 51 | 52 | def get_stretagy(self): 53 | return self.stretagy 54 | 55 | def load_mlmodel(self, 56 | mlmodel): 57 | self.mlmodel = mlmodel 58 | return 0 59 | 60 | def train_mlmodel(self, 61 | train_data): 62 | 63 | X = train_data[:,0:-1] 64 | y = numpy.ravel(train_data[:,-1]) 65 | self.mlmodel.fit(X,y) 66 | X_1 = self.data[:,0:-1] 67 | y_1 = numpy.ravel(self.data[:,-1]) 68 | eval_acc = self.mlmodel.score(X_1, y_1) 69 | return eval_acc 70 | 71 | 72 | def main(): 73 | print("test of the buyer") 74 | MyBuyer = Buyer() 75 | 76 | 77 | 78 | MyBuyer.loaddata(data=numpy.asmatrix([[0,1,1,1],[1,0,1,0]])) 79 | 80 | mlmodel1 = LogisticRegression(random_state=0) 81 | 82 | MyBuyer.load_mlmodel(mlmodel1) 83 | 84 | train_data = numpy.asmatrix([[0,1,1,1],[1,0,1,0],[1,1,1,1]]) 85 | 86 | eval1 = MyBuyer.train_mlmodel(train_data) 87 | 88 | print("eval acc",eval1) 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /src/seller.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:35:29 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | from pricefunction import PriceFunction 26 | import numpy 27 | numpy.random.seed(1111) 28 | 29 | class Seller(object): 30 | 31 | def __init__(self): 32 | return 33 | 34 | def loaddata(self, 35 | data=None, 36 | datapath=None,): 37 | # data: a m x n matrix 38 | # datapath: a path to a csv file. 39 | # the file should be a matrix with column names. 40 | if(not (data is None)): 41 | self.data = data 42 | return 43 | if(datapath != None): 44 | self.data = numpy.loadtxt(open(datapath, "rb"), 45 | delimiter=",", 46 | skiprows=1) 47 | return 48 | print("Not implemented load data of seller") 49 | return 50 | 51 | def setprice(self, pricefunc): 52 | self.pricefunc = pricefunc 53 | 54 | def getprice(self,data_size): 55 | q1 = data_size/(len(self.data)) 56 | return self.pricefunc.get_price(q1) 57 | 58 | def getdata(self, data_size, price): 59 | data = self.data 60 | q1 = data_size/(len(self.data)) 61 | if(q1>1): 62 | raise ValueError("The required number of samples is too large!") 63 | 64 | if(self.pricefunc.get_price(q1) <= price): 65 | number_of_rows = self.data.shape[0] 66 | random_indices = numpy.random.choice(number_of_rows, 67 | size=data_size, 68 | replace=True) 69 | rows = data[random_indices, :] 70 | return rows 71 | else: 72 | raise ValueError("The buyer's offer is too small!") 73 | return 74 | 75 | 76 | 77 | def main(): 78 | print("test of the seller") 79 | MySeller = Seller() 80 | 81 | MySeller.loaddata(data=numpy.asmatrix([[0,1,1],[1,0,1]])) 82 | 83 | MyPricing = PriceFunction() 84 | MyPricing.setup(max_p = 100, method="lin") 85 | 86 | MySeller.setprice(MyPricing) 87 | 88 | data = MySeller.getdata(1,60) 89 | 90 | print("get data is ",data) 91 | 92 | if __name__ == '__main__': 93 | main() 94 | 95 | -------------------------------------------------------------------------------- /src/dam.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | import numpy 18 | import pickle 19 | import json 20 | from pricefunction import PriceFunction 21 | import pandas 22 | class Dam(object): 23 | def __init__(self, 24 | instance=0, 25 | ): 26 | self._instance = instance 27 | self._marketpath="../marketinfo/" 28 | if(instance not in [0,1,2,3,4]): 29 | raise ValueError("the instance id is incorrect. it must be 0, 1, 2, 3, or 4.") 30 | return 31 | 32 | def getbudget(self,): 33 | budget = numpy.loadtxt(self._marketpath+str(self._instance)+"/price/"+"/budget.txt") 34 | return float(budget) 35 | 36 | def getbuyerdata(self,): 37 | path = self._marketpath+str(self._instance)+"/data_buyer/"+"/20.csv" 38 | buydata = pandas.read_csv(path,header=None,engine="pyarrow").to_numpy() 39 | return buydata 40 | 41 | def getmlmodel(self,): 42 | path = self._marketpath+str(self._instance)+"/data_buyer/"+"/mlmodel.pickle" 43 | with open(path, 'rb') as handle: 44 | model = pickle.load(handle) 45 | return model 46 | 47 | def getsellerid(self,): 48 | path = self._marketpath+str(self._instance)+"/sellerid.txt" 49 | ids = numpy.loadtxt(path) 50 | return ids 51 | 52 | def getsellerinfo(self,seller_id): 53 | path = self._marketpath+str(self._instance)+"/summary/"+str(seller_id)+".csv.json" 54 | f = open(path) 55 | ids = json.load(f) 56 | 57 | price = numpy.loadtxt(self._marketpath+str(self._instance)+"/price/"+"/price.txt", 58 | delimiter=',',dtype=str) 59 | price_i = price[seller_id] 60 | MyPricing1 = PriceFunction() 61 | #print("row number",ids['row_number']) 62 | MyPricing1.setup(max_p = float(price_i[1]), method=price_i[0], data_size=ids['row_number']) 63 | 64 | 65 | samples = numpy.loadtxt(self._marketpath+str(self._instance)+"/summary/"+str(seller_id)+".csvsamples.csv", 66 | delimiter=' ',dtype=float) 67 | 68 | 69 | return MyPricing1, ids, samples 70 | 71 | 72 | def main(): 73 | MyDam = Dam() 74 | budget = MyDam.getbudget() # get budget 75 | buyer_data = MyDam.getbuyerdata() # get buyer data 76 | mlmodel = MyDam.getmlmodel() # get ml model 77 | sellers_id = MyDam.getsellerid() 78 | i=0 79 | seller_i_price, seller_i_summary, seller_i_samples = MyDam.getsellerinfo(seller_id=i) 80 | 81 | return 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /src/evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:36:30 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | 26 | from sklearn.linear_model import LogisticRegression 27 | 28 | 29 | 30 | import numpy 31 | from seller import Seller 32 | from buyer import Buyer 33 | from pricefunction import PriceFunction 34 | from marketengine import MarketEngine 35 | from helper import Helper 36 | 37 | class Evaluator(object): 38 | def __init__(self): 39 | self.Helper = Helper() 40 | return 41 | def eval_submission(self, 42 | submission, 43 | seller_data, 44 | buyer_data, 45 | seller_price, 46 | buyer_budget=100, 47 | mlmodel=LogisticRegression(random_state=0), 48 | ): 49 | ''' 50 | 51 | 52 | Parameters 53 | ---------- 54 | submission : TYPE 55 | DESCRIPTION. 56 | seller_data_path : TYPE 57 | DESCRIPTION. 58 | buyer_data_path : TYPE 59 | DESCRIPTION. 60 | price_data_path : TYPE 61 | mlmodel: TYPE 62 | DESCRIPTION. 63 | : TYPE 64 | DESCRIPTION. 65 | 66 | Returns 67 | ------- 68 | None. 69 | 70 | ''' 71 | 72 | MyMarketEngine = MarketEngine() 73 | MyHelper = self.Helper 74 | 75 | 76 | # set up the market 77 | MyMarketEngine.setup_market(seller_data=seller_data, 78 | seller_prices = seller_price, 79 | buyer_data=buyer_data, 80 | buyer_budget=buyer_budget, 81 | mlmodel=mlmodel, 82 | ) 83 | 84 | # get train data 85 | traindata = MyHelper.load_data(submission, MyMarketEngine) 86 | # train the model 87 | model = MyHelper.train_model(mlmodel, traindata[:,0:-1], 88 | numpy.ravel(traindata[:,-1])) 89 | # eval the model 90 | acc1 = MyHelper.eval_model(model,test_X=buyer_data[:,0:-1],test_Y=buyer_data[:,-1]) 91 | return acc1 92 | 93 | def main(): 94 | print("test of the evaluator") 95 | submission = [[1,2],[50,50]] 96 | data_1 = numpy.asmatrix([[0,1,0],[1,0,0]]) 97 | data_2 = numpy.asmatrix([[0,1,1],[1,0,1],[1,1,1],[0,0,1]]) 98 | seller_data = [data_1, data_2] 99 | buyer_data = numpy.asmatrix([[0,1,0],[1,0,1],[0,1,1]]) 100 | MyPricing1 = PriceFunction() 101 | MyPricing1.setup(max_p = 100, method="lin") 102 | MyPricing2 = PriceFunction() 103 | MyPricing2.setup(max_p = 100, method="lin") 104 | seller_price = [MyPricing1, MyPricing2] 105 | 106 | MyEval = Evaluator() 107 | acc1 = MyEval.eval_submission( 108 | submission, 109 | seller_data, 110 | buyer_data, 111 | seller_price, 112 | buyer_budget=100, 113 | mlmodel=LogisticRegression(random_state=0), 114 | ) 115 | print("acc is:", acc1) 116 | if __name__ == '__main__': 117 | main() 118 | 119 | 120 | 121 | 122 | -------------------------------------------------------------------------------- /src/marketengine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:36:30 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | 26 | from sklearn.linear_model import LogisticRegression 27 | 28 | 29 | 30 | import numpy 31 | from seller import Seller 32 | from buyer import Buyer 33 | from pricefunction import PriceFunction 34 | 35 | 36 | class MarketEngine(object): 37 | def __init__(self): 38 | return 39 | 40 | def setup_market(self, 41 | seller_data=None, 42 | seller_prices=None, 43 | buyer_data=None, 44 | buyer_budget=None, 45 | mlmodel=None): 46 | sellers = list() 47 | for i in range(len(seller_data)): 48 | MySeller = Seller() 49 | MySeller.loaddata(data=seller_data[i]) 50 | MySeller.setprice(seller_prices[i]) 51 | sellers.append(MySeller) 52 | self.sellers = sellers 53 | 54 | MyBuyer = Buyer() 55 | MyBuyer.loaddata(data=buyer_data) 56 | mlmodel1 = mlmodel 57 | MyBuyer.load_mlmodel(mlmodel1) 58 | self.buyer = MyBuyer 59 | self.buyer_budget = buyer_budget 60 | #print("set up the market") 61 | return 62 | 63 | def load_stretagy(self, 64 | stretagy=None,): 65 | self.stretagy = stretagy 66 | 67 | return 68 | 69 | def train_buyer_model(self): 70 | print(" train buyer model ") 71 | 72 | 73 | # check if the budget constraint is satisified. 74 | cost = sum(self.stretagy[1]) 75 | if(cost>self.buyer_budget): 76 | raise ValueError("The budget constraint is not satisifed!") 77 | return 78 | 79 | traindata = None 80 | for i in range(len(self.sellers)): 81 | d1 = self.sellers[i].getdata(self.stretagy[0][i],self.stretagy[1][i]) 82 | if(i==0): 83 | traindata = d1 84 | else: 85 | traindata = numpy.concatenate((traindata,d1)) 86 | print(i,d1) 87 | 88 | print("budget checked! data loaded!") 89 | #print("train data", traindata) 90 | acc = self.buyer.train_mlmodel(traindata) 91 | return acc 92 | 93 | 94 | def main(): 95 | print("test of the market engine") 96 | MyMarketEngine = MarketEngine() 97 | 98 | data_1 = numpy.asmatrix([[0,1,0],[1,0,0]]) 99 | data_2 = numpy.asmatrix([[0,1,1],[1,0,1],[1,1,1],[0,0,1]]) 100 | data_b = numpy.asmatrix([[0,1,0],[1,0,1],[0,1,1]]) 101 | 102 | buyer_budget = 100 103 | 104 | MyPricing1 = PriceFunction() 105 | MyPricing1.setup(max_p = 100, method="lin") 106 | MyPricing2 = PriceFunction() 107 | MyPricing2.setup(max_p = 100, method="lin") 108 | 109 | 110 | mlmodel1 = LogisticRegression(random_state=0) 111 | 112 | 113 | MyMarketEngine.setup_market(seller_data=[data_1,data_2], 114 | seller_prices = [MyPricing1,MyPricing2], 115 | buyer_data=data_b, 116 | buyer_budget=buyer_budget, 117 | mlmodel=mlmodel1, 118 | ) 119 | 120 | stretagy = [[1,2],[50,50]] 121 | MyMarketEngine.load_stretagy(stretagy) 122 | 123 | acc1 = MyMarketEngine.train_buyer_model() 124 | print("acc is ",acc1) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /src/marketengine_demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Copyright (c) Meta Platforms, Inc. and affiliates.\n", 8 | "\n", 9 | "Licensed under the Apache License, Version 2.0 (the \"License\");\n", 10 | "you may not use this file except in compliance with the License.\n", 11 | "You may obtain a copy of the License at\n", 12 | "\n", 13 | " http://www.apache.org/licenses/LICENSE-2.0\n", 14 | "\n", 15 | "Unless required by applicable law or agreed to in writing, software\n", 16 | "distributed under the License is distributed on an \"AS IS\" BASIS,\n", 17 | "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", 18 | "See the License for the specific language governing permissions and\n", 19 | "limitations under the License." 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 7, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# The market demo\n", 29 | "from marketengine import MarketEngine\n", 30 | "from pricefunction import PriceFunction\n", 31 | "from sklearn.linear_model import LogisticRegression\n", 32 | "import numpy\n", 33 | "MyMarketEngine = MarketEngine()" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 8, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "set up the market\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "# Set up the market\n", 51 | "MyMarketEngine = MarketEngine()\n", 52 | " \n", 53 | "# load the dataset \n", 54 | "data_1 = numpy.asmatrix([[0,1,0],[1,0,0]]) \n", 55 | "data_2 = numpy.asmatrix([[0,1,1],[1,0,1],[1,1,1],[0,0,1]])\n", 56 | "data_b = numpy.asmatrix([[0,1,0],[1,0,1],[0,1,1]])\n", 57 | "\n", 58 | "# buyer budget\n", 59 | "buyer_budget = 100\n", 60 | " \n", 61 | "# seller price \n", 62 | "MyPricing1 = PriceFunction()\n", 63 | "MyPricing1.setup(max_p = 100, method=\"lin\")\n", 64 | "MyPricing2 = PriceFunction()\n", 65 | "MyPricing2.setup(max_p = 100, method=\"lin\")\n", 66 | "\n", 67 | "\n", 68 | "mlmodel1 = LogisticRegression(random_state=0)\n", 69 | "\n", 70 | " \n", 71 | "MyMarketEngine.setup_market(seller_data=[data_1,data_2],\n", 72 | " seller_prices = [MyPricing1,MyPricing2],\n", 73 | " buyer_data=data_b,\n", 74 | " buyer_budget=buyer_budget,\n", 75 | " mlmodel=mlmodel1,\n", 76 | " )" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": 11, 82 | "metadata": {}, 83 | "outputs": [ 84 | { 85 | "name": "stdout", 86 | "output_type": "stream", 87 | "text": [ 88 | " train buyer model \n", 89 | "0 [[0 1 0]]\n", 90 | "1 [[1 0 1]\n", 91 | " [0 1 1]]\n", 92 | "budget checked! data loaded!\n", 93 | "acc for the strategy is 0.6666666666666666\n" 94 | ] 95 | }, 96 | { 97 | "name": "stderr", 98 | "output_type": "stream", 99 | "text": [ 100 | "/Users/lingjiao/opt/anaconda3/lib/python3.9/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n", 101 | " warnings.warn(\n", 102 | "/Users/lingjiao/opt/anaconda3/lib/python3.9/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n", 103 | " warnings.warn(\n" 104 | ] 105 | } 106 | ], 107 | "source": [ 108 | "# Eval a stretagy\n", 109 | "stretagy = [[1,2],[50,50]]\n", 110 | "MyMarketEngine.load_stretagy(stretagy)\n", 111 | "acc1 = MyMarketEngine.train_buyer_model()\n", 112 | "print(\"acc for the strategy is\", acc1)" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [] 121 | } 122 | ], 123 | "metadata": { 124 | "kernelspec": { 125 | "display_name": "Python 3 (ipykernel)", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.9.12" 140 | } 141 | }, 142 | "nbformat": 4, 143 | "nbformat_minor": 4 144 | } 145 | -------------------------------------------------------------------------------- /src/evaluator_submission.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:36:30 2022 21 | 22 | @author: lingjiao 23 | """ 24 | from sklearn.linear_model import LogisticRegression 25 | from sklearn.ensemble import GradientBoostingClassifier 26 | import numpy 27 | from marketengine import MarketEngine 28 | from helper import Helper 29 | from sklearn.neighbors import KNeighborsClassifier 30 | import json 31 | 32 | def evaluate_batch(data_config, 33 | ): 34 | instance_ids = data_config['instance_ids'] 35 | result = dict() 36 | for id1 in instance_ids: 37 | result[id1] = evaluate_multiple_trial(data_config,instance_id=id1) 38 | return result 39 | 40 | def evaluate_multiple_trial(data_config, 41 | instance_id, 42 | num_trial=10, 43 | ): 44 | 45 | results = [evaluate_once(data_config=data_config, 46 | instance_id=instance_id) for i in range(num_trial)] 47 | #print("results are:",results) 48 | results_avg = dict() 49 | results_avg['cost'] = 0 50 | results_avg['acc'] = 0 51 | for item in results: 52 | #print("item is:",item) 53 | results_avg['cost'] += item['cost']/len(results) 54 | results_avg['acc'] += item['acc']/len(results) 55 | return results_avg 56 | 57 | def evaluate_once(data_config, 58 | instance_id): 59 | # load submission 60 | submission = load_submission(path = data_config['submission_path']+str(instance_id)+".csv") 61 | 62 | # get the helper 63 | model_name = data_config['model_name'] 64 | MarketHelper, MarketEngineObj, model, traindata, buyer_data = get_market_info(instance_id=instance_id, 65 | model_name=model_name) 66 | 67 | # calculate the cost of the submission 68 | cost = MarketHelper.get_cost(submission,MarketEngineObj) 69 | 70 | # generate the accuracy of the submission 71 | traindata = MarketHelper.load_data(submission, MarketEngineObj) 72 | model = MarketHelper.train_model(model, traindata[:,0:-1], 73 | numpy.ravel(traindata[:,-1])) 74 | acc1 = MarketHelper.eval_model(model,test_X=buyer_data[:,0:-1],test_Y=buyer_data[:,-1]) 75 | 76 | result = dict() 77 | result['cost'] = cost 78 | result['acc'] = acc1 79 | return result 80 | 81 | def load_submission(path): 82 | data = numpy.loadtxt(path,delimiter=",",dtype=int) 83 | return data 84 | 85 | def get_market_info(instance_id, 86 | model_name="lr"): 87 | MyHelper = Helper() 88 | seller_data, seller_prices, buyer_data, buyer_budget, data_size = MyHelper.load_market_instance( 89 | feature_path="../features/"+str(instance_id)+"/", 90 | buyer_data_path="../marketinfo/"+str(instance_id)+"/data_buyer/20.csv", 91 | price_path="../marketinfo/"+str(instance_id)+"/price/price.txt", 92 | budget_path="../marketinfo/"+str(instance_id)+"/price/budget.txt", 93 | ) 94 | MyMarketEngine = MarketEngine() 95 | mlmodel1 = LogisticRegression(random_state=0) 96 | if(model_name=="knn"): 97 | mlmodel1 = KNeighborsClassifier(n_neighbors=9) 98 | if(model_name=='rf'): 99 | mlmodel1 = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, 100 | max_depth=1, random_state=0) 101 | MyMarketEngine.setup_market(seller_data=seller_data, 102 | seller_prices = seller_prices, 103 | buyer_data=buyer_data, 104 | buyer_budget=1e10, 105 | mlmodel=mlmodel1, 106 | ) 107 | 108 | return MyHelper, MyMarketEngine, mlmodel1,seller_data, buyer_data 109 | 110 | def main(): 111 | data_config = json.load(open("../config/bilge20230301_rf.json")) # load the data folder 112 | result = evaluate_batch(data_config) 113 | json_object = json.dumps(result, indent=4) 114 | save_path = data_config['save_path'] 115 | with open(save_path, "w") as outfile: 116 | outfile.write(json_object) 117 | print("The result is:",result) 118 | 119 | return 120 | 121 | if __name__ == '__main__': 122 | main() 123 | 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /src/evaluator_acc_cost.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:36:30 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | 26 | from sklearn.linear_model import LogisticRegression 27 | 28 | from sklearn.ensemble import GradientBoostingClassifier 29 | 30 | import numpy 31 | from seller import Seller 32 | from buyer import Buyer 33 | from pricefunction import PriceFunction 34 | from marketengine import MarketEngine 35 | from helper import Helper 36 | import pandas 37 | from sklearn.neighbors import KNeighborsClassifier 38 | 39 | 40 | def evaluate( 41 | MarketHelper, 42 | MarketEngineObj, 43 | model, 44 | buyer_data, 45 | trial=100, # number of trials per budget 46 | seller_data_size_list = [100,200,300], 47 | cost_scale=0.1, 48 | method="single", 49 | full_price=100, 50 | ): 51 | trial_list = list(range(trial)) 52 | acc_list = list() 53 | cost_list = list() 54 | budget_list = list() 55 | for i in range(trial): 56 | print("trial:",i) 57 | # generate a submission 58 | submission = gen_submission(seller_data_size_list,cost_scale=cost_scale, 59 | method=method) 60 | # calculate the cost of the submission 61 | cost = MarketHelper.get_cost(submission,MarketEngineObj) 62 | # generate the accuracy of the submission 63 | traindata = MarketHelper.load_data(submission, MarketEngineObj) 64 | model = MarketHelper.train_model(model, traindata[:,0:-1], 65 | numpy.ravel(traindata[:,-1])) 66 | acc1 = MarketHelper.eval_model(model,test_X=buyer_data[:,0:-1],test_Y=buyer_data[:,-1]) 67 | 68 | cost_list.append(cost) 69 | acc_list.append(acc1) 70 | budget_list.append(cost_scale*full_price) 71 | result = pandas.DataFrame() 72 | result['trial'] = trial_list 73 | result['acc'] = acc_list 74 | result['cost'] = cost_list 75 | result['budget'] = budget_list 76 | return result 77 | 78 | ''' generate a pandas dataframe 79 | 80 | trial,accuracy, cost 81 | ''' 82 | 83 | def gen_submission(seller_data_size_list=[100,200,300], 84 | cost_scale=1, 85 | method="uniform"): 86 | if(method=="uniform"): 87 | d = len(seller_data_size_list) 88 | submission = [numpy.random.randint(0,int(a*cost_scale/d*2)) for a in seller_data_size_list] 89 | if(method=="single"): 90 | submission = [0]*len(seller_data_size_list) 91 | index = numpy.random.randint(0,len(submission)) 92 | submission[index] = int(seller_data_size_list[index]*cost_scale) 93 | return submission 94 | 95 | def evaluate_budget(MarketHelper, 96 | MarketEngineObj, 97 | model, 98 | buyer_data, 99 | trial=100, # number of trials per budget 100 | seller_data_size_list = [100,200,300], 101 | cost_scale_list=[0.1], 102 | method="single", 103 | ): 104 | results = [evaluate( 105 | MarketHelper=MarketHelper, 106 | MarketEngineObj=MarketEngineObj, 107 | model=model, 108 | buyer_data=buyer_data, 109 | trial=trial, # number of trials per budget 110 | seller_data_size_list = seller_data_size_list, 111 | cost_scale=c1, 112 | method=method, 113 | ) for c1 in cost_scale_list] 114 | full_result = pandas.concat(results, ignore_index=True,axis=0) 115 | return full_result 116 | 117 | def evaluate_full(instance_id=0, 118 | method="single", 119 | model_name="knn",): 120 | print("evaluate acc and cost tradeoffs") 121 | # instance_id=0 122 | # method="single" 123 | # model_name="knn" 124 | MyHelper = Helper() 125 | seller_data, seller_prices, buyer_data, buyer_budget, data_size = MyHelper.load_market_instance( 126 | feature_path="../features/"+str(instance_id)+"/", 127 | buyer_data_path="../marketinfo/"+str(instance_id)+"/data_buyer/20.csv", 128 | price_path="../marketinfo/"+str(instance_id)+"/price/price.txt", 129 | budget_path="../marketinfo/"+str(instance_id)+"/price/budget.txt", 130 | ) 131 | numpy.savetxt("../marketinfo/"+str(instance_id)+"/seller_datasize.csv",data_size,fmt="%d") 132 | 133 | MyMarketEngine = MarketEngine() 134 | mlmodel1 = LogisticRegression(random_state=0) 135 | if(model_name=="knn"): 136 | mlmodel1 = KNeighborsClassifier(n_neighbors=9) 137 | if(model_name=='rf'): 138 | mlmodel1 = GradientBoostingClassifier(n_estimators=100, learning_rate=1.0, 139 | max_depth=1, random_state=0) 140 | MyMarketEngine.setup_market(seller_data=seller_data, 141 | seller_prices = seller_prices, 142 | buyer_data=buyer_data, 143 | buyer_budget=1e10, 144 | mlmodel=mlmodel1, 145 | ) 146 | 147 | result = evaluate( 148 | MarketHelper=MyHelper, 149 | MarketEngineObj=MyMarketEngine, 150 | model=mlmodel1, 151 | buyer_data=buyer_data, 152 | trial=10, # number of trials per budget 153 | seller_data_size_list = numpy.loadtxt("../marketinfo/"+str(instance_id)+"/seller_datasize.csv"), 154 | cost_scale=0.1, 155 | ) 156 | result2 = evaluate_budget( 157 | MarketHelper=MyHelper, 158 | MarketEngineObj=MyMarketEngine, 159 | model=mlmodel1, 160 | buyer_data=buyer_data, 161 | trial=100, # number of trials per budget 162 | seller_data_size_list = numpy.loadtxt("../marketinfo/" + str(instance_id) +"/seller_datasize.csv"), 163 | # cost_scale_list=[0.005,0.0075,0.01,0.025,0.05,0.075,0.1], 164 | cost_scale_list=[0.01,0.025,0.05,0.1,0.2], 165 | method=method, 166 | # cost_scale_list=[0.05,0.1,0.5,1], 167 | # method="single", 168 | ) 169 | folder1 = "../logs/"+str(instance_id)+"/" 170 | 171 | result2.to_csv(folder1+"acc_cost_tradeoffs_"+method+"_"+model_name+".csv") 172 | print("result is:",result) 173 | return 174 | 175 | def main(): 176 | instance_ids = [3,4] 177 | methods = ['single','uniform'] 178 | for instance_id in instance_ids: 179 | for method in methods: 180 | evaluate_full(instance_id=instance_id,method=method,model_name="knn") 181 | evaluate_full(instance_id=instance_id,method=method,model_name="logreg") 182 | evaluate_full(instance_id=instance_id,method=method,model_name="rf") 183 | 184 | 185 | return 186 | 187 | if __name__ == '__main__': 188 | main() 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /src/visualize_acc_cost.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:36:30 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | 26 | from sklearn.linear_model import LogisticRegression 27 | 28 | import matplotlib.pyplot as plt 29 | import matplotlib 30 | 31 | import numpy 32 | from seller import Seller 33 | from buyer import Buyer 34 | from pricefunction import PriceFunction 35 | from marketengine import MarketEngine 36 | from helper import Helper 37 | import pandas 38 | from sklearn.neighbors import KNeighborsClassifier 39 | import seaborn as sns 40 | 41 | def visualize_acc_cost(data_path="../logs/0/acc_cost_tradeoffs_uniform_logreg.csv", 42 | savepath="../figures/", 43 | ): 44 | plt.clf() 45 | data = pandas.read_csv(data_path) 46 | print("data",data) 47 | mean1 = data.groupby("budget").mean() 48 | var1 = data.groupby("budget").var() 49 | max1 = data.groupby("budget").max() 50 | min1 = data.groupby("budget").min() 51 | print("mean1 of acc",mean1['acc']) 52 | print("var",var1['acc']) 53 | print("diff, max, and min",max1['acc']-min1['acc'],max1['acc'],min1['acc']) 54 | sns.color_palette("tab10") 55 | swarm_plot = sns.histplot(data=data, x="acc", hue="budget",palette=["C0", "C1", "C2","C3","C4"]) 56 | #swarm_plot = sns.scatterplot(data=data, x= "cost",y="acc") 57 | plt.figure() 58 | fig = swarm_plot.get_figure() 59 | data_parse = data_path.split("/") 60 | method = data_parse[-1].split("_")[-2] 61 | instanceid = data_parse[-2] 62 | ml = data_parse[-1].split("_")[-1] 63 | fig.savefig(savepath+str(instanceid)+"/"+method+ml+".pdf") 64 | 65 | plt.figure() 66 | 67 | swarm_plot = sns.lineplot(data=data, y="acc", x="budget", err_style="band") 68 | fig2 = swarm_plot.get_figure() 69 | fig2.savefig(savepath+str(instanceid)+"/"+method+ml+"_line.pdf") 70 | 71 | 72 | return 73 | 74 | def evaluate( 75 | MarketHelper, 76 | MarketEngineObj, 77 | model, 78 | buyer_data, 79 | trial=100, # number of trials per budget 80 | seller_data_size_list = [100,200,300], 81 | cost_scale=0.1, 82 | method="single", 83 | ): 84 | trial_list = list(range(trial)) 85 | acc_list = list() 86 | cost_list = list() 87 | 88 | for i in range(trial): 89 | print("trial:",i) 90 | # generate a submission 91 | submission = gen_submission(seller_data_size_list,cost_scale=cost_scale, 92 | method=method) 93 | # calculate the cost of the submission 94 | cost = MarketHelper.get_cost(submission,MarketEngineObj) 95 | # generate the accuracy of the submission 96 | traindata = MarketHelper.load_data(submission, MarketEngineObj) 97 | model = MarketHelper.train_model(model, traindata[:,0:-1], 98 | numpy.ravel(traindata[:,-1])) 99 | acc1 = MarketHelper.eval_model(model,test_X=buyer_data[:,0:-1],test_Y=buyer_data[:,-1]) 100 | 101 | cost_list.append(cost) 102 | acc_list.append(acc1) 103 | 104 | result = pandas.DataFrame() 105 | result['trial'] = trial_list 106 | result['acc'] = acc_list 107 | result['cost'] = cost_list 108 | return result 109 | 110 | ''' generate a pandas dataframe 111 | 112 | trial,accuracy, cost 113 | ''' 114 | 115 | def gen_submission(seller_data_size_list=[100,200,300], 116 | cost_scale=1, 117 | method="uniform"): 118 | if(method=="uniform"): 119 | submission = [numpy.random.randint(0,int(a*cost_scale)) for a in seller_data_size_list] 120 | if(method=="single"): 121 | submission = [0]*len(seller_data_size_list) 122 | index = numpy.random.randint(0,len(submission)) 123 | submission[index] = int(seller_data_size_list[index]*cost_scale) 124 | return submission 125 | 126 | def evaluate_budget(MarketHelper, 127 | MarketEngineObj, 128 | model, 129 | buyer_data, 130 | trial=100, # number of trials per budget 131 | seller_data_size_list = [100,200,300], 132 | cost_scale_list=[0.1], 133 | method="single", 134 | ): 135 | results = [evaluate( 136 | MarketHelper=MarketHelper, 137 | MarketEngineObj=MarketEngineObj, 138 | model=model, 139 | buyer_data=buyer_data, 140 | trial=trial, # number of trials per budget 141 | seller_data_size_list = seller_data_size_list, 142 | cost_scale=c1, 143 | method=method, 144 | ) for c1 in cost_scale_list] 145 | full_result = pandas.concat(results, ignore_index=True,axis=0) 146 | return full_result 147 | 148 | 149 | def main(): 150 | matplotlib.pyplot.close('all') 151 | instance_ids = [0,1,2,3,4] 152 | methods = ['single','uniform'] 153 | 154 | methods=['uniform'] 155 | for instance_id in instance_ids: 156 | for method in methods: 157 | #visualize_acc_cost(data_path="../logs/"+str(instance_id)+"/acc_cost_tradeoffs_"+method+"_knn.csv") 158 | visualize_acc_cost(data_path="../logs/"+str(instance_id)+"/acc_cost_tradeoffs_"+method+"_rf.csv") 159 | #visualize_acc_cost(data_path="../logs/"+str(instance_id)+"/acc_cost_tradeoffs_"+method+"_logreg.csv") 160 | 161 | ''' 162 | print("evaluate acc and cost tradeoffs") 163 | instance_id=0 164 | MyHelper = Helper() 165 | seller_data, seller_prices, buyer_data, buyer_budget, data_size = MyHelper.load_market_instance( 166 | feature_path="../features/"+str(instance_id)+"/", 167 | buyer_data_path="../marketinfo/"+str(instance_id)+"/data_buyer/20.csv", 168 | price_path="../marketinfo/"+str(instance_id)+"/price/price.txt", 169 | budget_path="../marketinfo/"+str(instance_id)+"/price/budget.txt", 170 | ) 171 | 172 | MyMarketEngine = MarketEngine() 173 | mlmodel1 = LogisticRegression(random_state=0) 174 | mlmodel1 = KNeighborsClassifier(n_neighbors=9) 175 | 176 | MyMarketEngine.setup_market(seller_data=seller_data, 177 | seller_prices = seller_prices, 178 | buyer_data=buyer_data, 179 | buyer_budget=1e10, 180 | mlmodel=mlmodel1, 181 | ) 182 | 183 | result = evaluate( 184 | MarketHelper=MyHelper, 185 | MarketEngineObj=MyMarketEngine, 186 | model=mlmodel1, 187 | buyer_data=buyer_data, 188 | trial=10, # number of trials per budget 189 | seller_data_size_list = numpy.loadtxt("../marketinfo/"+str(instance_id)+"/seller_datasize.csv"), 190 | cost_scale=0.1, 191 | ) 192 | result2 = evaluate_budget( 193 | MarketHelper=MyHelper, 194 | MarketEngineObj=MyMarketEngine, 195 | model=mlmodel1, 196 | buyer_data=buyer_data, 197 | trial=100, # number of trials per budget 198 | seller_data_size_list = numpy.loadtxt("../marketinfo/" + str(instance_id) +"/seller_datasize.csv"), 199 | # cost_scale_list=[0.005,0.0075,0.01,0.025], 200 | # method="uniform", 201 | cost_scale_list=[0.05,0.1,0.5,1], 202 | method="single", 203 | ) 204 | folder1 = "../logs/"+str(instance_id)+"/" 205 | 206 | result2.to_csv(folder1+"acc_cost_tradeoffs.csv") 207 | print("result is:",result) 208 | ''' 209 | if __name__ == '__main__': 210 | main() 211 | 212 | 213 | 214 | 215 | -------------------------------------------------------------------------------- /src/helper.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | """ 16 | 17 | #!/usr/bin/env python3 18 | # -*- coding: utf-8 -*- 19 | """ 20 | Created on Tue Aug 16 18:36:30 2022 21 | 22 | @author: lingjiao 23 | """ 24 | 25 | 26 | from sklearn.linear_model import LogisticRegression 27 | from sklearn.neighbors import KNeighborsClassifier 28 | from sklearn.ensemble import RandomForestClassifier 29 | from sklearn.dummy import DummyClassifier 30 | import numpy 31 | from seller import Seller 32 | from buyer import Buyer 33 | from pricefunction import PriceFunction 34 | from marketengine import MarketEngine 35 | import glob 36 | import pandas 37 | 38 | def sub2stretagy(submission,MarketEngineObj): 39 | stretagy1 = list() 40 | cost1 = list() 41 | for i in range(len(submission)): 42 | stretagy1.append(submission[i]) 43 | cost1.append(MarketEngineObj.sellers[i].getprice(submission[i])) 44 | stretagy = list() 45 | stretagy.append(stretagy1) 46 | stretagy.append(cost1) 47 | #print("stretagy is:",stretagy) 48 | return stretagy 49 | 50 | class Helper(object): 51 | def __init__(self): 52 | return 53 | 54 | def get_cost(self,submission,MarketEngineObj): 55 | stretagy = sub2stretagy(submission,MarketEngineObj) 56 | cost = sum(stretagy[1]) 57 | return cost 58 | 59 | def load_data(self, submission, MarketEngineObj): 60 | ''' 61 | load submissions. 62 | return: train X and y 63 | ''' 64 | 65 | #print(" train buyer model ") 66 | 67 | stretagy = sub2stretagy(submission,MarketEngineObj) 68 | buyer_budget = MarketEngineObj.buyer_budget 69 | print("strategy is:",stretagy) 70 | # check if the budget constraint is satisified. 71 | cost = sum(stretagy[1]) 72 | if(cost>buyer_budget): 73 | raise ValueError("The budget constraint is not satisifed!") 74 | return 75 | 76 | traindata = None 77 | for i in range(len(MarketEngineObj.sellers)): 78 | d1 = MarketEngineObj.sellers[i].getdata(stretagy[0][i],stretagy[1][i]) 79 | if(i==0): 80 | traindata = d1 81 | else: 82 | traindata = numpy.concatenate((traindata,d1)) 83 | return traindata 84 | 85 | 86 | def train_model(self, model, train_X, train_Y): 87 | model.fit(train_X,train_Y) 88 | return model 89 | 90 | def eval_model(self, model, test_X, test_Y): 91 | eval_acc = model.score(test_X, test_Y) 92 | return eval_acc 93 | 94 | def load_market_instance(self, 95 | feature_path="features/0/", 96 | buyer_data_path="buyerdata.csv", 97 | price_path="price.txt", 98 | budget_path="budget.txt", 99 | ): 100 | paths = glob.glob(feature_path+"*.csv") 101 | print("paths:",paths) 102 | # 1. load seller data 103 | seller_data = list() 104 | seller_prices = list() 105 | buyer_budget = numpy.loadtxt(budget_path) 106 | buyer_budget = float(buyer_budget) 107 | #print('budget_ is', type(buyer_budget)) 108 | # datafull = [numpy.loadtxt(path,delimiter=',') for path in paths] 109 | datafull = [pandas.read_csv(path,header=None,engine="pyarrow").to_numpy() for path in paths] 110 | seller_datasize = [len(data1) for data1 in datafull] 111 | pricefull = numpy.loadtxt(price_path,delimiter=',',dtype=str) 112 | for i in range(len(datafull)): 113 | if(1): 114 | seller_data.append(datafull[i]) 115 | #print(pricefull[i]) 116 | MyPricing1 = PriceFunction() 117 | MyPricing1.setup(max_p = float(pricefull[i][1]), method=pricefull[i][0]) 118 | seller_prices.append(MyPricing1) 119 | # buyer_data = numpy.loadtxt(buyer_data_path,delimiter=',') 120 | buyer_data = pandas.read_csv(buyer_data_path,header=None,engine="pyarrow").to_numpy() 121 | return seller_data, seller_prices, buyer_data, buyer_budget, seller_datasize 122 | def main(): 123 | print("test of the helper") 124 | MyMarketEngine = MarketEngine() 125 | 126 | data_1 = numpy.asmatrix([[0,1,0],[1,0,0]]) 127 | data_2 = numpy.asmatrix([[0,1,1],[1,0,1],[1,1,1],[0,0,1]]) 128 | data_b = numpy.asmatrix([[0,1,0],[1,0,1],[0,1,1]]) 129 | 130 | buyer_budget = 100 131 | 132 | MyPricing1 = PriceFunction() 133 | MyPricing1.setup(max_p = 100, method="lin") 134 | MyPricing2 = PriceFunction() 135 | MyPricing2.setup(max_p = 100, method="lin") 136 | 137 | 138 | mlmodel1 = LogisticRegression(random_state=0) 139 | 140 | 141 | MyMarketEngine.setup_market(seller_data=[data_1,data_2], 142 | seller_prices = [MyPricing1,MyPricing2], 143 | buyer_data=data_b, 144 | buyer_budget=buyer_budget, 145 | mlmodel=mlmodel1, 146 | ) 147 | 148 | stretagy = [[1,2],[50,50]] 149 | #MyMarketEngine.load_stretagy(stretagy) 150 | 151 | #acc1 = MyMarketEngine.train_buyer_model() 152 | #print("acc is ",acc1) 153 | 154 | MyHelper = Helper() 155 | seller_data, seller_prices, buyer_data, buyer_budget, seller_datasize = MyHelper.load_market_instance( 156 | feature_path="../features/0/", 157 | buyer_data_path="../marketinfo/0/data_buyer/20.csv", 158 | price_path="../marketinfo/0/price/price.txt", 159 | budget_path="../marketinfo/0/price/budget.txt", 160 | ) 161 | print("load data finished") 162 | print("seller data size:",seller_datasize) 163 | numpy.savetxt("../marketinfo/0/seller_datasize.csv",seller_datasize,fmt="%d") 164 | MyMarketEngine.setup_market(seller_data=seller_data, 165 | seller_prices = seller_prices, 166 | buyer_data=buyer_data, 167 | buyer_budget=buyer_budget, 168 | mlmodel=mlmodel1, 169 | ) 170 | print("set up market finished") 171 | stretagy=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,10,10,10,10,15] 172 | stretagy=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0,0,0,0,0] 173 | stretagy=[10,20,30,40,50,60,70,80,9,10,11,12,13,14,15,0,0,0,0,0] 174 | stretagy=[10,20,30,40,50,60,70,800,9,10,11,12,13,14,15,0,0,0,0,0] 175 | stretagy=[10,20,30,40,50,60,70,80,9,10,11,12,13,14,15,0,0,0,0,0] 176 | stretagy=[50,20,30,40,5,6,7,80,9,10,11,12,13,14,15,0,400,0,50,0] 177 | 178 | stretagy=[100,200,300,400,500,600,70,80,9,10,11,12,13,14,15,50,50,50,50,50] 179 | stretagy=[10,20,30,40,50,60,70,80,9,10,11,12,13,14,15,50,50,50,50,50] 180 | stretagy=[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0,0,0,0,0] 181 | 182 | traindata = MyHelper.load_data(stretagy, MyMarketEngine) 183 | model = RandomForestClassifier() 184 | model = KNeighborsClassifier(n_neighbors=9) 185 | model = LogisticRegression(random_state=0) 186 | 187 | model = MyHelper.train_model(model, traindata[:,0:-1], 188 | numpy.ravel(traindata[:,-1])) 189 | acc1 = MyHelper.eval_model(model,test_X=buyer_data[:,0:-1],test_Y=buyer_data[:,-1]) 190 | 191 | print("acc is:", acc1) 192 | model2 = DummyClassifier(strategy="most_frequent") 193 | model2 = MyHelper.train_model(model2, traindata[:,0:-1], 194 | numpy.ravel(traindata[:,-1])) 195 | acc2 = MyHelper.eval_model(model2,test_X=buyer_data[:,0:-1],test_Y=buyer_data[:,-1]) 196 | print("dummy acc is:", acc2) 197 | 198 | if __name__ == '__main__': 199 | main() 200 | 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dataperf-dam: A Data-centric Benchmark on Data Acquisition for Machine Learning 2 | 3 | This github repo serves as the starting point for submissions and evaluations for data acquisition for machine learning benchmark, or in short, DAM, as part of the DataPerf benchmark suite [https://dataperf.org/](https://dataperf.org/) 4 | 5 | 6 | ## 1. What is the DAM benchmark? 7 | 8 | An increasingly large amount of data is purchased for AI-enabled data science applications. How to select the right set of datasets for AI tasks of interest is an important decision that has, however, received limited attention. A naive approach is to acquire all available datasets and then select which ones to use empirically. This requires expensive human supervision and incurs prohibitively high costs, posing unique challenges to budget-limited users. 9 | 10 | How can one decide which datasets to acquire before actually purchasing the data to optimize the performance quality of an ML model? In the DAM (Data-Acquisition-for-Machine-learning) benchmark, the participants are asked to tackle the aforementioned problem. Participants need to provide a data purchase strategy for a data buyer in K (=5 in the beta version) separate data marketplaces. In each data marketplace, there are a few data sellers offering datasets for sale, and one data buyer interested in acquiring some of those datasets to train an ML model. The seller provides a pricing function that depends on the number of purchased samples. The buyer first decides how many data points to purchase from each seller given a data acquisition budget b. Then those data points are compiled into one dataset to train an ML model f(). The buyer also has a dataset Db to evaluate the performance of the trained model. Similar to real-world data marketplaces, the buyer can observe no sellers’ datasets but some summary information from the sellers. 11 | 12 | ## 2. How to participate this challenge? 13 | We suggest to start participating by using the [colab notebook](https://colab.research.google.com/drive/1HYoFfKwd9Pr-Zg_e2uJxWF8yHqa9sRMn?usp=sharing). It is self-contained, and shows how to (i) install the needed library, (ii) access the buyer's observation, and (iii) create strategies ready to be submitted. In the following we explain this in more details. 14 | 15 | ## 3. How to access the buyer's observation? 16 | 17 | We provide a simple python library to access the buyer’s observation in each data marketplace. 18 | To use it, we recommand to create a virtual environment by 19 | ``` 20 | conda create -n DAM python=3.8 21 | conda activate DAM 22 | ``` 23 | and then clone the github repo and install all libraries, and download the data by 24 | ``` 25 | git clone https://github.com/facebookresearch/Data_Acquisition_for_ML_Benchmark 26 | cd Data_Acquisition_for_ML_Benchmark 27 | pip install -r requirements.txt 28 | wget https://github.com/lchen001/Data_Acquisition_for_ML_Benchmark/releases/download/v0.0.1/marketinfo.zip 29 | ! unzip marketinfo.zip 30 | cd src 31 | ``` 32 | 33 | Now, one is ready to use this library. For example, to specify the marketplace id, one can use 34 | 35 | ``` 36 | from dam import Dam 37 | MyDam = Dam(instance=0) 38 | ``` 39 | 40 | 41 | The following code lists the buyer’s budget, dataset, and ml model. 42 | 43 | ``` 44 | budget = MyDam.getbudget() 45 | buyer_data = MyDam.getbuyerdata() 46 | mlmodel = MyDam.getmlmodel() 47 | ``` 48 | 49 | 50 | To list all sellers’ ids, execute 51 | 52 | 53 | ``` 54 | sellers_id = MyDam.getsellerid() 55 | ``` 56 | 57 | To get seller i’s information, run 58 | 59 | ``` 60 | seller_i_price, seller_i_summary, seller_i_samples = MyDam.getsellerinfo(seller_id=i) 61 | ``` 62 | 63 | seller_i_price contains the pricing function. seller_i_summary includes (i) the number of rows, (ii) the number of columns, (iii) the histogram of each dimension, and (iv) the correlation between each column and the label. Seller_i_samples contains 5 samples from each dataset. 64 | 65 | Note: For simplification purposes, all sellers sell the same type of data, or in a more mathematically way, their data distribution shares the same support. For example, the number of columns are the same, and so the semantic meaning. 66 | 67 | More details on the price function: given a sample size, the price can be calculated by calling the get_price_samplesize function. For example, if the sample size is 100, then calling 68 | 69 | ``` 70 | seller_i_price.get_price_samplesize(samplesize=100) 71 | ``` 72 | gives the price. 73 | 74 | More details on the seller summary: the seller_i_summary contains four fields as follows: 75 | 76 | ``` 77 | seller_i_summary.keys() 78 | >>> dict_keys(['row_number', 'column_number', 'hist', 'label_correlation']) 79 | ``` 80 | Here, seller_i_summary['row_number'] encode the number of data points. Similarly, seller_i_summary['column_number'] equals the number of features plus (the label). seller_i_summary['hist'] is a dictionary containg the histgram for each feature. seller_i_summary['label_correlation'] is a dictionary that represents the pearson correlation between each feature and the label. 81 | 82 | For example, one can print the histogram of the second feature by 83 | ``` 84 | print(seller_i_summary['hist']['2']) 85 | >>> {'0_size': 3, '1_size': 35, '2_size': 198, '3_size': 821, '4_size': 2988, '5_size': 8496, '6_size': 11563, '7_size': 5155, '8_size': 704, '9_size': 37, '0_range': -0.7187578082084656, '1_range': -0.5989721298217774, '2_range': -0.4791864514350891, '3_range': -0.3594007730484009, '4_range': -0.23961509466171266, '5_range': -0.11982941627502441, '6_range': -4.373788833622605e-05, '7_range': 0.11974194049835207, '8_range': 0.23952761888504026, '9_range': 0.35931329727172856, '10_range': 0.47909897565841675} 86 | ``` 87 | How to read this? This representation basically documents (i) how the histogram bins are created (i_range), and (ii) how many points fall into each bin (i_size). For example, '2_size':198 means 198 data points are in the 2nd bin, and '' '2_range': -0.4791864514350891, '3_range': -0.3594007730484009'' means the 2nd bin is within [-0.4791864514350891,-0.3594007730484009]. 88 | 89 | ``` 90 | print(seller_i_summary['label_correlation']['2']) 91 | >>> 0.08490820825406746 92 | ``` 93 | This means the correlation between the 2nd feature and the label is 0.08490820825406746. 94 | 95 | Note that all features in the sellers and buyers' datasets are NOT in their raw form. In fact, we have extracted those features using a deep learning model (more specifically, a dist-bert model) from their original format. 96 | 97 | ## 3. How to submit a solution? 98 | 99 | The submission should contain K(=5) txt files. k.txt corresponds to the purchase strategy for the kth marketplace. The notebook will automatically generate txt files for submission under the folder ```\submission\my_submission```. For example, one submission may look like 100 | 101 | 102 | ``` 103 | 104 | \submission\my_submission\0.txt 105 | 106 | \submission\my_submission\1.txt 107 | 108 | \submission\my_submission\2.txt 109 | 110 | \submission\my_submission\3.txt 111 | 112 | \submission\my_submission\4.txt 113 | 114 | ``` 115 | 116 | Each txt file should contain one line of numbers, where the ith number indicates the number of data to purchase from the ith seller. For example, 0.txt containing 117 | 118 | ``` 119 | 100,50,200,500 120 | ``` 121 | 122 | means buying 100, 50, 200, and 500 samples from seller 1, seller 2, seller 3, and seller 4 separately. 123 | 124 | Once you are ready, upload the txt files to DynaBench for evaluation: https://dynabench.org/tasks/DAM/ 125 | 126 | 127 | ## 4. How is a submission evaluated? 128 | 129 | Once received the submission, we will first evaluate whether the strategy is legal (e.g., satisfying the budget constraint). Then we train an ML model on the dataset generated by the submitted strategy and evaluate its performance (standard accuracy) on the buyer’s data Db. We will report the performance averaged over all K marketplace instances. 130 | 131 | What ML model to train? To focus on the data acquisition task, we train a simple logistic regression model. More specifically, we use the following model 132 | 133 | ``` 134 | from sklearn.linear_model import LogisticRegression 135 | model = LogisticRegression(random_state=0) 136 | ``` 137 | 138 | Requirements: 139 | 140 | (i) you may use any (open-source/commercial) software; 141 | 142 | (ii) you may not use external datasets; 143 | 144 | (iii) do not create multiple accounts for submission; 145 | 146 | (iv) follow the honor code. 147 | 148 | ## Contact and License 149 | _DAM_ is Apache 2.0 licensed. 150 | 151 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2022 Meta Platforms 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | -------------------------------------------------------------------------------- /src/visualizetools.py: -------------------------------------------------------------------------------- 1 | import matplotlib # noqa 2 | matplotlib.use('Agg') # noqa 3 | 4 | import matplotlib.pyplot as plt 5 | plt.rcParams['axes.facecolor'] = 'white' 6 | 7 | import numpy as np 8 | import matplotlib.ticker as ticker 9 | import json 10 | import seaborn as sn 11 | import pandas as pd 12 | from matplotlib.colors import LogNorm 13 | import seaborn as sns 14 | from matplotlib.colors import LinearSegmentedColormap 15 | import umap 16 | #import matplotlib.pyplot as plt 17 | 18 | 19 | class VisualizeTools(object): 20 | def __init__(self,figuresize = (10,8),figureformat='jpg', 21 | colorset=['r','orange','k','yellow','g','b','k'], 22 | markersize=30, 23 | fontsize=30, 24 | usecommand=True): 25 | self.figuresize=figuresize 26 | self.figureformat = figureformat 27 | self.fontsize = fontsize 28 | self.linewidth = 5 29 | self.markersize = markersize 30 | self.folder = "../figures/" # use "../figures/" if needed 31 | self.colorset=colorset 32 | self.markerset = ['o','X','^','v','s','o','*','d','p'] 33 | self.marker = 'o' # from ['X','^','v','s','o','*','d','p'], 34 | self.linestyle = '-' # from ['-.','--','--','-.','-',':','--','-.'], 35 | self.linestyleset = ['-','-.','--','--','-.','-',':','--','-.'] 36 | self.usecommand = usecommand 37 | 38 | def plotline(self, 39 | xvalue, 40 | yvalue, 41 | xlabel='xlabel', 42 | ylabel='ylabel', 43 | legend=None, 44 | filename='lineplot', 45 | fig=None, 46 | color=None, 47 | ax=None): 48 | if(ax==None): 49 | # setup figures 50 | fig = plt.figure(figsize=self.figuresize) 51 | fig, ax = plt.subplots(figsize=self.figuresize) 52 | plt.rcParams.update({'font.size': self.fontsize}) 53 | plt.rcParams["font.weight"] = "bold" 54 | plt.rcParams["axes.labelweight"] = "bold" 55 | plt.rcParams["lines.linewidth"] = self.linewidth 56 | plt.rcParams["lines.markersize"] = self.markersize 57 | plt.rcParams["font.sans-serif"] = 'Arial' 58 | 59 | # plot it 60 | if(color==None): 61 | color = self.colorset[0] 62 | ax.plot(xvalue, 63 | yvalue, 64 | marker=self.marker, 65 | label=legend, 66 | color=color, 67 | linestyle = self.linestyle, 68 | zorder=0, 69 | ) 70 | plt.xlabel(xlabel) 71 | plt.ylabel(ylabel) 72 | 73 | plt.grid(True) 74 | ax.locator_params(axis='x', nbins=6) 75 | ax.locator_params(axis='y', nbins=6) 76 | 77 | formatter = ticker.FormatStrFormatter('%0.2e') 78 | 79 | formatterx = ticker.FormatStrFormatter('%0.2f') 80 | 81 | ax.yaxis.set_major_formatter(formatter) 82 | ax.xaxis.set_major_formatter(formatterx) 83 | 84 | filename =filename+'.'+self.figureformat 85 | 86 | if(self.figureformat=='jpg'): 87 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 88 | else: 89 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 90 | 91 | return fig, ax 92 | #plt.fill_between(bud, np.asarray(acc_mean)-np.asarray(acc_std), np.asarray(acc_mean)+np.asarray(acc_std),alpha=0.3,facecolor='lightgray') 93 | 94 | 95 | def plotlines(self, 96 | xvalue, 97 | yvalues, 98 | xlabel='xlabel', 99 | ylabel='ylabel', 100 | legend=None, 101 | filename='lineplot', 102 | fig=None, 103 | ax=None, 104 | showlegend=False, 105 | log=False, 106 | fontsize=60, 107 | basey=10, 108 | ylim=None): 109 | #if(-1): 110 | if(ax==None): 111 | # setup figures 112 | fig = plt.figure(figsize=self.figuresize) 113 | fig, ax = plt.subplots(figsize=self.figuresize,frameon=True) 114 | plt.rcParams.update({'font.size': fontsize}) 115 | plt.rcParams["font.weight"] = "bold" 116 | plt.rcParams["axes.labelweight"] = "bold" 117 | plt.rcParams["lines.linewidth"] = self.linewidth 118 | plt.rcParams["lines.markersize"] = self.markersize 119 | plt.rcParams["font.sans-serif"] = 'Arial' 120 | ax.set_facecolor("white") 121 | #ax.set_edgecolor("black") 122 | ax.grid("True",color="grey") 123 | ax.get_yaxis().set_visible(True) 124 | ax.get_xaxis().set_visible(True) 125 | # plot it 126 | for i in range(len(yvalues)): 127 | ax.plot(xvalue, 128 | yvalues[i], 129 | marker=self.markerset[i], 130 | label=legend[i], 131 | color=self.colorset[i], 132 | linestyle = self.linestyleset[i], 133 | zorder=0, 134 | markersize=self.markersize, 135 | markevery=1, 136 | ) 137 | plt.xlabel(xlabel,fontsize=fontsize) 138 | plt.ylabel(ylabel,fontsize=fontsize) 139 | 140 | plt.grid(True) 141 | #ax.locator_params(axis='x', nbins=6) 142 | #ax.locator_params(axis='y', nbins=6) 143 | ''' 144 | formatter = ticker.FormatStrFormatter('%d') 145 | 146 | formatterx = ticker.FormatStrFormatter('%d') 147 | 148 | ax.yaxis.set_major_formatter(formatter) 149 | ax.xaxis.set_major_formatter(formatterx) 150 | ''' 151 | ax.tick_params(axis='both', which='major', labelsize=fontsize) 152 | 153 | if(ylim!=None): 154 | plt.ylim(ylim) 155 | 156 | if(log==True): 157 | ax.set_yscale('log',base=basey) 158 | if(showlegend==True): 159 | ax.legend(legend,facecolor="white",prop={'size': fontsize}, 160 | markerscale=1, numpoints= 2,loc="best") 161 | 162 | filename =filename+'.'+self.figureformat 163 | 164 | if(self.figureformat=='jpg'): 165 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 166 | else: 167 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 168 | 169 | return fig, ax 170 | #plt.fill_between(bud, np.asarray(acc_mean)-np.asarray(acc_std), np.asarray(acc_mean)+np.asarray(acc_std),alpha=0.3,facecolor='lightgray') 171 | 172 | def Histogram(self, 173 | xvalue, 174 | 175 | xlabel='xlabel', 176 | ylabel='ylabel', 177 | legend=None, 178 | filename='lineplot', 179 | fig=None, 180 | ax=None, 181 | showlegend=False, 182 | log=False, 183 | fontsize=90, 184 | ylim=None, 185 | n_bins=20): 186 | #if(-1): 187 | if(ax==None): 188 | # setup figures 189 | fig = plt.figure(figsize=self.figuresize) 190 | fig, ax = plt.subplots(figsize=self.figuresize,frameon=True) 191 | plt.rcParams.update({'font.size': fontsize}) 192 | plt.rcParams["font.weight"] = "bold" 193 | plt.rcParams["axes.labelweight"] = "bold" 194 | plt.rcParams["lines.linewidth"] = self.linewidth 195 | plt.rcParams["lines.markersize"] = self.markersize 196 | plt.rcParams["font.sans-serif"] = 'Arial' 197 | ax.set_facecolor("white") 198 | #ax.set_edgecolor("black") 199 | ax.grid("True",color="grey") 200 | ax.get_yaxis().set_visible(True) 201 | ax.get_xaxis().set_visible(True) 202 | # plot it 203 | plt.hist(xvalue,bins=n_bins) 204 | ''' 205 | for i in range(len(yvalues)): 206 | 207 | ax.plot(xvalue, 208 | yvalues[i], 209 | marker=self.markerset[i], 210 | label=legend[i], 211 | color=self.colorset[i], 212 | linestyle = self.linestyleset[i], 213 | zorder=0, 214 | markersize=self.markersize, 215 | markevery=10, 216 | ) 217 | ''' 218 | plt.xlabel(xlabel,fontsize=fontsize) 219 | plt.ylabel(ylabel,fontsize=fontsize) 220 | 221 | plt.grid(True) 222 | #ax.locator_params(axis='x', nbins=6) 223 | #ax.locator_params(axis='y', nbins=6) 224 | ''' 225 | formatter = ticker.FormatStrFormatter('%d') 226 | 227 | formatterx = ticker.FormatStrFormatter('%d') 228 | 229 | ax.yaxis.set_major_formatter(formatter) 230 | ax.xaxis.set_major_formatter(formatterx) 231 | ''' 232 | ax.tick_params(axis='both', which='major', labelsize=fontsize) 233 | 234 | if(ylim!=None): 235 | plt.ylim(ylim) 236 | 237 | if(log==True): 238 | ax.set_yscale('log') 239 | if(showlegend==True): 240 | ax.legend(legend,facecolor="white",prop={'size': fontsize}, 241 | markerscale=2, numpoints= 2,loc=0) 242 | 243 | filename =filename+'.'+self.figureformat 244 | 245 | if(self.figureformat=='jpg'): 246 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 247 | else: 248 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 249 | 250 | return fig, ax 251 | #plt.fill_between(bud, np.asarray(acc_mean)-np.asarray(acc_std), np.asarray(acc_mean)+np.asarray(acc_std),alpha=0.3,facecolor='lightgray') 252 | 253 | 254 | def Histograms(self, 255 | xvalues, 256 | xlabel='xlabel', 257 | ylabel='ylabel', 258 | legend=None, 259 | filename='lineplot', 260 | fig=None, 261 | ax=None, 262 | showlegend=False, 263 | log=False, 264 | fontsize=90, 265 | color=['red','orange'], 266 | ylim=None, 267 | n_bins=20): 268 | #if(-1): 269 | if(ax==None): 270 | # setup figures 271 | fig = plt.figure(figsize=self.figuresize) 272 | fig, ax = plt.subplots(figsize=self.figuresize,frameon=True) 273 | plt.rcParams.update({'font.size': fontsize}) 274 | plt.rcParams["font.weight"] = "bold" 275 | plt.rcParams["axes.labelweight"] = "bold" 276 | plt.rcParams["lines.linewidth"] = self.linewidth 277 | plt.rcParams["lines.markersize"] = self.markersize 278 | plt.rcParams["font.sans-serif"] = 'Arial' 279 | ax.set_facecolor("white") 280 | #ax.set_edgecolor("black") 281 | ax.grid("True",color="grey") 282 | ax.get_yaxis().set_visible(True) 283 | ax.get_xaxis().set_visible(True) 284 | # plot it 285 | plt.hist(xvalues,bins=n_bins, density=True,color=color) 286 | ''' 287 | for i in range(len(yvalues)): 288 | 289 | ax.plot(xvalue, 290 | yvalues[i], 291 | marker=self.markerset[i], 292 | label=legend[i], 293 | color=self.colorset[i], 294 | linestyle = self.linestyleset[i], 295 | zorder=0, 296 | markersize=self.markersize, 297 | markevery=10, 298 | ) 299 | ''' 300 | plt.xlabel(xlabel,fontsize=fontsize) 301 | plt.ylabel(ylabel,fontsize=fontsize) 302 | 303 | plt.grid(True) 304 | #ax.locator_params(axis='x', nbins=6) 305 | #ax.locator_params(axis='y', nbins=6) 306 | ''' 307 | formatter = ticker.FormatStrFormatter('%d') 308 | 309 | formatterx = ticker.FormatStrFormatter('%d') 310 | 311 | ax.yaxis.set_major_formatter(formatter) 312 | ax.xaxis.set_major_formatter(formatterx) 313 | ''' 314 | ax.tick_params(axis='both', which='major', labelsize=fontsize) 315 | 316 | if(ylim!=None): 317 | plt.ylim(ylim) 318 | 319 | if(log==True): 320 | ax.set_yscale('log') 321 | if(showlegend==True): 322 | ax.legend(legend,facecolor="white",prop={'size': fontsize}, 323 | markerscale=2, numpoints= 2,loc=0) 324 | 325 | filename =filename+'.'+self.figureformat 326 | 327 | if(self.figureformat=='jpg'): 328 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 329 | else: 330 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 331 | 332 | return fig, ax 333 | #plt.fill_between(bud, np.asarray(acc_mean)-np.asarray(acc_std), np.asarray(acc_mean)+np.asarray(acc_std),alpha=0.3,facecolor='lightgray') 334 | 335 | 336 | def plotscatter(self, 337 | xvalue=0.3, 338 | yvalue=0.5, 339 | filename='lineplot', 340 | markersize=10, 341 | legend='Learned Thres', 342 | color='blue', 343 | showlegend=False, 344 | fig=None, 345 | ax=None): 346 | if(ax==None): 347 | # setup figures 348 | fig = plt.figure(figsize=self.figuresize) 349 | fig, ax = plt.subplots(figsize=self.figuresize) 350 | plt.rcParams.update({'font.size': self.fontsize}) 351 | plt.rcParams["font.weight"] = "bold" 352 | plt.rcParams["axes.labelweight"] = "bold" 353 | plt.rcParams["lines.linewidth"] = self.linewidth 354 | plt.rcParams["lines.markersize"] = self.markersize 355 | plt.rcParams["font.sans-serif"] = 'Arial' 356 | 357 | ax.plot(xvalue,yvalue,'*',markersize=markersize,color=color, 358 | label=legend) 359 | if(showlegend): 360 | handles, labels = ax.get_legend_handles_labels() 361 | print("labels",labels) 362 | ax.legend(handles[::-1],labels[::-1], prop={'size': 35},markerscale=3, numpoints= 1,loc=0) 363 | 364 | 365 | filename =filename+'.'+self.figureformat 366 | if(self.figureformat=='jpg'): 367 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 368 | else: 369 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 370 | 371 | return fig, ax 372 | 373 | 374 | def plotscatter(self, 375 | xvalue=0.3, 376 | yvalue=0.5, 377 | filename='lineplot', 378 | markersize=10, 379 | legend='Learned Thres', 380 | color='blue', 381 | showlegend=False, 382 | fig=None, 383 | ax=None): 384 | if(ax==None): 385 | # setup figures 386 | fig = plt.figure(figsize=self.figuresize) 387 | fig, ax = plt.subplots(figsize=self.figuresize) 388 | plt.rcParams.update({'font.size': self.fontsize}) 389 | plt.rcParams["font.weight"] = "bold" 390 | plt.rcParams["axes.labelweight"] = "bold" 391 | plt.rcParams["lines.linewidth"] = self.linewidth 392 | plt.rcParams["lines.markersize"] = self.markersize 393 | plt.rcParams["font.sans-serif"] = 'Arial' 394 | 395 | ax.plot(xvalue,yvalue,'*',markersize=markersize,color=color, 396 | label=legend) 397 | if(showlegend): 398 | handles, labels = ax.get_legend_handles_labels() 399 | print("labels",labels) 400 | ax.legend(handles[::-1],labels[::-1], prop={'size': 35},markerscale=3, numpoints= 1,loc=0) 401 | 402 | 403 | filename =filename+'.'+self.figureformat 404 | if(self.figureformat=='jpg'): 405 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 406 | else: 407 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 408 | 409 | return fig, ax 410 | 411 | def plotscatters_annotation(self, 412 | xvalue=[0.3], 413 | yvalue=[0.5], 414 | filename='lineplot', 415 | markersize=10, 416 | legend='Learned Thres', 417 | color='blue', 418 | showlegend=False, 419 | fig=None, 420 | ax=None, 421 | annotation=None): 422 | if(ax==None): 423 | # setup figures 424 | fig = plt.figure(figsize=self.figuresize) 425 | fig, ax = plt.subplots(figsize=self.figuresize) 426 | plt.rcParams.update({'font.size': self.fontsize}) 427 | plt.rcParams["font.weight"] = "bold" 428 | plt.rcParams["axes.labelweight"] = "bold" 429 | plt.rcParams["lines.linewidth"] = self.linewidth 430 | plt.rcParams["lines.markersize"] = self.markersize 431 | plt.rcParams["font.sans-serif"] = 'Arial' 432 | 433 | ax.scatter(xvalue,yvalue,) 434 | # '*',markersize=markersize,color=color, 435 | # ) 436 | for i in range(len(xvalue)): 437 | ax.annotate(annotation[i], xy=[xvalue[i],yvalue[i]]) 438 | if(showlegend): 439 | handles, labels = ax.get_legend_handles_labels() 440 | print("labels",labels) 441 | ax.legend(handles[::-1],labels[::-1], prop={'size': 35},markerscale=3, numpoints= 1,loc=0) 442 | 443 | 444 | filename =filename+'.'+self.figureformat 445 | if(self.figureformat=='jpg'): 446 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 447 | else: 448 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 449 | 450 | return fig, ax 451 | 452 | def plot_bar(self,barname,barvalue, 453 | filename='barplot', 454 | markersize=2, 455 | yname='Frequency', 456 | xname="", 457 | color='blue', 458 | ylim=None, 459 | fig=None, 460 | showlegend=False, 461 | ax=None, 462 | labelpad=None, 463 | fontsize=30, 464 | threshold=10, 465 | add_thresline=False,): 466 | if(ax==None): 467 | # setup figures 468 | fig = plt.figure(figsize=self.figuresize) 469 | fig, ax = plt.subplots(figsize=self.figuresize) 470 | ax.set_facecolor("white") 471 | plt.rcParams.update({'font.size': 1}) 472 | plt.rcParams["font.weight"] = "bold" 473 | plt.rcParams["axes.labelweight"] = "bold" 474 | plt.rcParams["lines.linewidth"] = self.linewidth 475 | plt.rcParams["lines.markersize"] = markersize 476 | plt.rcParams["font.sans-serif"] = 'Arial' 477 | plt.rc('font', size=1) # controls default text sizes 478 | print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") 479 | plt.grid(True,color="grey") 480 | x = np.arange(len(barname)) 481 | ax.bar(x,barvalue,color=color, 482 | label=barname) 483 | ax.set_ylabel(yname,fontsize=fontsize) 484 | if(xname!=""): 485 | ax.set_xlabel(xname,fontsize=fontsize) 486 | 487 | #ax.set_title('Scores by group and gender') 488 | ax.set_xticks(x) 489 | ax.set_xticklabels(barname,rotation='horizontal',fontsize=fontsize) 490 | #ax.set_xticklabels(barname,rotation='vertical') 491 | plt.xlim(x[0]-0.5,x[-1]+0.5) 492 | 493 | if(add_thresline==True): 494 | ax.plot([min(x)-0.5, max(x)+0.5], [threshold, threshold], "k--") 495 | 496 | matplotlib.rc('xtick', labelsize=fontsize) 497 | 498 | ax.tick_params(axis='both', which='major', labelsize=fontsize) 499 | 500 | if(not(labelpad==None)): 501 | ax.tick_params(axis='x', which='major', pad=labelpad) 502 | 503 | #matplotlib.rc('ytick', labelsize=fontsize) 504 | #ax.text(0.5,0.5,"hello") 505 | 506 | #ax.legend() 507 | 508 | if(showlegend): 509 | handles, labels = ax.get_legend_handles_labels() 510 | print("labels",labels) 511 | ax.legend(handles[::-1],labels[::-1], prop={'size': 10},markerscale=3, numpoints= 1,loc=0) 512 | 513 | 514 | #ticks = [tick for tick in plt.gca().get_xticklabels()] 515 | #print("ticks 0 is",ticks[0].get_window_extent()) 516 | ''' 517 | plt.text(-0.07, -0.145, 'label:', horizontalalignment='center',fontsize=fontsize, 518 | verticalalignment='center', transform=ax.transAxes) 519 | plt.text(-0.07, -0.25, 'qs:', horizontalalignment='center',fontsize=fontsize, 520 | verticalalignment='center', transform=ax.transAxes) 521 | ''' 522 | filename =filename+'.'+self.figureformat 523 | if(not(ylim==None)): 524 | plt.ylim(ylim) 525 | if(self.figureformat=='jpg'): 526 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 527 | else: 528 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 529 | 530 | return fig, ax 531 | 532 | 533 | def plot_bar2value(self,barname,barvalue, barvalue2, 534 | filename='barplot', 535 | markersize=2, 536 | yname='Frequency', 537 | color='blue', 538 | fig=None, 539 | showlegend=False, 540 | legend=['precision','recall'], 541 | yrange = None, 542 | ax=None, 543 | fontsize=25, 544 | showvalues = False, 545 | legend_loc="upper left", 546 | hatch=None): 547 | if(ax==None): 548 | # setup figures 549 | fig = plt.figure(figsize=self.figuresize) 550 | fig, ax = plt.subplots(figsize=self.figuresize) 551 | plt.rcParams.update({'font.size': fontsize}) 552 | plt.rcParams["font.weight"] = "bold" 553 | plt.rcParams["axes.labelweight"] = "bold" 554 | plt.rcParams["lines.linewidth"] = self.linewidth 555 | plt.rcParams["lines.markersize"] = markersize 556 | plt.rcParams["font.sans-serif"] = 'Arial' 557 | width=0.3 558 | x = np.arange(len(barname)) 559 | ax.bar(x-width/2,barvalue,width,color=color[0], 560 | label=legend[0]) 561 | ax.bar(x+width/2,barvalue2,width, color=color[1], 562 | hatch=hatch, 563 | label=legend[1]) 564 | 565 | 566 | 567 | ax.set_ylabel(yname,fontsize=fontsize) 568 | #ax.set_title('Scores by group and gender') 569 | ax.set_xticks(x) 570 | #ax.set_xticklabels(barname,rotation='vertical') 571 | #ax.set_xticklabels(barname,rotation=45) 572 | ax.set_xticklabels(barname,rotation='horizontal') 573 | plt.xlim(x[0]-0.5,x[-1]+0.5) 574 | if(not(yrange==None)): 575 | plt.ylim(yrange[0],yrange[1]) 576 | 577 | matplotlib.rc('xtick', labelsize=fontsize) 578 | matplotlib.rc('ytick', labelsize=fontsize) 579 | 580 | #ax.legend() 581 | 582 | if(showvalues==True): 583 | for i, v in enumerate(barvalue): 584 | ax.text(i - 0.33,v + 0.1, "{:.1f}".format(v), color=color[0], fontweight='bold',) 585 | 586 | for i, v in enumerate(barvalue2): 587 | ax.text(i + .10,v + 0.2, "{:.1f}".format(v), color=color[1], fontweight='bold',) 588 | 589 | if(showlegend): 590 | handles, labels = ax.get_legend_handles_labels() 591 | print("labels",labels) 592 | ax.legend(handles[::-1],labels[::-1], prop={'size': fontsize},markerscale=3, numpoints= 1, 593 | loc=legend_loc,ncol=1, )#bbox_to_anchor=(0, 1.05)) 594 | 595 | 596 | filename =filename+'.'+self.figureformat 597 | if(self.figureformat=='jpg'): 598 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 599 | else: 600 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 601 | 602 | return fig, ax 603 | 604 | def plotconfusionmaitrix(self,confmatrix, 605 | xlabel=None,ylabel=None, 606 | filename='confmatrix', 607 | keywordsize = 16, 608 | font_scale=2, 609 | figuresize=(10,10), 610 | cmap="coolwarm", # "Blues" 611 | vmin=0, 612 | vmax=10, 613 | fonttype='Arial', 614 | title1="", 615 | fmt=".1f", 616 | xlabel1 = "Predicted label", 617 | ylabel1="True label",): 618 | if(self.usecommand==True): 619 | return self.plotconfusionmaitrix_common1(confmatrix=confmatrix, 620 | xlabel=xlabel, 621 | ylabel=ylabel, 622 | filename=filename, 623 | keywordsize = keywordsize, 624 | font_scale=font_scale, 625 | figuresize=figuresize, 626 | cmap=cmap, 627 | vmin=vmin, 628 | vmax=vmax, 629 | fonttype=fonttype, 630 | title1=title1, 631 | xlabel1=xlabel1, 632 | ylabel1=ylabel1, 633 | fmt=fmt) 634 | 635 | sn.set(font=fonttype) 636 | #boundaries = [0.0, 0.045, 0.05, 0.055, 0.06,0.065,0.07,0.08,0.1,0.15, 1.0] # custom boundaries 637 | boundaries = [0.0, 0.06,0.2, 0.25,0.3, 0.4,0.5,0.6,0.7, 0.8, 1.0] # custom boundaries 638 | 639 | # here I generated twice as many colors, 640 | # so that I could prune the boundaries more clearly 641 | #hex_colors = sns.light_palette('blue', n_colors=len(boundaries) * 2 + 2, as_cmap=False).as_hex() 642 | #hex_colors = [hex_colors[i] for i in range(0, len(hex_colors), 2)] 643 | #print("hex",hex_colors) 644 | # My color 645 | hex_colors = ['#ffffff','#ebf1f7', 646 | '#d3e4f3', 647 | '#bfd8ed', 648 | '#a1cbe2', 649 | '#7db8da', 650 | '#5ca4d0', 651 | '#3f8fc5', 652 | '#2676b8', 653 | '#135fa7', 654 | '#08488e'] 655 | ''' 656 | ['#e5eff9', 657 | '#d3e4f3', 658 | '#bfd8ed', 659 | '#a1cbe2', 660 | '#7db8da', 661 | '#5ca4d0', 662 | '#3f8fc5', 663 | '#2676b8', 664 | '#135fa7', 665 | '#08488e'] 666 | ''' 667 | 668 | boundaries = [0.0, 0.03, 0.06,0.1,0.2,0.29,0.3,0.8,1.0] 669 | hex_colors = ['#F2F6FA','#ebf1f7','#FFB9C7','#FF1242', '#FF1242','#FF1242','#2676b8','#135fa7','#08488e'] 670 | 671 | colors=list(zip(boundaries, hex_colors)) 672 | 673 | custom_color_map = LinearSegmentedColormap.from_list( 674 | name='custom_navy', 675 | colors=colors, 676 | ) 677 | 678 | tol=1e-4 679 | labels = confmatrix 680 | confmatrix=confmatrix*(confmatrix>0.35) 681 | print("confmatrix",confmatrix+tol) 682 | df_cm = pd.DataFrame(confmatrix+tol,xlabel,ylabel) 683 | plt.figure(figsize=figuresize) 684 | sn.set(font_scale=font_scale) # for label size 685 | g = sn.heatmap(df_cm, 686 | linewidths=0.3, 687 | linecolor="grey", 688 | cmap=custom_color_map, 689 | #annot=True, 690 | annot = labels, 691 | annot_kws={"size": keywordsize},fmt=".1f", 692 | #mask=df_cm < 0.02, 693 | vmin=vmin+tol, 694 | vmax=vmax, 695 | cbar=False, 696 | #cbar_kws={"ticks":[0.1,0.3,1,3,10]}, 697 | #norm=LogNorm(), 698 | #legend=False, 699 | ) # font size 700 | #g.cax.set_visible(False) 701 | #sn.heatmap(df, cbar=False) 702 | 703 | g.set_yticklabels(labels=g.get_yticklabels(), va='center') 704 | filename =filename+'.'+self.figureformat 705 | plt.ylabel(ylabel1) 706 | plt.xlabel(xlabel1) 707 | plt.title("Overall accuracy:"+"{:.1f}".format(np.trace(confmatrix)), 708 | fontweight="bold", 709 | pad=32) 710 | g.set_xticklabels(g.get_xticklabels(), rotation = 0) 711 | 712 | 713 | if(self.figureformat=='jpg'): 714 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 715 | else: 716 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 717 | 718 | return 0 719 | 720 | 721 | def plotconfusionmaitrix_common1(self,confmatrix, 722 | xlabel=None,ylabel=None, 723 | filename='confmatrix', 724 | keywordsize = 16, 725 | font_scale=2, 726 | figuresize=(10,10), 727 | cmap="vlag", 728 | vmin=0, 729 | vmax=10, 730 | fonttype='Arial', 731 | title1="", 732 | fmt=".1f", 733 | xlabel1 = "Predicted label", 734 | ylabel1="True label", 735 | ): 736 | print("Use common confusion matrix plot!") 737 | sn.set(font=fonttype) 738 | #boundaries = [0.0, 0.045, 0.05, 0.055, 0.06,0.065,0.07,0.08,0.1,0.15, 1.0] # custom boundaries 739 | boundaries = [0.0, 0.06,0.2, 0.25,0.3, 0.4,0.5,0.6,0.7, 0.8, 1.0] # custom boundaries 740 | 741 | # here I generated twice as many colors, 742 | # so that I could prune the boundaries more clearly 743 | #hex_colors = sns.light_palette('blue', n_colors=len(boundaries) * 2 + 2, as_cmap=False).as_hex() 744 | #hex_colors = [hex_colors[i] for i in range(0, len(hex_colors), 2)] 745 | #print("hex",hex_colors) 746 | # My color 747 | hex_colors = ['#ffffff','#ebf1f7', 748 | '#d3e4f3', 749 | '#bfd8ed', 750 | '#a1cbe2', 751 | '#7db8da', 752 | '#5ca4d0', 753 | '#3f8fc5', 754 | '#2676b8', 755 | '#135fa7', 756 | '#08488e'] 757 | ''' 758 | ['#e5eff9', 759 | '#d3e4f3', 760 | '#bfd8ed', 761 | '#a1cbe2', 762 | '#7db8da', 763 | '#5ca4d0', 764 | '#3f8fc5', 765 | '#2676b8', 766 | '#135fa7', 767 | '#08488e'] 768 | ''' 769 | 770 | boundaries = [0.0, 0.03, 0.06,0.1,0.2,0.29,0.3,0.8,1.0] 771 | hex_colors = ['#F2F6FA','#ebf1f7','#FFB9C7','#FF1242', '#FF1242','#FF1242','#2676b8','#135fa7','#08488e'] 772 | 773 | colors=list(zip(boundaries, hex_colors)) 774 | 775 | custom_color_map = LinearSegmentedColormap.from_list( 776 | name='custom_navy', 777 | colors=colors, 778 | ) 779 | 780 | tol=1e-4 781 | labels = confmatrix 782 | #confmatrix=confmatrix*(confmatrix>0.35) 783 | #print("confmatrix",confmatrix+tol) 784 | df_cm = pd.DataFrame(confmatrix+tol,xlabel,ylabel) 785 | plt.figure(figsize=figuresize) 786 | sn.set(font_scale=font_scale) # for label size 787 | g = sn.heatmap(-df_cm, 788 | linewidths=0.3, 789 | linecolor="grey", 790 | cmap=cmap, 791 | #annot=True, 792 | annot = labels, 793 | annot_kws={"size": keywordsize},fmt=fmt, 794 | #mask=df_cm < 0.02, 795 | #vmin=vmin+tol, 796 | #vmax=vmax, 797 | cbar=False, 798 | center=0, 799 | #cbar_kws={"ticks":[0.1,0.3,1,3,10]}, 800 | #norm=LogNorm(), 801 | #legend=False, 802 | ) # font size 803 | #g.cax.set_visible(False) 804 | #sn.heatmap(df, cbar=False) 805 | 806 | g.set_yticklabels(labels=g.get_yticklabels(), va='center') 807 | filename =filename+'.'+self.figureformat 808 | plt.ylabel(ylabel1) 809 | plt.xlabel(xlabel1) 810 | print("trece",np.trace(confmatrix),confmatrix) 811 | plt.title(title1, 812 | fontweight="bold", 813 | fontsize=keywordsize*1.1, 814 | pad=40) 815 | g.set_xticklabels(g.get_xticklabels(), rotation = 0) 816 | 817 | 818 | if(self.figureformat=='jpg'): 819 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=300) 820 | else: 821 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 822 | 823 | return 0 824 | 825 | def plotconfusionmaitrix_common(self,confmatrix, 826 | xlabel=None,ylabel=None, 827 | filename='confmatrix', 828 | keywordsize = 16, 829 | font_scale=2, 830 | figuresize=(10,10), 831 | cmap='vlag',#sn.diverging_palette(240, 10, n=9), 832 | vmin=-5, 833 | vmax=10, 834 | center=0, 835 | fonttype='Arial'): 836 | 837 | cmap = LinearSegmentedColormap.from_list('RedWhiteGreen', ['red', 'white', 'green']) 838 | 839 | 840 | sn.set(font=fonttype) 841 | 842 | tol=1e-4 843 | labels = (confmatrix+0.05)*(np.abs(confmatrix)>0.1) 844 | labels = list() 845 | for i in range(confmatrix.shape[0]): 846 | temp = list() 847 | for j in range(confmatrix.shape[1]): 848 | a = confmatrix[i,j] 849 | if(a>0.1): 850 | temp.append("+"+"{0:.1f}".format(a)) 851 | if(a<-0.1): 852 | temp.append("{0:.1f}".format(a)) 853 | if(a<=0.1 and a>=-0.1): 854 | temp.append(str(0.0)) 855 | labels.append(temp) 856 | #labels = (confmatrix+0.05)*(np.abs(confmatrix)>0.1) 857 | 858 | print("labels",labels) 859 | 860 | confmatrix=confmatrix=confmatrix*(np.abs(confmatrix)>0.7) 861 | 862 | print("confmatrix",confmatrix+tol) 863 | df_cm = pd.DataFrame(confmatrix+tol,xlabel,ylabel) 864 | plt.figure(figsize=figuresize) 865 | sn.set(font_scale=font_scale) # for label size 866 | g = sn.heatmap(df_cm, 867 | linewidths=12.0, 868 | linecolor="grey", 869 | cmap=cmap, 870 | center=center, 871 | #annot=True, 872 | annot = labels, 873 | annot_kws={"size": keywordsize},fmt="s",#fmt="{0:+.1f}", 874 | #mask=df_cm < 0.02, 875 | vmin=vmin, 876 | vmax=vmax, 877 | cbar=False, 878 | #cbar_kws={"ticks":[0.1,0.3,1,3,10]}, 879 | #norm=LogNorm(), 880 | #legend=False, 881 | ) # font size 882 | #g.cax.set_visible(False) 883 | #sn.heatmap(df, cbar=False) 884 | 885 | g.set_yticklabels(labels=g.get_yticklabels(), va='center') 886 | filename =filename+'.'+self.figureformat 887 | plt.ylabel("ML API") 888 | plt.xlabel("Dataset",) 889 | #plt.title("Overall accuracy:"+"{:.1f}".format(np.trace(confmatrix)), 890 | # fontweight="bold", 891 | # pad=32) 892 | g.set_xticklabels(g.get_xticklabels(), rotation = 0) 893 | 894 | 895 | if(self.figureformat=='jpg'): 896 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight',dpi=40) 897 | else: 898 | plt.savefig(filename, format=self.figureformat, bbox_inches='tight') 899 | 900 | return 0 901 | 902 | def reward_vs_confidence(self, 903 | BaseID = 100, 904 | ModelID=[100,0,1,2], 905 | confidencerange = (0.1,0.2,0.3,0.4,0.5,0.6,0.7,.99,1), 906 | prob_range=None, 907 | datapath='path/to/imagenet/result/val_performance'): 908 | """ 909 | Run a small experiment on solving a Bernoulli bandit with K slot machines, 910 | each with a randomly initialized reward probability. 911 | 912 | Args: 913 | K (int): number of slot machiens. 914 | N (int): number of time steps to try. 915 | """ 916 | datapath = self.datapath 917 | print('reward datapath',datapath) 918 | b0 = BernoulliBanditwithData(ModelID=ModelID,datapath=datapath) 919 | K = len(ModelID) 920 | print ("Data generated Bernoulli bandit has reward probabilities:\n", b0.probas) 921 | print ("The best machine has index: {} and proba: {}".format( 922 | max(range(K), key=lambda i: b0.probas[i]), max(b0.probas))) 923 | Params0 = context_params(ModelID=ModelID,datapath=datapath) 924 | #confidencerange = (0.02,0.03,0.04,0.05,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95,0.99,0.9999,1) 925 | #confidencerange = (0.99,0.991,0.992,0.993,0.994,0.995,0.996,0.997,0.9999,1) 926 | if(not(prob_range==None)): 927 | confidencerange = self.mlmodels.prob2qvalue(prob_interval=prob_range,conf_id=BaseID) 928 | BaseAccuracy, Others =self.mlmodels.accuracy_condition_score_List(ScoreRange=confidencerange,BaseID=BaseID,ModelID=ModelID) 929 | 930 | print(BaseAccuracy, Others) 931 | CDF = Params0.BaseModel.Compute_Prob_vs_Score(ScoreRange=confidencerange) 932 | print(CDF) 933 | plot_reward_vs_confidence(confidencerange, BaseAccuracy,Others, ModelID,"model reward compare_ModelID_{}.png".format(ModelID),CDF) 934 | 935 | def reward_vs_prob(self, 936 | BaseID = 100, 937 | ModelID=[100,0,1,2], 938 | confidencerange = (0.1,0.2,0.3,0.4,0.5,0.6,0.7,.99,1), 939 | prob_range=None, 940 | datapath='path/to/imagenet/result/val_performance', 941 | dataname='imagenet_val', 942 | context=None): 943 | """ 944 | compute and plot reward as a function of the probability of not using 945 | the basemodel. 946 | 947 | Args: 948 | See the name. 949 | """ 950 | datapath = self.datapath 951 | print('reward datapath',datapath) 952 | if(not(prob_range==None)): 953 | confidencerange = self.mlmodels.prob2qvalue(prob_interval=prob_range,conf_id=BaseID,context = context) 954 | BaseAccuracy, Others =self.mlmodels.accuracy_condition_score_list(ScoreRange=confidencerange,BaseID=BaseID,ModelID=ModelID,context=context) 955 | print('Base Accuracy', BaseAccuracy, 'Other',Others) 956 | CDF = self.mlmodels.compute_prob_vs_score(ScoreRange=confidencerange,context = context) 957 | print('CDF',CDF) 958 | self._plot_reward_vs_prob(CDF, BaseAccuracy,Others, ModelID,self.folder+"Reward_vs_Prob_BaseID_{}_{}_context_{}.{}".format(BaseID,dataname,context,self.figureformat),CDF) 959 | 960 | def reward_vs_prob_pdf(self, 961 | BaseID = 100, 962 | ModelID=[100,0,1,2], 963 | confidencerange = (0.1,0.2,0.3,0.4,0.5,0.6,0.7,.99,1), 964 | prob_range=None, 965 | datapath='path/to/imagenet/result/val_performance', 966 | dataname='imagenet_val', 967 | context=None): 968 | """ 969 | compute and plot reward as a function of the probability of not using 970 | the basemodel. 971 | 972 | Args: 973 | See the name. 974 | """ 975 | datapath = self.datapath 976 | print('reward datapath',datapath) 977 | if(not(prob_range==None)): 978 | confidencerange = self.mlmodels.prob2qvalue(prob_interval=prob_range,conf_id=BaseID,context = context) 979 | BaseAccuracy, Others =self.mlmodels.accuracy_condition_score_list(ScoreRange=confidencerange,BaseID=BaseID,ModelID=ModelID,context=context) 980 | print('Base Accuracy', BaseAccuracy, 'Other',Others) 981 | CDF = self.mlmodels.compute_prob_vs_score(ScoreRange=confidencerange,context = context) 982 | print('CDF',CDF) 983 | self._plot_reward_vs_prob(CDF, BaseAccuracy,Others, ModelID,self.folder+"Reward_vs_Prob_BaseID_{}_{}_context_{}.{}".format(BaseID,dataname,context,self.figureformat),CDF) 984 | 985 | if(not(prob_range==None)): 986 | base_pdf,other_pdf = self.mlmodels.accuracy_condition_score_list_cdf2pdf(prob_range,BaseAccuracy,Others,diff = False) 987 | print('base pdf',base_pdf) 988 | print('other pdf',other_pdf) 989 | self._plot_reward_vs_prob(CDF, base_pdf,other_pdf, ModelID,self.folder+"Reward_vs_Probpdf_diff_BaseID_{}_{}_context_{}.{}".format(BaseID,dataname,context,self.figureformat),CDF) 990 | self._plot_reward_vs_prob(confidencerange, base_pdf,other_pdf, ModelID,self.folder+"Reward_vs_conf_pdf_diff_BaseID_{}_{}_context_{}.{}".format(BaseID,dataname,context,self.figureformat),CDF) 991 | 992 | 993 | def qvalue_vs_prob(self, 994 | confidence_range = None, 995 | BaseID = 100, 996 | prob_range = None, 997 | dataname = 'imagenet_val', 998 | context=None): 999 | if(not(prob_range==None)): 1000 | confidence_range = self.mlmodels.prob2qvalue(prob_interval=prob_range,conf_id=BaseID,context=context) 1001 | filename = self.folder+"Conf_vs_prob_BaseID_{}_{}_context_{}.{}".format(BaseID,dataname,context,self.figureformat) 1002 | prob = self.mlmodels.compute_prob_wrt_confidence(confidence_range=confidence_range,BaseID = BaseID,context=context) 1003 | self._plot_q_value_vs_prob(confidence_range,prob,filename) 1004 | return 0 1005 | 1006 | def _plot_reward_vs_prob(self, confidence_range, base_acc, model_acc, model_names, figname, CDF): 1007 | """ 1008 | Plot the results by multi-armed bandit solvers. 1009 | 1010 | Args: 1011 | solvers (list): All of them should have been fitted. 1012 | solver_names (list): All of them should have been fitted. 1129 | solver_names (list 0, solvers)) 1135 | 1136 | b = solvers[0].bandit 1137 | 1138 | fig = plt.figure(figsize=(14, 4)) 1139 | fig.subplots_adjust(bottom=0.3, wspace=0.3) 1140 | 1141 | ax1 = fig.add_subplot(131) 1142 | ax2 = fig.add_subplot(132) 1143 | ax3 = fig.add_subplot(133) 1144 | 1145 | # Sub.fig. 1: Regrets in time. 1146 | for i, s in enumerate(solvers): 1147 | ax1.plot(range(len(s.regrets)), s.regrets, label=solver_names[i]) 1148 | 1149 | ax1.set_xlabel('Time step') 1150 | ax1.set_ylabel('Cumulative regret') 1151 | ax1.legend(loc=9, bbox_to_anchor=(1.82, -0.25), ncol=5) 1152 | ax1.grid('k', ls='--', alpha=0.3) 1153 | 1154 | # Sub.fig. 2: Probabilities estimated by solvers. 1155 | sorted_indices = sorted(range(b.n), key=lambda x: b.probas[x]) 1156 | ax2.plot(range(b.n), [b.probas[x] for x in sorted_indices], 'k--', markersize=12) 1157 | for s in solvers: 1158 | ax2.plot(range(b.n), [s.estimated_probas[x] for x in sorted_indices], 'x', markeredgewidth=2) 1159 | ax2.set_xlabel('Actions sorted by ' + r'$\theta$') 1160 | ax2.set_ylabel('Estimated') 1161 | ax2.grid('k', ls='--', alpha=0.3) 1162 | 1163 | # Sub.fig. 3: Action counts 1164 | for s in solvers: 1165 | ax3.plot(range(b.n), np.array(s.counts) / float(len(solvers[0].regrets)), ls='steps', lw=2) 1166 | ax3.set_xlabel('Actions') 1167 | ax3.set_ylabel('Frac. # trials') 1168 | ax3.grid('k', ls='--', alpha=0.3) 1169 | 1170 | plt.savefig(figname) 1171 | 1172 | def plot_reward_vs_confidence(confidence_range, base_acc, model_acc, model_names, figname, CDF): 1173 | """ 1174 | Plot the results by multi-armed bandit solvers. 1175 | 1176 | Args: 1177 | solvers (list): All of them should have been fitted. 1178 | solver_names (list): All of them should have been fitted. 1221 | solver_names (list