├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data ├── .gitattributes ├── crag_task_1_and_2_dev_v4.jsonl.bz2 ├── crag_task_3_dev_v4.tar.bz2.part1 ├── crag_task_3_dev_v4.tar.bz2.part2 ├── crag_task_3_dev_v4.tar.bz2.part3 └── crag_task_3_dev_v4.tar.bz2.part4 ├── docs ├── baselines.md ├── dataset.md └── download_baseline_model_weights.md ├── example_data └── dev_data.jsonl.bz2 ├── local_evaluation.py ├── mock_api ├── .gitignore ├── README.md ├── apiwrapper │ ├── example_call.ipynb │ └── pycragapi.py ├── cragapi │ ├── __init__.py │ ├── fast_bm25.py │ ├── finance.py │ ├── movie.py │ ├── music.py │ ├── open.py │ └── sports.py ├── cragkg │ ├── .gitattributes │ ├── finance │ │ ├── company_name.dict │ │ ├── finance_detailed_price.sqlite │ │ ├── finance_dividend.sqlite │ │ ├── finance_info.sqlite │ │ ├── finance_marketcap.sqlite │ │ └── finance_price.sqlite │ ├── movie │ │ ├── movie_db.json │ │ ├── person_db.json │ │ └── year_db.json │ ├── music │ │ ├── artist_dict_simplified.pickle │ │ ├── artist_work_dict.pickle │ │ ├── grammy_df.pickle │ │ ├── rank_dict_hot100.pickle │ │ ├── song_dict_hot100.pickle │ │ └── song_dict_simplified.pickle │ ├── open │ │ ├── kg.0.jsonl.bz2 │ │ └── kg.1.jsonl.bz2 │ └── sports │ │ ├── nba.sqlite │ │ └── soccer_team_match_stats.pkl ├── requirements.txt └── server.py ├── models ├── README.md ├── dummy_model.py ├── rag_knowledge_graph_baseline.py ├── rag_llama_baseline.py ├── user_config.py ├── utils.py └── vanilla_llama_baseline.py ├── prompts └── templates.py ├── requirements.txt ├── tokenizer ├── README.md ├── special_tokens_map.json ├── tokenizer.json ├── tokenizer.model └── tokenizer_config.json └── utils └── cragapi_wrapper.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__ 3 | .DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Meta has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://opensource.fb.com/code-of-conduct/) 5 | so that you can understand what actions will and will not be tolerated. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to CRAG 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 4 spaces for indentation rather than tabs 31 | * 120 character line length 32 | 33 | ## License 34 | By contributing to this project, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CRAG: Comprehensive RAG Benchmark 2 | 3 | The Comprehensive RAG Benchmark (CRAG) is a rich and comprehensive factual question answering benchmark designed to advance research in RAG. Besides question-answer pairs, CRAG provides mock APIs to simulate web and knowledge graph search. CRAG is designed to encapsulate a diverse array of questions across five domains and eight question categories, reflecting varied entity popularity from popular to long-tail, and temporal dynamisms ranging from years to seconds. 4 | 5 | This repository is migrated from [meta-comprehensive-rag-benchmark-kdd-cup-2024](https://gitlab.aicrowd.com/aicrowd/challenges/meta-comprehensive-rag-benchmark-kdd-cup-2024). 6 | 7 | ## 📊 Dataset and Mock APIs 8 | 9 | Please find more details about the CRAG dataset (download, schema, etc.) in [docs/dataset.md](docs/dataset.md) and mock APIs in [mock_api](mock_api). 10 | 11 | 12 | ## 📏 Evaluation Metrics 13 | RAG systems are evaluated using a scoring method that measures response quality to questions in the evaluation set. Responses are rated as perfect, acceptable, missing, or incorrect: 14 | 15 | - Perfect: The response correctly answers the user question and contains no hallucinated content. 16 | 17 | - Acceptable: The response provides a useful answer to the user question, but may contain minor errors that do not harm the usefulness of the answer. 18 | 19 | - Missing: The answer does not provide the requested information. Such as “I don’t know”, “I’m sorry I can’t find …” or similar sentences without providing a concrete answer to the question. 20 | 21 | - Incorrect: The response provides wrong or irrelevant information to answer the user question 22 | 23 | 24 | Auto-evaluation: 25 | - Automatic evaluation employs rule-based matching and LLM assessment to check answer correctness. It will assign three scores: correct (1 point), missing (0 points), and incorrect (-1 point). 26 | 27 | 28 | Please refer to [local_evaluation.py](local_evaluation.py) for more details on how the evaluation was implemented. 29 | 30 | ## ✍️ How to run end-to-end evaluation? 31 | 1. **Install** specific dependencies 32 | ```bash 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | 2. Please follow the instructions in [models/README.md](models/README.md) for instructions and examples on how to write your own models. 37 | 38 | 3. After writing your own model(s), update [models/user_config.py](models/user_config.py) 39 | 40 | For example, in models/user_config.py, specify InstructModel to call llama3-8b-instruct model 41 | ```bash 42 | from models.vanilla_llama_baseline import InstructModel 43 | UserModel = InstructModel 44 | 45 | ``` 46 | 47 | 4. Test your model locally using `python local_evaluation.py`. This script will run answer generation and auto-evaluation. 48 | 49 | 50 | ## 🏁 Baselines 51 | We include three baselines for demonstration purposes, and you can read more about them in [docs/baselines.md](docs/baselines.md). 52 | 53 | 54 | ## Citations 55 | 56 | ``` 57 | @article{yang2024crag, 58 | title={CRAG -- Comprehensive RAG Benchmark}, 59 | author={Xiao Yang and Kai Sun and Hao Xin and Yushi Sun and Nikita Bhalla and Xiangsen Chen and Sajal Choudhary and Rongze Daniel Gui and Ziran Will Jiang and Ziyu Jiang and Lingkun Kong and Brian Moran and Jiaqi Wang and Yifan Ethan Xu and An Yan and Chenyu Yang and Eting Yuan and Hanwen Zha and Nan Tang and Lei Chen and Nicolas Scheffer and Yue Liu and Nirav Shah and Rakesh Wanga and Anuj Kumar and Wen-tau Yih and Xin Luna Dong}, 60 | year={2024}, 61 | journal={arXiv preprint arXiv:2406.04744}, 62 | url={https://arxiv.org/abs/2406.04744} 63 | } 64 | ``` 65 | 66 | ## License 67 | 68 | This project is licensed under the [Creative Commons Attribution-NonCommercial 4.0 International License (CC BY-NC 4.0)](LICENSE). This license permits sharing and adapting the work, provided it's not used for commercial purposes and appropriate credit is given. For a quick overview, visit [Creative Commons License](https://creativecommons.org/licenses/by-nc/4.0/). 69 | -------------------------------------------------------------------------------- /data/.gitattributes: -------------------------------------------------------------------------------- 1 | *.bz2 filter=lfs diff=lfs merge=lfs -text 2 | *.bz2.part* filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /data/crag_task_1_and_2_dev_v4.jsonl.bz2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:afa29f2b3facfb5d15aa9cded00d5ec90ff76f3e67279e7b99cfe86659a641ca 3 | size 739384484 4 | -------------------------------------------------------------------------------- /data/crag_task_3_dev_v4.tar.bz2.part1: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3a3f642643adc427b356b34fd078ba867c3905aa4341fc181457197aee9a55cf 3 | size 2097152000 4 | -------------------------------------------------------------------------------- /data/crag_task_3_dev_v4.tar.bz2.part2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:abd76c029079b93065aba7fb4fc0e129a803de7ee0e1511432d0037d1782aa7c 3 | size 2097152000 4 | -------------------------------------------------------------------------------- /data/crag_task_3_dev_v4.tar.bz2.part3: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:d82e3c94e29a361635ed6b56797747b77ef3cf07dc454736d661461f32c23dab 3 | size 2097152000 4 | -------------------------------------------------------------------------------- /data/crag_task_3_dev_v4.tar.bz2.part4: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:ed48b1ca55606a7d22be4809b10d7de0251d429f01e689a066c6def2428995ad 3 | size 1432877269 4 | -------------------------------------------------------------------------------- /docs/baselines.md: -------------------------------------------------------------------------------- 1 | # CRAG Baselines 2 | 3 | For the CRAG benchmark, we provide users with three baseline models to help get started. Detailed implementations of these baseline models are accessible through the links provided below. Refer to [this page](download_baseline_model_weights.md) for steps to download (and check in) the models weights required for the baseline models. 4 | 5 | 6 | ## Available Baseline Models: 7 | 8 | 1. [**Vanilla Llama 3 Model**](../models/vanilla_llama_baseline.py): For an implementation guide and further details, refer to the Vanilla Llama 3 model inline documentation [here](../models/vanilla_llama_baseline.py). 9 | 10 | 2. [**RAG Baseline Model**](../models/rag_llama_baseline.py): For an implementation guide and further details, refer to the RAG Baseline model inline documentation [here](../models/rag_llama_baseline.py). 11 | 12 | 3. [**RAG Knowledge Graph Baseline Model**](../models/rag_knowledge_graph_baseline.py): For an implementation guide and further details, refer to the RAG Baseline model inline documentation [here](../models/rag_knowledge_graph_baseline.py). 13 | 14 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # CRAG Dataset Documentation 2 | 3 | ## Overview 4 | 5 | The CRAG dataset is designed to support the development and evaluation of Retrieval-Augmented Generation (RAG) models. It consists of two main types of data: 6 | 7 | 1. **Question Answering Pairs:** Pairs of questions and their corresponding answers. 8 | 2. **Retrieval Contents:** Contents for information retrieval to support answer generation. 9 | 10 | Retrieval contents are divided into two types to simulate practical scenarios for RAG: 11 | 12 | 1. **Web Search Results:** For each question, up to `50` **full HTML pages** are stored, retrieved using the question text as a search query. For Task 1 & 2, `5 pages` are **randomly selected** from the `top-10 pages`. These pages are likely relevant to the question, but relevance is not guaranteed. 13 | 2. **Mock KGs and APIs:** The Mock API is designed to mimic real-world **Knowledge Graphs (KGs)** or **API searches**. Given some input parameters, they output relevant results, which may or may not be helpful in answering the user's question. 14 | 15 | ## Download CRAG Data 16 | 17 | - **Task #1:** [QA Pairs & Retrieval Contents](https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/crag_task_1_and_2_dev_v4.jsonl.bz2?download=) 18 | - **Task #2:** [QA Pairs & Retrieval Contents](https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/crag_task_1_and_2_dev_v4.jsonl.bz2?download=), [Mock KGs and APIs](/mock_api) 19 | - **Task #3:** QA Pairs & Retrieval Contents (download part [1](https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/crag_task_3_dev_v4.tar.bz2.part1?download=), [2](https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/crag_task_3_dev_v4.tar.bz2.part2?download=), [3](https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/crag_task_3_dev_v4.tar.bz2.part3?download=), and [4](https://github.com/facebookresearch/CRAG/raw/refs/heads/main/data/crag_task_3_dev_v4.tar.bz2.part4?download=); then merge them by `cat crag_task_3_dev_v4.tar.bz2.part* > crag_task_3_dev_v4.tar.bz2`), [Mock KGs and APIs](/mock_api) 20 | 21 | ## Data Schema 22 | 23 | | Field Name | Type | Description | 24 | |------------------------|---------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------| 25 | | `interaction_id` | string | A unique identifier for each example. | 26 | | `query_time` | string | Date and time when the query and the web search occurred. | 27 | | `domain` | string | Domain label for the query. Possible values: "finance", "music", "movie", "sports", "open". "Open" includes any factual queries not among the previous four domains. | 28 | | `question_type` | string | Type label about the query. Possible values include: "simple", "simple_w_condition", "comparison", "aggregation", "set", "false_premise", "post-processing", "multi-hop". | 29 | | `static_or_dynamic` | string | Indicates whether the answer to a question changes and the expected rate of change. Possible values: "static", "slow-changing", "fast-changing", and "real-time". | 30 | | `query` | string | The question for RAG to answer. | 31 | | `answer` | string | The gold standard answer to the question. | 32 | | `alt_ans` | list | Other valid gold standard answers to the question. | 33 | | `split` | integer | Data split indicator, where 0 is for validation and 1 is for the public test. | 34 | | `search_results` | list of JSON | Contains up to `k` HTML pages for each query (`k=5` for Task #1 and `k=50` for Task #3), including page name, URL, snippet, full HTML, and last modified time. | 35 | 36 | ### Search Results Detail 37 | 38 | | Key | Type | Description | 39 | |----------------------|--------|---------------------------------------------------------| 40 | | `page_name` | string | The name of the webpage. | 41 | | `page_url` | string | The URL of the webpage. | 42 | | `page_snippet` | string | A short paragraph describing the major content of the page. | 43 | | `page_result` | string | The full HTML of the webpage. | 44 | | `page_last_modified` | string | The time when the page was last modified. | 45 | 46 | -------------------------------------------------------------------------------- /docs/download_baseline_model_weights.md: -------------------------------------------------------------------------------- 1 | ### Setting Up and Downloading Baseline Model weights with Hugging Face 2 | 3 | This guide outlines the steps to download (and check in) the models weights required for the baseline models. 4 | We will focus on the `Meta-Llama-3-8B-Instruct` and `all-MiniLM-L6-v2` models. 5 | But the steps should work equally well for any other models on hugging face. 6 | 7 | #### Preliminary Steps: 8 | 9 | 1. **Install the Hugging Face Hub Package**: 10 | 11 | Begin by installing the `huggingface_hub` package, which includes the `hf_transfer` utility, by running the following command in your terminal: 12 | 13 | ```bash 14 | pip install huggingface_hub[hf_transfer] 15 | ``` 16 | 17 | 2. **Accept the Llama Terms**: 18 | 19 | You must accept the Llama model's terms of use by visiting: [meta-llama/Meta-Llama-3-8B-Instruct Terms](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct). 20 | 21 | 3. **Create a Hugging Face CLI Token**: 22 | 23 | Generate a CLI token by navigating to: [Hugging Face Token Settings](https://huggingface.co/settings/tokens). You will need this token for authentication. 24 | 25 | #### Hugging Face Authentication: 26 | 27 | 1. **Login via CLI**: 28 | 29 | Authenticate yourself with the Hugging Face CLI using the token created in the previous step. Run: 30 | 31 | ```bash 32 | huggingface-cli login 33 | ``` 34 | 35 | When prompted, enter the token. 36 | 37 | #### Model Downloads: 38 | 39 | 1. **Download LLaMA-2-7b Model**: 40 | 41 | Execute the following command to download the `Meta-Llama-3-8B-Instruct` model to a local subdirectory. This command excludes unnecessary files to save space: 42 | 43 | ```bash 44 | HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download \ 45 | meta-llama/Meta-Llama-3-8B-Instruct \ 46 | --local-dir-use-symlinks False \ 47 | --local-dir models/meta-llama/Meta-Llama-3-8B-Instruct \ 48 | --exclude *.pth # These are alternates to the safetensors hence not needed 49 | ``` 50 | 51 | 3. **Download MiniLM-L6-v2 Model (for sentence embeddings)**: 52 | 53 | Similarly, download the `sentence-transformers/all-MiniLM-L6-v2` model using the following command: 54 | 55 | ```bash 56 | HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download \ 57 | sentence-transformers/all-MiniLM-L6-v2 \ 58 | --local-dir-use-symlinks False \ 59 | --local-dir models/sentence-transformers/all-MiniLM-L6-v2 \ 60 | --exclude *.bin *.h5 *.ot # These are alternates to the safetensors hence not needed 61 | ``` 62 | -------------------------------------------------------------------------------- /example_data/dev_data.jsonl.bz2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/CRAG/c71ad61ea4f18ab0cc4b5b009932bc76e21be394/example_data/dev_data.jsonl.bz2 -------------------------------------------------------------------------------- /local_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import bz2 8 | import json 9 | import os 10 | import re 11 | from datetime import datetime 12 | 13 | from loguru import logger 14 | from openai import APIConnectionError, OpenAI, RateLimitError 15 | from prompts.templates import IN_CONTEXT_EXAMPLES, INSTRUCTIONS 16 | from tqdm.auto import tqdm 17 | from transformers import LlamaTokenizerFast 18 | from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig 19 | 20 | tokenizer = LlamaTokenizerFast.from_pretrained("tokenizer") 21 | 22 | 23 | def load_json_file(file_path): 24 | """Load and return the content of a JSON file.""" 25 | logger.info(f"Loading JSON from {file_path}") 26 | with open(file_path) as f: 27 | return json.load(f) 28 | 29 | 30 | def get_system_message(): 31 | """Returns the system message containing instructions and in context examples.""" 32 | return INSTRUCTIONS + "\n" + IN_CONTEXT_EXAMPLES 33 | 34 | 35 | def attempt_api_call(client, model_name, messages, max_retries=10): 36 | """Attempt an API call with retries upon encountering specific errors.""" 37 | # todo: add default response when all efforts fail 38 | for attempt in range(max_retries): 39 | try: 40 | response = client.chat.completions.create( 41 | model=model_name, 42 | messages=messages, 43 | response_format={"type": "json_object"}, 44 | temperature=0.0, 45 | ) 46 | return response.choices[0].message.content 47 | except (APIConnectionError, RateLimitError): 48 | logger.warning(f"API call failed on attempt {attempt + 1}, retrying...") 49 | except Exception as e: 50 | logger.error(f"Unexpected error: {e}") 51 | break 52 | return None 53 | 54 | 55 | def log_response(messages, response, output_directory="api_responses"): 56 | """Save the response from the API to a file.""" 57 | os.makedirs(output_directory, exist_ok=True) 58 | file_name = datetime.now().strftime("%d-%m-%Y-%H-%M-%S.json") 59 | file_path = os.path.join(output_directory, file_name) 60 | with open(file_path, "w") as f: 61 | json.dump({"messages": messages, "response": response}, f) 62 | 63 | 64 | def parse_response(response: str): 65 | """ 66 | Return a tuple of (explanation, score) from the response, 67 | where score is 0 if the prediction is wrong, 1 if the prediction is correct. 68 | 69 | Need to handle 70 | Corner case 1: 71 | {"explanation": ...} 72 | Wait, no! I made a mistake. The prediction does not exactly match the ground truth. ... 73 | {...} 74 | 75 | Corner case 2: 76 | {"score": 0, "explanation": "The prediction does not contain item, nick "goose" bradshaw, that is in the ground truth."} 77 | return a tuple of (explanation, score) 78 | """ 79 | matches = re.findall(r"{([^}]*)}", response) 80 | text = "" 81 | for match in matches: 82 | text = "{" + match + "}" 83 | try: 84 | score = -1 85 | # Pattern to match the score 86 | score_pattern = r'"score"\s*:\s*(\d+)' 87 | score_match = re.search(score_pattern, text) 88 | if score_match: 89 | score = int(score_match.group(1)) 90 | if score != 0 and score != 1: 91 | raise Exception("bad score: " + response) 92 | else: 93 | return "Parse Err: Score not found", -1 94 | 95 | # Pattern to match the explanation 96 | explanation_pattern = r'"explanation"\s*:\s*"(.+)"' 97 | explanation_match = re.search(explanation_pattern, text) 98 | if explanation_match: 99 | explanation = explanation_match.group(1) 100 | return explanation, score 101 | else: 102 | return text, score 103 | except Exception as e: 104 | print(f"Parsing Error with resp: {response}") 105 | print(f"Error: {e}") 106 | return response, -1 107 | 108 | 109 | def trim_predictions_to_max_token_length(prediction): 110 | """Trims prediction output to 75 tokens using Llama2 tokenizer""" 111 | max_token_length = 75 112 | tokenized_prediction = tokenizer.encode(prediction) 113 | trimmed_tokenized_prediction = tokenized_prediction[1 : max_token_length + 1] 114 | trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction) 115 | return trimmed_prediction 116 | 117 | 118 | def load_data_in_batches(dataset_path, batch_size): 119 | """ 120 | Generator function that reads data from a compressed file and yields batches of data. 121 | Each batch is a dictionary containing lists of interaction_ids, queries, search results, query times, and answers. 122 | 123 | Args: 124 | dataset_path (str): Path to the dataset file. 125 | batch_size (int): Number of data items in each batch. 126 | 127 | Yields: 128 | dict: A batch of data. 129 | """ 130 | def initialize_batch(): 131 | """ Helper function to create an empty batch. """ 132 | return {"interaction_id": [], "query": [], "search_results": [], "query_time": [], "answer": []} 133 | 134 | try: 135 | with bz2.open(dataset_path, "rt") as file: 136 | batch = initialize_batch() 137 | for line in file: 138 | try: 139 | item = json.loads(line) 140 | for key in batch: 141 | batch[key].append(item[key]) 142 | 143 | if len(batch["query"]) == batch_size: 144 | yield batch 145 | batch = initialize_batch() 146 | except json.JSONDecodeError: 147 | logger.warn("Warning: Failed to decode a line.") 148 | # Yield any remaining data as the last batch 149 | if batch["query"]: 150 | yield batch 151 | except FileNotFoundError as e: 152 | logger.error(f"Error: The file {dataset_path} was not found.") 153 | raise e 154 | except IOError as e: 155 | logger.error(f"Error: An error occurred while reading the file {dataset_path}.") 156 | raise e 157 | 158 | 159 | 160 | def generate_predictions(dataset_path, participant_model): 161 | """ 162 | Processes batches of data from a dataset to generate predictions using a model. 163 | 164 | Args: 165 | dataset_path (str): Path to the dataset. 166 | participant_model (object): UserModel that provides `get_batch_size()` and `batch_generate_answer()` interfaces. 167 | 168 | Returns: 169 | tuple: A tuple containing lists of queries, ground truths, and predictions. 170 | """ 171 | queries, ground_truths, predictions = [], [], [] 172 | batch_size = participant_model.get_batch_size() 173 | 174 | for batch in tqdm(load_data_in_batches(dataset_path, batch_size), desc="Generating predictions"): 175 | batch_ground_truths = batch.pop("answer") # Remove answers from batch and store them 176 | batch_predictions = participant_model.batch_generate_answer(batch) 177 | 178 | queries.extend(batch["query"]) 179 | ground_truths.extend(batch_ground_truths) 180 | predictions.extend(batch_predictions) 181 | 182 | return queries, ground_truths, predictions 183 | 184 | 185 | def evaluate_predictions(queries, ground_truths_list, predictions, evaluation_model_name): 186 | """ 187 | Evaluates the predictions generated by a model against ground truth answers. 188 | 189 | Args: 190 | queries (List[str]): List of queries. 191 | ground_truths_list (List[List[str]]): List of lists of ground truth answers. 192 | Note each query can have multiple ground truth answers. 193 | predictions (list): List of predictions generated by the model. 194 | evaluation_model_name (str): Name of the evaluation model. 195 | 196 | Returns: 197 | dict: A dictionary containing evaluation results. 198 | """ 199 | 200 | if "chat" in evaluation_model_name.lower(): 201 | # now we are using chatgpt 202 | openai_client = OpenAI() 203 | n_miss, n_correct = 0, 0 204 | system_message = get_system_message() 205 | 206 | for _idx, prediction in enumerate(tqdm( 207 | predictions, total=len(predictions), desc="Evaluating Predictions" 208 | )): 209 | query = queries[_idx] 210 | ground_truths = ground_truths_list[_idx].strip() 211 | # trim prediction to 75 tokens using Llama2 tokenizer 212 | prediction = trim_predictions_to_max_token_length(prediction) 213 | prediction = prediction.strip() 214 | 215 | if "i don't know" in prediction_lowercase: 216 | n_miss += 1 217 | continue 218 | 219 | accuracy = -1 220 | 221 | for ground_truth in ground_truths: 222 | ground_truth_lowercase = ground_truth.lower() 223 | prediction_lowercase = prediction.lower() 224 | messages = [ 225 | {"role": "system", "content": system_message}, 226 | { 227 | "role": "user", 228 | "content": f"Question: {query}\n Ground truth: {ground_truth}\n Prediction: {prediction}\n", 229 | }, 230 | ] 231 | if prediction_lowercase == ground_truth_lowercase: 232 | # exact correct 233 | accuracy = 1 234 | break 235 | elif "invalid" in prediction_lowercase and "invalid" in ground_truth_lowercase: 236 | accuracy = 1 237 | break 238 | elif "invalid" in prediction_lowercase and "invalid" not in ground_truth_lowercase: 239 | # hallucination 240 | accuracy = 0 241 | continue 242 | elif "invalid" not in prediction_lowercase and "invalid" in ground_truth_lowercase: 243 | # hallucination 244 | accuracy = 0 245 | continue 246 | else: 247 | # need to use the OpenAI evaluation model to get the accuracy result (0 means wrong, 1 means correct) 248 | response = attempt_api_call(openai_client, evaluation_model_name, messages) 249 | if response: 250 | log_response(messages, response) 251 | _, accuracy = parse_response(response) 252 | if accuracy == 1: 253 | # no need to check other ground truth(s) 254 | break 255 | 256 | if accuracy == 1: 257 | n_correct += 1 258 | 259 | n = len(predictions) 260 | results = { 261 | "score": (2 * n_correct + n_miss) / n - 1, 262 | "accuracy": n_correct / n, 263 | "hallucination": (n - n_correct - n_miss) / n, 264 | "missing": n_miss / n, 265 | "n_miss": n_miss, 266 | "n_correct": n_correct, 267 | "n_hallucination": n - n_correct - n_miss, 268 | "total": n, 269 | } 270 | logger.info(results) 271 | return results 272 | elif "llama" in evaluation_model_name.lower(): 273 | # now we are using llama model to evaluate 274 | # to be filled by Jiaqi 275 | raise NotImplementedError("Llama evaluation model is not implemented yet.") 276 | else: 277 | raise NotImplementedError(f"Unknown evaluation model: {evaluation_model_name}") 278 | 279 | 280 | if __name__ == "__main__": 281 | from models.user_config import UserModel 282 | 283 | DATASET_PATH = "example_data/dev_data.jsonl.bz2" 284 | EVALUATION_MODEL_NAME = os.getenv("EVALUATION_MODEL_NAME", "gpt-4-0125-preview") 285 | 286 | # Generate predictions 287 | participant_model = UserModel() 288 | queries, ground_truths, predictions = generate_predictions(DATASET_PATH, participant_model) 289 | # Evaluate Predictions 290 | openai_client = OpenAI() 291 | evaluation_results = evaluate_predictions( 292 | queries, ground_truths, predictions, EVALUATION_MODEL_NAME, openai_client 293 | ) 294 | -------------------------------------------------------------------------------- /mock_api/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints -------------------------------------------------------------------------------- /mock_api/README.md: -------------------------------------------------------------------------------- 1 | # Comprehensive RAG Benchmark (CRAG) Mock API 2 | 3 | ## Prerequisites 4 | 5 | Before diving into the setup and usage of the CRAG Mock API, ensure you have the following prerequisites installed and set up on your system: 6 | - Git (for cloning the repository) 7 | - Python 3.10 8 | 9 | ## Installation Guide 10 | 11 | ### Setting Up Your Environment 12 | 13 | First, clone the repository to your local machine using Git. Then, navigate to the repository directory and install the necessary dependencies: 14 | 15 | ``` 16 | cd mock_api 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## Running the API Server 21 | 22 | To launch the API server on your local machine, use the following Uvicorn command. This starts a fast, asynchronous server to handle API requests. 23 | 24 | ``` 25 | uvicorn server:app --reload 26 | ``` 27 | 28 | Access the API documentation and test the endpoints at `http://127.0.0.1:8000/docs`. 29 | 30 | For custom server configurations, specify the host and port as follows: 31 | 32 | ``` 33 | uvicorn server:app --reload --host [HOST] --port [PORT] 34 | ``` 35 | 36 | ## System Requirements 37 | 38 | - **Supported OS**: Linux, Windows, macOS 39 | - **Python Version**: 3.10 40 | - See `requirements.txt` for a complete list of Python package dependencies. 41 | 42 | ## Python API Wrapper 43 | 44 | For Python developers, the [/mock_api/apiwrapper/pycragapi.py](/mock_api/apiwrapper/pycragapi.py) provides a convenient way to interact with the API. An example usage is demonstrated in [/mock_api/apiwrapper/example_call.ipynb](/mock_api/apiwrapper/example_call.ipynb), showcasing how to efficiently integrate the API into your development workflow. 45 | -------------------------------------------------------------------------------- /mock_api/apiwrapper/pycragapi.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | from typing import List 10 | 11 | import requests 12 | 13 | 14 | class CRAG(object): 15 | """ 16 | A client for interacting with the CRAG server, offering methods to query various domains such as Open, Movie, Finance, Music, and Sports. Each method corresponds to an API endpoint on the CRAG server. 17 | 18 | Attributes: 19 | server (str): The base URL of the CRAG server. Defaults to "http://127.0.0.1:8000". 20 | 21 | Methods: 22 | open_search_entity_by_name(query: str) -> dict: Search for entities by name in the Open domain. 23 | open_get_entity(entity: str) -> dict: Retrieve detailed information about an entity in the Open domain. 24 | movie_get_person_info(person_name: str) -> dict: Get information about a person related to movies. 25 | movie_get_movie_info(movie_name: str) -> dict: Get information about a movie. 26 | movie_get_year_info(year: str) -> dict: Get information about movies released in a specific year. 27 | movie_get_movie_info_by_id(movie_id: int) -> dict: Get movie information by its unique ID. 28 | movie_get_person_info_by_id(person_id: int) -> dict: Get person information by their unique ID. 29 | finance_get_company_name(query: str) -> dict: Search for company names in the finance domain. 30 | finance_get_ticker_by_name(query: str) -> dict: Retrieve the ticker symbol for a given company name. 31 | finance_get_price_history(ticker_name: str) -> dict: Get the price history for a given ticker symbol. 32 | finance_get_detailed_price_history(ticker_name: str) -> dict: Get detailed price history for a ticker symbol. 33 | finance_get_dividends_history(ticker_name: str) -> dict: Get dividend history for a ticker symbol. 34 | finance_get_market_capitalization(ticker_name: str) -> dict: Retrieve market capitalization for a ticker symbol. 35 | finance_get_eps(ticker_name: str) -> dict: Get earnings per share (EPS) for a ticker symbol. 36 | finance_get_pe_ratio(ticker_name: str) -> dict: Get the price-to-earnings (PE) ratio for a ticker symbol. 37 | finance_get_info(ticker_name: str) -> dict: Get financial information for a ticker symbol. 38 | music_search_artist_entity_by_name(artist_name: str) -> dict: Search for music artists by name. 39 | music_search_song_entity_by_name(song_name: str) -> dict: Search for songs by name. 40 | music_get_billboard_rank_date(rank: int, date: str = None) -> dict: Get Billboard ranking for a specific rank and date. 41 | music_get_billboard_attributes(date: str, attribute: str, song_name: str) -> dict: Get attributes of a song from Billboard rankings. 42 | music_grammy_get_best_artist_by_year(year: int) -> dict: Get the Grammy Best New Artist for a specific year. 43 | music_grammy_get_award_count_by_artist(artist_name: str) -> dict: Get the total Grammy awards won by an artist. 44 | music_grammy_get_award_count_by_song(song_name: str) -> dict: Get the total Grammy awards won by a song. 45 | music_grammy_get_best_song_by_year(year: int) -> dict: Get the Grammy Song of the Year for a specific year. 46 | music_grammy_get_award_date_by_artist(artist_name: str) -> dict: Get the years an artist won a Grammy award. 47 | music_grammy_get_best_album_by_year(year: int) -> dict: Get the Grammy Album of the Year for a specific year. 48 | music_grammy_get_all_awarded_artists() -> dict: Get all artists awarded the Grammy Best New Artist. 49 | music_get_artist_birth_place(artist_name: str) -> dict: Get the birthplace of an artist. 50 | music_get_artist_birth_date(artist_name: str) -> dict: Get the birth date of an artist. 51 | music_get_members(band_name: str) -> dict: Get the member list of a band. 52 | music_get_lifespan(artist_name: str) -> dict: Get the lifespan of an artist. 53 | music_get_song_author(song_name: str) -> dict: Get the author of a song. 54 | music_get_song_release_country(song_name: str) -> dict: Get the release country of a song. 55 | music_get_song_release_date(song_name: str) -> dict: Get the release date of a song. 56 | music_get_artist_all_works(artist_name: str) -> dict: Get all works by an artist. 57 | sports_soccer_get_games_on_date(team_name: str, date: str) -> dict: Get soccer games on a specific date. 58 | Result includes game attributes such as date, time, 59 | GF: GF: Goals For - the number of goals scored by the team in question during the match, 60 | GA: Goals Against - the number of goals conceded by the team during the match, 61 | xG: Expected Goals - a statistical measure that estimates the number of goals a team should have scored based on the quality of chances they created, 62 | xGA: Expected Goals Against - a measure estimating the number of goals a team should have conceded based on the quality of chances allowed to the opponent, 63 | Poss: Possession - the percentage of the match time during which the team had possession of the ball. 64 | sports_nba_get_games_on_date(team_name: str, date: str) -> dict: Get NBA games on a specific date. 65 | Result includes game attributes such as 66 | team_name_home: The full name of the home team, 67 | wl_home: The outcome of the game for the home team, 68 | pts_home: The total number of points scored by the home team. 69 | sports_nba_get_play_by_play_data_by_game_ids(game_ids: List[str]) -> dict: Get NBA play by play data for a set of game ids. 70 | Result includes play-by-play event time, description, player etc. 71 | 72 | Note: 73 | Each method performs a POST request to the corresponding API endpoint and returns the response as a JSON dictionary. 74 | """ 75 | def __init__(self): 76 | self.server = os.environ.get('CRAG_SERVER', "http://127.0.0.1:8000") 77 | 78 | def open_search_entity_by_name(self, query:str): 79 | url = self.server + '/open/search_entity_by_name' 80 | headers={'accept': "application/json"} 81 | data = {'query': query} 82 | result = requests.post(url, json=data, headers=headers) 83 | return json.loads(result.text) 84 | 85 | def open_get_entity(self, entity:str): 86 | url = self.server + '/open/get_entity' 87 | headers={'accept': "application/json"} 88 | data = {'query': entity} 89 | result = requests.post(url, json=data, headers=headers) 90 | return json.loads(result.text) 91 | 92 | def movie_get_person_info(self, person_name:str): 93 | url = self.server + '/movie/get_person_info' 94 | headers={'accept': "application/json"} 95 | data = {'query': person_name} 96 | result = requests.post(url, json=data, headers=headers) 97 | return json.loads(result.text) 98 | 99 | def movie_get_movie_info(self, movie_name:str): 100 | url = self.server + '/movie/get_movie_info' 101 | headers={'accept': "application/json"} 102 | data = {'query': movie_name} 103 | result = requests.post(url, json=data, headers=headers) 104 | return json.loads(result.text) 105 | 106 | def movie_get_year_info(self, year:str): 107 | url = self.server + '/movie/get_year_info' 108 | headers={'accept': "application/json"} 109 | data = {'query': year} 110 | result = requests.post(url, json=data, headers=headers) 111 | return json.loads(result.text) 112 | 113 | def movie_get_movie_info_by_id(self, movid_id:int): 114 | url = self.server + '/movie/get_movie_info_by_id' 115 | headers={'accept': "application/json"} 116 | data = {'query': movid_id} 117 | result = requests.post(url, json=data, headers=headers) 118 | return json.loads(result.text) 119 | 120 | def movie_get_person_info_by_id(self, person_id:int): 121 | url = self.server + '/movie/get_person_info_by_id' 122 | headers={'accept': "application/json"} 123 | data = {'query': person_id} 124 | result = requests.post(url, json=data, headers=headers) 125 | return json.loads(result.text) 126 | 127 | def finance_get_company_name(self, query:str): 128 | url = self.server + '/finance/get_company_name' 129 | headers={'accept': "application/json"} 130 | data = {'query': query} 131 | result = requests.post(url, json=data, headers=headers) 132 | return json.loads(result.text) 133 | 134 | def finance_get_ticker_by_name(self, query:str): 135 | url = self.server + '/finance/get_ticker_by_name' 136 | headers={'accept': "application/json"} 137 | data = {'query': query} 138 | result = requests.post(url, json=data, headers=headers) 139 | return json.loads(result.text) 140 | 141 | def finance_get_price_history(self, ticker_name:str): 142 | url = self.server + '/finance/get_price_history' 143 | headers={'accept': "application/json"} 144 | data = {'query': ticker_name} 145 | result = requests.post(url, json=data, headers=headers) 146 | return json.loads(result.text) 147 | 148 | def finance_get_detailed_price_history(self, ticker_name:str): 149 | url = self.server + '/finance/get_detailed_price_history' 150 | headers={'accept': "application/json"} 151 | data = {'query': ticker_name} 152 | result = requests.post(url, json=data, headers=headers) 153 | return json.loads(result.text) 154 | 155 | def finance_get_dividends_history(self, ticker_name:str): 156 | url = self.server + '/finance/get_dividends_history' 157 | headers={'accept': "application/json"} 158 | data = {'query': ticker_name} 159 | result = requests.post(url, json=data, headers=headers) 160 | return json.loads(result.text) 161 | 162 | def finance_get_market_capitalization(self, ticker_name:str): 163 | url = self.server + '/finance/get_market_capitalization' 164 | headers={'accept': "application/json"} 165 | data = {'query': ticker_name} 166 | result = requests.post(url, json=data, headers=headers) 167 | return json.loads(result.text) 168 | 169 | def finance_get_eps(self, ticker_name:str): 170 | url = self.server + '/finance/get_eps' 171 | headers={'accept': "application/json"} 172 | data = {'query': ticker_name} 173 | result = requests.post(url, json=data, headers=headers) 174 | return json.loads(result.text) 175 | 176 | def finance_get_pe_ratio(self, ticker_name:str): 177 | url = self.server + '/finance/get_pe_ratio' 178 | headers={'accept': "application/json"} 179 | data = {'query': ticker_name} 180 | result = requests.post(url, json=data, headers=headers) 181 | return json.loads(result.text) 182 | 183 | def finance_get_info(self, ticker_name:str): 184 | url = self.server + '/finance/get_info' 185 | headers={'accept': "application/json"} 186 | data = {'query': ticker_name} 187 | result = requests.post(url, json=data, headers=headers) 188 | return json.loads(result.text) 189 | 190 | def music_search_artist_entity_by_name(self, artist_name:str): 191 | url = self.server + '/music/search_artist_entity_by_name' 192 | headers={'accept': "application/json"} 193 | data = {'query': artist_name} 194 | result = requests.post(url, json=data, headers=headers) 195 | return json.loads(result.text) 196 | 197 | def music_search_song_entity_by_name(self, song_name:str): 198 | url = self.server + '/music/search_song_entity_by_name' 199 | headers={'accept': "application/json"} 200 | data = {'query': song_name} 201 | result = requests.post(url, json=data, headers=headers) 202 | return json.loads(result.text) 203 | 204 | def music_get_billboard_rank_date(self, rank:int, date:str=None): 205 | url = self.server + '/music/get_billboard_rank_date' 206 | headers={'accept': "application/json"} 207 | data = {'rank': rank, 'date': date} 208 | result = requests.post(url, json=data, headers=headers) 209 | return json.loads(result.text) 210 | 211 | def music_get_billboard_attributes(self, date:str, attribute:str, song_name:str): 212 | url = self.server + '/music/get_billboard_attributes' 213 | headers={'accept': "application/json"} 214 | data = {'date': date, 'attribute': attribute, 'song_name': song_name} 215 | result = requests.post(url, json=data, headers=headers) 216 | return json.loads(result.text) 217 | 218 | def music_grammy_get_best_artist_by_year(self, year:int): 219 | url = self.server + '/music/grammy_get_best_artist_by_year' 220 | headers={'accept': "application/json"} 221 | data = {'query': year} 222 | result = requests.post(url, json=data, headers=headers) 223 | return json.loads(result.text) 224 | 225 | def music_grammy_get_award_count_by_artist(self, artist_name:str): 226 | url = self.server + '/music/grammy_get_award_count_by_artist' 227 | headers={'accept': "application/json"} 228 | data = {'query': artist_name} 229 | result = requests.post(url, json=data, headers=headers) 230 | return json.loads(result.text) 231 | 232 | def music_grammy_get_award_count_by_song(self, song_name:str): 233 | url = self.server + '/music/grammy_get_award_count_by_song' 234 | headers={'accept': "application/json"} 235 | data = {'query': song_name} 236 | result = requests.post(url, json=data, headers=headers) 237 | return json.loads(result.text) 238 | 239 | def music_grammy_get_best_song_by_year(self, year:int): 240 | url = self.server + '/music/grammy_get_best_song_by_year' 241 | headers={'accept': "application/json"} 242 | data = {'query': year} 243 | result = requests.post(url, json=data, headers=headers) 244 | return json.loads(result.text) 245 | 246 | def music_grammy_get_award_date_by_artist(self, artist_name:str): 247 | url = self.server + '/music/grammy_get_award_date_by_artist' 248 | headers={'accept': "application/json"} 249 | data = {'query': artist_name} 250 | result = requests.post(url, json=data, headers=headers) 251 | return json.loads(result.text) 252 | 253 | def music_grammy_get_best_album_by_year(self, year:int): 254 | url = self.server + '/music/grammy_get_best_album_by_year' 255 | headers={'accept': "application/json"} 256 | data = {'query': year} 257 | result = requests.post(url, json=data, headers=headers) 258 | return json.loads(result.text) 259 | 260 | def music_grammy_get_all_awarded_artists(self): 261 | url = self.server + '/music/grammy_get_all_awarded_artists' 262 | headers={'accept': "application/json"} 263 | result = requests.post(url, headers=headers) 264 | return json.loads(result.text) 265 | 266 | def music_get_artist_birth_place(self, artist_name:str): 267 | url = self.server + '/music/get_artist_birth_place' 268 | headers={'accept': "application/json"} 269 | data = {'query': artist_name} 270 | result = requests.post(url, json=data, headers=headers) 271 | return json.loads(result.text) 272 | 273 | def music_get_artist_birth_date(self, artist_name:str): 274 | url = self.server + '/music/get_artist_birth_date' 275 | headers={'accept': "application/json"} 276 | data = {'query': artist_name} 277 | result = requests.post(url, json=data, headers=headers) 278 | return json.loads(result.text) 279 | 280 | def music_get_members(self, band_name:str): 281 | url = self.server + '/music/get_members' 282 | headers={'accept': "application/json"} 283 | data = {'query': band_name} 284 | result = requests.post(url, json=data, headers=headers) 285 | return json.loads(result.text) 286 | 287 | def music_get_lifespan(self, artist_name:str): 288 | url = self.server + '/music/get_lifespan' 289 | headers={'accept': "application/json"} 290 | data = {'query': artist_name} 291 | result = requests.post(url, json=data, headers=headers) 292 | return json.loads(result.text) 293 | 294 | def music_get_song_author(self, song_name:str): 295 | url = self.server + '/music/get_song_author' 296 | headers={'accept': "application/json"} 297 | data = {'query': song_name} 298 | result = requests.post(url, json=data, headers=headers) 299 | return json.loads(result.text) 300 | 301 | def music_get_song_release_country(self, song_name:str): 302 | url = self.server + '/music/get_song_release_country' 303 | headers={'accept': "application/json"} 304 | data = {'query': song_name} 305 | result = requests.post(url, json=data, headers=headers) 306 | return json.loads(result.text) 307 | 308 | def music_get_song_release_date(self, song_name:str): 309 | url = self.server + '/music/get_song_release_date' 310 | headers={'accept': "application/json"} 311 | data = {'query': song_name} 312 | result = requests.post(url, json=data, headers=headers) 313 | return json.loads(result.text) 314 | 315 | def music_get_artist_all_works(self, song_name:str): 316 | url = self.server + '/music/get_artist_all_works' 317 | headers={'accept': "application/json"} 318 | data = {'query': song_name} 319 | result = requests.post(url, json=data, headers=headers) 320 | return json.loads(result.text) 321 | 322 | def sports_soccer_get_games_on_date(self, date:str, team_name:str=None): 323 | url = self.server + '/sports/soccer/get_games_on_date' 324 | headers={'accept': "application/json"} 325 | data = {'team_name': team_name, 'date': date} 326 | result = requests.post(url, json=data, headers=headers) 327 | return json.loads(result.text) 328 | 329 | def sports_nba_get_games_on_date(self, date:str, team_name:str=None): 330 | url = self.server + '/sports/nba/get_games_on_date' 331 | headers={'accept': "application/json"} 332 | data = {'team_name': team_name, 'date': date} 333 | result = requests.post(url, json=data, headers=headers) 334 | return json.loads(result.text) 335 | 336 | def sports_nba_get_play_by_play_data_by_game_ids(self, game_ids:List[str]): 337 | url = self.server + '/sports/nba/get_play_by_play_data_by_game_ids' 338 | headers={'accept': "application/json"} 339 | data = {'game_ids': game_ids} 340 | result = requests.post(url, json=data, headers=headers) 341 | return json.loads(result.text) 342 | 343 | 344 | -------------------------------------------------------------------------------- /mock_api/cragapi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/CRAG/c71ad61ea4f18ab0cc4b5b009932bc76e21be394/mock_api/cragapi/__init__.py -------------------------------------------------------------------------------- /mock_api/cragapi/fast_bm25.py: -------------------------------------------------------------------------------- 1 | ''' 2 | MIT License 3 | 4 | Copyright (c) 2021 Teo Orthlieb 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | ''' 24 | 25 | import collections 26 | import heapq 27 | import math 28 | import pickle 29 | import sys 30 | PARAM_K1 = 1.5 31 | PARAM_B = 0.75 32 | IDF_CUTOFF = 4 33 | 34 | 35 | class BM25: 36 | """Fast Implementation of Best Matching 25 ranking function. 37 | 38 | Attributes 39 | ---------- 40 | t2d : > 41 | Dictionary with terms frequencies for each document in `corpus`. 42 | idf: 43 | Pre computed IDF score for every term. 44 | doc_len : list of int 45 | List of document lengths. 46 | avgdl : float 47 | Average length of document in `corpus`. 48 | """ 49 | def __init__(self, corpus, k1=PARAM_K1, b=PARAM_B, alpha=IDF_CUTOFF): 50 | """ 51 | Parameters 52 | ---------- 53 | corpus : list of list of str 54 | Given corpus. 55 | k1 : float 56 | Constant used for influencing the term frequency saturation. After saturation is reached, additional 57 | presence for the term adds a significantly less additional score. According to [1]_, experiments suggest 58 | that 1.2 < k1 < 2 yields reasonably good results, although the optimal value depends on factors such as 59 | the type of documents or queries. 60 | b : float 61 | Constant used for influencing the effects of different document lengths relative to average document length. 62 | When b is bigger, lengthier documents (compared to average) have more impact on its effect. According to 63 | [1]_, experiments suggest that 0.5 < b < 0.8 yields reasonably good results, although the optimal value 64 | depends on factors such as the type of documents or queries. 65 | alpha: float 66 | IDF cutoff, terms with a lower idf score than alpha will be dropped. A higher alpha will lower the accuracy 67 | of BM25 but increase performance 68 | """ 69 | self.k1 = k1 70 | self.b = b 71 | self.alpha = alpha 72 | 73 | self.avgdl = 0 74 | self.t2d = {} 75 | self.idf = {} 76 | self.doc_len = [] 77 | if corpus: 78 | self._initialize(corpus) 79 | 80 | @property 81 | def corpus_size(self): 82 | return len(self.doc_len) 83 | 84 | def _initialize(self, corpus): 85 | """Calculates frequencies of terms in documents and in corpus. Also computes inverse document frequencies.""" 86 | for i, document in enumerate(corpus): 87 | self.doc_len.append(len(document)) 88 | 89 | for word in document: 90 | if word not in self.t2d: 91 | self.t2d[word] = {} 92 | if i not in self.t2d[word]: 93 | self.t2d[word][i] = 0 94 | self.t2d[word][i] += 1 95 | 96 | self.avgdl = sum(self.doc_len)/len(self.doc_len) 97 | to_delete = [] 98 | for word, docs in self.t2d.items(): 99 | idf = math.log(self.corpus_size - len(docs) + 0.5) - math.log(len(docs) + 0.5) 100 | # only store the idf score if it's above the threshold 101 | if idf > self.alpha: 102 | self.idf[word] = idf 103 | else: 104 | to_delete.append(word) 105 | #print(f"Dropping {len(to_delete)} terms") 106 | for word in to_delete: 107 | del self.t2d[word] 108 | 109 | self.average_idf = sum(self.idf.values())/len(self.idf) 110 | 111 | if self.average_idf < 0: 112 | print( 113 | f'Average inverse document frequency is less than zero. Your corpus of {self.corpus_size} documents' 114 | ' is either too small or it does not originate from natural text. BM25 may produce' 115 | ' unintuitive results.', 116 | file=sys.stderr 117 | ) 118 | 119 | def get_top_n(self, query, documents, n=5): 120 | """ 121 | Retrieve the top n documents for the query. 122 | 123 | Parameters 124 | ---------- 125 | query: list of str 126 | The tokenized query 127 | documents: list 128 | The documents to return from 129 | n: int 130 | The number of documents to return 131 | 132 | Returns 133 | ------- 134 | list 135 | The top n documents 136 | """ 137 | assert self.corpus_size == len(documents), "The documents given don't match the index corpus!" 138 | scores = collections.defaultdict(float) 139 | for token in query: 140 | if token in self.t2d: 141 | for index, freq in self.t2d[token].items(): 142 | denom_cst = self.k1 * (1 - self.b + self.b * self.doc_len[index] / self.avgdl) 143 | scores[index] += self.idf[token]*freq*(self.k1 + 1)/(freq + denom_cst) 144 | 145 | return [documents[i] for i in heapq.nlargest(n, scores.keys(), key=scores.__getitem__)] 146 | 147 | def save(self, filename): 148 | with open(f"{filename}.pkl", "wb") as fsave: 149 | pickle.dump(self, fsave, protocol=pickle.HIGHEST_PROTOCOL) 150 | 151 | @staticmethod 152 | def load(filename): 153 | with open(f"{filename}.pkl", "rb") as fsave: 154 | return pickle.load(fsave) 155 | -------------------------------------------------------------------------------- /mock_api/cragapi/finance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import random 10 | import string 11 | 12 | import numpy as np 13 | import pandas as pd 14 | from loguru import logger 15 | from rank_bm25 import BM25Okapi 16 | from sqlitedict import SqliteDict 17 | 18 | KG_BASE_DIRECTORY = os.getenv("KG_BASE_DIRECTORY", "cragkg") 19 | ########################################################################################## 20 | # The following are the mock API functions needed for the Finance domain. 21 | ########################################################################################## 22 | 23 | class FinanceKG(): 24 | def __init__(self): 25 | self.fuzzy_n = 10 26 | company_dict_file_path = os.path.join(KG_BASE_DIRECTORY, "finance", 'company_name.dict') 27 | logger.info(f"Reading {company_dict_file_path}") 28 | df = pd.read_csv(company_dict_file_path)[["Name", "Symbol"]] 29 | self.name_dict = dict(df.values) 30 | 31 | self.key_map = dict() 32 | self.corpus = [] 33 | for e in self.name_dict: 34 | ne = self.normalize(e) 35 | if ne not in self.key_map: 36 | self.key_map[ne] = [] 37 | self.key_map[ne].append(e) 38 | self.corpus.append(ne.split()) 39 | self.bm25 = BM25Okapi(self.corpus) 40 | self._load_db() 41 | logger.info("Finance KG initialized ✅") 42 | 43 | def _load_db(self): 44 | # Price history 45 | price_history_path = os.path.join(KG_BASE_DIRECTORY, "finance", "finance_price.sqlite") 46 | logger.info(f"Reading price history from: {price_history_path}") 47 | self.price_history = SqliteDict(price_history_path) 48 | 49 | # Detailed price history 50 | detailed_price_history_path = os.path.join(KG_BASE_DIRECTORY, "finance", "finance_detailed_price.sqlite") 51 | logger.info(f"Reading detailed price history from: {detailed_price_history_path}") 52 | self.detailed_price_history = SqliteDict(detailed_price_history_path) 53 | 54 | # Dividend history 55 | dividend_history_path = os.path.join(KG_BASE_DIRECTORY, "finance", "finance_dividend.sqlite") 56 | logger.info(f"Reading dividend history from: {dividend_history_path}") 57 | self.dividend_history = SqliteDict(dividend_history_path) 58 | 59 | # Market cap 60 | market_cap_path = os.path.join(KG_BASE_DIRECTORY, "finance", "finance_marketcap.sqlite") 61 | logger.info(f"Reading market capitalization from: {market_cap_path}") 62 | self.market_cap = SqliteDict(market_cap_path) 63 | 64 | # Financial info 65 | financial_info_path = os.path.join(KG_BASE_DIRECTORY, "finance", "finance_info.sqlite") 66 | logger.info(f"Reading financial information from: {financial_info_path}") 67 | self.financial_info = SqliteDict(financial_info_path) 68 | 69 | 70 | def normalize(self, x:str) -> str: 71 | """ 72 | Normalize a given string. 73 | arg: 74 | x: str 75 | output: 76 | normalized string value: str 77 | """ 78 | return " ".join(x.lower().replace("_", " ").translate(str.maketrans('', '', string.punctuation)).split()) 79 | 80 | def get_company_name(self, query:str) -> list[str]: 81 | """ 82 | Given a query, return top matched company names. 83 | arg: 84 | query: str 85 | output: 86 | top matched company names: list[str] 87 | """ 88 | 89 | query = self.normalize(query) 90 | scores = self.bm25.get_scores(query.split()) 91 | top_idx = np.argsort(scores)[::-1][:self.fuzzy_n] 92 | top_ne = [" ".join(self.corpus[i]) for i in top_idx if scores[i] != 0] 93 | top_e = [] 94 | for ne in top_ne: 95 | assert(ne in self.key_map) 96 | top_e += self.key_map[ne] 97 | return top_e[:self.fuzzy_n] 98 | 99 | def get_ticker_by_name(self, company_name:str) -> str: 100 | """ 101 | Return ticker name by company name. 102 | arg: 103 | company_name: the company name: str 104 | output: 105 | the ticker name of the company: str 106 | """ 107 | return self.name_dict.get(company_name, None) 108 | 109 | def get_price_history(self, ticker_name:str): 110 | """ 111 | Return 1 year history of daily Open price, Close price, High price, Low price and trading Volume. 112 | arg: 113 | ticker_name: str 114 | output: 115 | 1 year daily price history: json 116 | example: 117 | {'2023-02-28 00:00:00 EST': {'Open': 17.258894515434886, 118 | 'High': 17.371392171233836, 119 | 'Low': 17.09014892578125, 120 | 'Close': 17.09014892578125, 121 | 'Volume': 45100}, 122 | '2023-03-01 00:00:00 EST': {'Open': 17.090151299382544, 123 | 'High': 17.094839670907174, 124 | 'Low': 16.443295499989794, 125 | 'Close': 16.87453269958496, 126 | 'Volume': 104300}, 127 | ... 128 | } 129 | """ 130 | db = self.price_history 131 | if ticker_name in db: 132 | return db[ticker_name] 133 | 134 | def get_detailed_price_history(self, ticker_name:str): 135 | """ 136 | Return the past 5 days' history of 1 minute Open price, Close price, High price, Low price and trading Volume, starting from 09:30:00 EST to 15:59:00 EST. Note that the Open, Close, High, Low, Volume are the data for the 1 min duration. However, the Open at 9:30:00 EST may not be equal to the daily Open price, and Close at 15:59:00 EST may not be equal to the daily Close price, due to handling of the paper trade. The sum of the 1 minute Volume may not be equal to the daily Volume. 137 | arg: 138 | ticker_name: str 139 | output: 140 | past 5 days' 1 min price history: json 141 | example: 142 | {'2024-02-22 09:30:00 EST': {'Open': 15.920000076293945, 143 | 'High': 15.920000076293945, 144 | 'Low': 15.920000076293945, 145 | 'Close': 15.920000076293945, 146 | 'Volume': 629}, 147 | '2024-02-22 09:31:00 EST': {'Open': 15.989999771118164, 148 | 'High': 15.989999771118164, 149 | 'Low': 15.989999771118164, 150 | 'Close': 15.989999771118164, 151 | 'Volume': 108}, 152 | ... 153 | } 154 | """ 155 | db = self.detailed_price_history 156 | if ticker_name in db: 157 | return db[ticker_name] 158 | 159 | def get_dividends_history(self, ticker_name:str): 160 | """ 161 | Return dividend history of a ticker. 162 | arg: 163 | ticker_name: str 164 | output: 165 | dividend distribution history: json 166 | example: 167 | {'2019-12-19 00:00:00 EST': 0.058, 168 | '2020-03-19 00:00:00 EST': 0.2, 169 | '2020-06-12 00:00:00 EST': 0.2, 170 | ... 171 | } 172 | """ 173 | db = self.dividend_history 174 | if ticker_name in db: 175 | return db[ticker_name] 176 | 177 | def get_market_capitalization(self, ticker_name: str) -> float: 178 | """ 179 | Return the market capitalization of a ticker. 180 | arg: 181 | ticker_name: str 182 | output: 183 | market capitalization: float 184 | """ 185 | db = self.market_cap 186 | if ticker_name in db: 187 | return db[ticker_name] 188 | 189 | def get_eps(self, ticker_name:str) -> float: 190 | """ 191 | Return earnings per share of a ticker. 192 | arg: 193 | ticker_name: str 194 | output: 195 | earnings per share: float 196 | """ 197 | db = self.financial_info 198 | if ticker_name in db and 'forwardEps' in db[ticker_name]: 199 | return db[ticker_name]['forwardEps'] 200 | 201 | def get_pe_ratio(self, ticker_name:str) -> float: 202 | """ 203 | Return price-to-earnings ratio of a ticker. 204 | arg: 205 | ticker_name: str 206 | output: 207 | price-to-earnings ratio: float 208 | """ 209 | db = self.financial_info 210 | if ticker_name in db and 'forwardPE' in db[ticker_name]: 211 | return db[ticker_name]['forwardPE'] 212 | 213 | def get_info(self, ticker_name:str): 214 | """ 215 | Return meta data of a ticker. 216 | arg: 217 | ticker_name: str 218 | output: 219 | meta information: json 220 | """ 221 | db = self.financial_info 222 | if ticker_name in db: 223 | return db[ticker_name] 224 | -------------------------------------------------------------------------------- /mock_api/cragapi/movie.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import string 10 | from typing import Any, Dict, List, Tuple 11 | 12 | import numpy as np 13 | from loguru import logger 14 | from rank_bm25 import BM25Okapi 15 | 16 | KG_BASE_DIRECTORY = os.getenv("KG_BASE_DIRECTORY", "cragkg") 17 | 18 | class MovieKG: 19 | '''Knowledge Graph API for movie domain 20 | 21 | Mock KG API for movie domain. Supports getting information of movies and of persons including cast and crew. 22 | ''' 23 | def __init__(self, top_n: int=10) -> None: 24 | '''Initialize API and load data. Loads 3 dbs from json 25 | 26 | Args: 27 | top_n: max number of entities to return in entity search 28 | ''' 29 | # Reading the year database 30 | year_db_path = os.path.join(KG_BASE_DIRECTORY, "movie", "year_db.json") 31 | logger.info(f"Reading year database from: {year_db_path}") 32 | with open(year_db_path) as f: 33 | self._year_db = json.load(f) 34 | 35 | # Reading the person database 36 | person_db_path = os.path.join(KG_BASE_DIRECTORY, "movie", "person_db.json") 37 | logger.info(f"Reading person database from: {person_db_path}") 38 | with open(person_db_path) as f: 39 | self._person_db = json.load(f) 40 | 41 | # Reading the movie database 42 | movie_db_path = os.path.join(KG_BASE_DIRECTORY, "movie", "movie_db.json") 43 | logger.info(f"Reading movie database from: {movie_db_path}") 44 | with open(movie_db_path) as f: 45 | self._movie_db = json.load(f) 46 | 47 | self._top_n = top_n 48 | self._person_db_lookup = self._get_direct_lookup_db(self._person_db) 49 | self._movie_db_lookup = self._get_direct_lookup_db(self._movie_db) 50 | self._movie_corpus, self._movie_bm25 = self._get_ranking_db(self._movie_db) 51 | self._person_corpus, self._person_bm25 = self._get_ranking_db(self._person_db) 52 | 53 | logger.info("Movie KG initialized ✅") 54 | 55 | def _normalize(self, x: str) -> str: 56 | '''Helper function for normalizing text 57 | 58 | Args: 59 | x: string to be normalized 60 | 61 | Returns: 62 | normalized string 63 | ''' 64 | return " ".join(x.lower().replace("_", " ").translate(str.maketrans('', '', string.punctuation)).split()) 65 | 66 | def _get_ranking_db(self, db: Dict[str, Any]) -> Tuple[List[str], BM25Okapi]: 67 | '''Helper function to get BM25 index 68 | 69 | Args: 70 | db: dictionary of entities keyed by entity name 71 | 72 | Returns: 73 | corpus: list of entity names corresponding to BM25 index position 74 | bm25: BM25 index 75 | ''' 76 | corpus = [i.split() for i in db.keys()] 77 | bm25 = BM25Okapi(corpus) 78 | return corpus, bm25 79 | 80 | def _get_direct_lookup_db(self, db: Dict[str, Any]) -> Dict[int, Any]: 81 | '''Converts name-indexed db to id-indexed db for latency optimization 82 | 83 | Args: 84 | db: dictionary of entities keyed by normalized entity name 85 | 86 | Returns: 87 | dictionary of entities keyed by unique entity id 88 | ''' 89 | temp_db = {} 90 | for key, value in db.items(): 91 | if 'id' in value: 92 | temp_db[value['id']] = value 93 | return temp_db 94 | 95 | def _search_entity_by_name(self, query: str, bm25: BM25Okapi, corpus: List[str], map_db: Dict[str, Any]) -> List[Dict[str, Any]]: 96 | '''BM25 search for top n=10 matching entities 97 | 98 | Args: 99 | query: string to be searched 100 | bm25: BM25 index 101 | corpus: list of entity names corresponding to BM25 index position 102 | map_db: dictionary of entities keyed by normalized entity name 103 | 104 | Returns: 105 | list of top n matching entities. Each entity is a tuple of (normalized entity name, entity info) 106 | ''' 107 | n = self._top_n 108 | query = self._normalize(query) 109 | scores = bm25.get_scores(query.split()) 110 | top_idx = np.argsort(scores)[::-1][:n] 111 | top_ne = [" ".join(corpus[i]) for i in top_idx if scores[i] != 0] 112 | top_e = [] 113 | for ne in top_ne[:n]: 114 | assert(ne in map_db) 115 | top_e.append(map_db[ne]) 116 | return top_e[:n] 117 | 118 | def get_person_info(self, person_name: str) -> List[Dict[str, Any]]: 119 | '''Gets person info in database through BM25. 120 | 121 | Gets person info through BM25 Search. The returned entities MAY contain the following fields: 122 | - name (string): name of person 123 | - id (int): unique id of person 124 | - acted_movies (list[int]): list of movie ids in which person acted 125 | - directed_movies (list[int]): list of movie ids in which person directed 126 | - birthday (string): string of person's birthday, in the format of "YYYY-MM-DD" 127 | - oscar_awards: list of oscar awards (dict), win or nominated, in which the person was the entity. The format for oscar award entity are: 128 | 'year_ceremony' (int): year of the oscar ceremony, 129 | 'ceremony' (int): which ceremony. for example, ceremony = 50 means the 50th oscar ceremony, 130 | 'category' (string): category of this oscar award, 131 | 'name' (string): name of the nominee, 132 | 'film' (string): name of the film, 133 | 'winner' (bool): whether the person won the award 134 | 135 | Args: 136 | person_name: string to be searched 137 | 138 | Returns: 139 | list of top n matching entities. Entities are ranked by BM25 score. 140 | ''' 141 | res = self._search_entity_by_name(person_name, self._person_bm25, self._person_corpus, self._person_db) 142 | return res 143 | 144 | def get_movie_info(self, person_name: str) -> List[Dict[str, Any]]: 145 | '''Gets movie info in database through BM25. 146 | 147 | Gets movie info through BM25 Search. The returned entities MAY contain the following fields: 148 | - title (string): title of movie 149 | - id (int): unique id of movie 150 | - release_date (string): string of movie's release date, in the format of "YYYY-MM-DD" 151 | - original_title (string): original title of movie, if in another language other than english 152 | - original_language (string): original language of movie. Example: 'en', 'fr' 153 | - budget (int): budget of movie, in USD 154 | - revenue (int): revenue of movie, in USD 155 | - rating (float): rating of movie, in range [0, 10] 156 | - genres (list[dict]): list of genres of movie. Sample genre object is {'id': 123, 'name': 'action'} 157 | - oscar_awards: list of oscar awards (dict), win or nominated, in which the movie was the entity. The format for oscar award entity are: 158 | 'year_ceremony' (int): year of the oscar ceremony, 159 | 'ceremony' (int): which ceremony. for example, ceremony = 50 means the 50th oscar ceremony, 160 | 'category' (string): category of this oscar award, 161 | 'name' (string): name of the nominee, 162 | 'film' (string): name of the film, 163 | 'winner' (bool): whether the person won the award 164 | - cast (list [dict]): list of cast members of the movie and their roles. The format of the cast member entity is: 165 | 'name' (string): name of the cast member, 166 | 'id' (int): unique id of the cast member, 167 | 'character' (string): character played by the cast member in the movie, 168 | 'gender' (string): the reported gender of the cast member. Use 2 for actor and 1 for actress, 169 | 'order' (int): order of the cast member in the movie. For example, the actress with the lowest order is the main actress, 170 | - crew' (list [dict]): list of crew members of the movie and their roles. The format of the crew member entity is: 171 | 'name' (string): name of the crew member, 172 | 'id' (int): unique id of the crew member, 173 | 'job' (string): job of the crew member, 174 | 175 | Args: 176 | movie_name: string to be searched 177 | 178 | Returns: 179 | list of top n matching entities. Entities are ranked by BM25 score. 180 | ''' 181 | res = self._search_entity_by_name(person_name, self._movie_bm25, self._movie_corpus, self._movie_db) 182 | return res 183 | 184 | def get_year_info(self, year: str) -> Dict[str, Any]: 185 | '''Gets info of a specific year 186 | 187 | Gets year info. The returned entity MAY contain the following fields: 188 | - movie_list: list of movie ids in the year. This field can be very long to a few thousand films 189 | - oscar_awards: list of oscar awards (dict), held in that particular year. The format for oscar award entity are: 190 | 'year_ceremony' (int): year of the oscar ceremony, 191 | 'ceremony' (int): which ceremony. for example, ceremony = 50 means the 50th oscar ceremony, 192 | 'category' (string): category of this oscar award, 193 | 'name' (string): name of the nominee, 194 | 'film' (string): name of the film, 195 | 'winner' (bool): whether the person won the award 196 | 197 | Args: 198 | year: string of year. Note that we only support years between 1990 and 2021 199 | 200 | Returns: 201 | an entity representing year information 202 | ''' 203 | if int(year) not in range(1990, 2022): 204 | raise ValueError("Year must be between 1990 and 2021") 205 | return self._year_db.get(str(year), None) 206 | 207 | def get_movie_info_by_id(self, movie_id: int) -> Dict[str, Any]: 208 | '''Helper fast lookup function to get movie info directly by id 209 | 210 | Return a movie entity with same format as the entity in get_movie_info. 211 | 212 | Args: 213 | movie_id: unique id of movie 214 | 215 | Returns: 216 | an entity representing movie information 217 | ''' 218 | return self._movie_db_lookup.get(movie_id, None) 219 | 220 | def get_person_info_by_id(self, person_id: int) -> Dict[str, Any]: 221 | '''Helper fast lookup function to get person info directly by id 222 | 223 | Return a person entity with same format as the entity in get_person_info. 224 | 225 | Args: 226 | person_id: unique id of person 227 | 228 | Returns: 229 | an entity representing person information 230 | ''' 231 | return self._person_db_lookup.get(person_id, None) 232 | -------------------------------------------------------------------------------- /mock_api/cragapi/music.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import csv 8 | import json 9 | import os 10 | import pickle 11 | import string 12 | import time 13 | from datetime import datetime 14 | 15 | import numpy as np 16 | import pandas as pd 17 | from loguru import logger 18 | from tqdm import tqdm 19 | 20 | from .fast_bm25 import BM25 21 | 22 | KG_BASE_DIRECTORY = os.getenv("KG_BASE_DIRECTORY", "cragkg") 23 | 24 | 25 | class MusicKG(object): 26 | def __init__(self): 27 | # Reading the artist dictionary 28 | artist_dict_path = os.path.join(KG_BASE_DIRECTORY, "music", "artist_dict_simplified.pickle") 29 | logger.info(f"Reading artist dictionary from: {artist_dict_path}") 30 | with open(artist_dict_path, 'rb') as file: 31 | self.artist_dict = pickle.load(file) 32 | 33 | # Reading the song dictionary 34 | song_dict_path = os.path.join(KG_BASE_DIRECTORY, "music", "song_dict_simplified.pickle") 35 | logger.info(f"Reading song dictionary from: {song_dict_path}") 36 | with open(song_dict_path, 'rb') as file: 37 | self.song_dict = pickle.load(file) 38 | 39 | # Reading the Grammy DataFrame 40 | grammy_df_path = os.path.join(KG_BASE_DIRECTORY, "music", "grammy_df.pickle") 41 | logger.info(f"Reading Grammy DataFrame from: {grammy_df_path}") 42 | with open(grammy_df_path, 'rb') as file: 43 | self.grammy_df = pickle.load(file) 44 | 45 | # Reading the rank dictionary for Hot 100 46 | rank_dict_hot_path = os.path.join(KG_BASE_DIRECTORY, "music", "rank_dict_hot100.pickle") 47 | logger.info(f"Reading rank dictionary for Hot 100 from: {rank_dict_hot_path}") 48 | with open(rank_dict_hot_path, 'rb') as file: 49 | self.rank_dict_hot = pickle.load(file) 50 | 51 | # Reading the song dictionary for Hot 100 52 | song_dict_hot_path = os.path.join(KG_BASE_DIRECTORY, "music", "song_dict_hot100.pickle") 53 | logger.info(f"Reading song dictionary for Hot 100 from: {song_dict_hot_path}") 54 | with open(song_dict_hot_path, 'rb') as file: 55 | self.song_dict_hot = pickle.load(file) 56 | 57 | # Reading the artist work dictionary 58 | artist_work_dict_path = os.path.join(KG_BASE_DIRECTORY, "music", "artist_work_dict.pickle") 59 | logger.info(f"Reading artist work dictionary from: {artist_work_dict_path}") 60 | with open(artist_work_dict_path, 'rb') as file: 61 | self.artist_work_dict = pickle.load(file) 62 | 63 | self.key_map_artist = {} 64 | self.corpus_artist = [] 65 | for e in self.artist_dict.keys(): 66 | ne = self.normalize(e) 67 | ne_split = str(ne.split()) 68 | if ne_split not in self.key_map_artist: 69 | self.key_map_artist[ne_split] = [] 70 | self.key_map_artist[ne_split].append(e) 71 | self.corpus_artist.append(ne) 72 | self.corpus_artist = list(set(self.corpus_artist)) 73 | self.corpus_artist.sort() 74 | self.corpus_artist = [ne.split() for ne in self.corpus_artist] 75 | self.bm25_artist = BM25(self.corpus_artist) 76 | 77 | self.key_map_song = {} 78 | self.corpus_song = [] 79 | for e in self.song_dict.keys(): 80 | ne = self.normalize(e) 81 | ne_split = str(ne.split()) 82 | if ne_split not in self.key_map_song: 83 | self.key_map_song[ne_split] = [] 84 | self.key_map_song[ne_split].append(e) 85 | self.corpus_song.append(ne) 86 | self.corpus_song = list(set(self.corpus_song)) 87 | self.corpus_song.sort() 88 | self.corpus_song = [ne.split() for ne in self.corpus_song] 89 | self.bm25_song = BM25(self.corpus_song) 90 | 91 | logger.info("Music KG initialized ✅") 92 | 93 | 94 | def normalize(self, x): 95 | return " ".join(x.lower().replace("_", " ").translate(str.maketrans('', '', string.punctuation)).split()) 96 | 97 | def search_artist_entity_by_name(self, query): 98 | """ Return the fuzzy matching results of the query (artist name); we only return the top-10 similar results from our KB 99 | 100 | Args: 101 | query (str): artist name 102 | 103 | Returns: 104 | Top-10 similar entity name in a list 105 | 106 | """ 107 | n = 10 108 | query = self.normalize(query) 109 | results = self.bm25_artist.get_top_n(query.split(), self.corpus_artist, n=n) 110 | top_e = [] 111 | for cur_ne_str in results: 112 | assert(str(cur_ne_str) in self.key_map_artist.keys()) 113 | top_e += self.key_map_artist[str(cur_ne_str)] 114 | return top_e[:n] 115 | 116 | def search_song_entity_by_name(self, query): 117 | """ Return the fuzzy matching results of the query (song name); we only return the top-10 similar results from our KB 118 | 119 | Args: 120 | query (str): song name 121 | 122 | Returns: 123 | Top-10 similar entity name in a list 124 | 125 | """ 126 | n = 10 127 | query = self.normalize(query) 128 | results = self.bm25_song.get_top_n(query.split(), self.corpus_song, n=n) 129 | top_e = [] 130 | for cur_ne_str in results: 131 | assert(str(cur_ne_str) in self.key_map_song.keys()) 132 | top_e += self.key_map_song[str(cur_ne_str)] 133 | return top_e[:n] 134 | 135 | def get_billboard_rank_date(self, rank, date=None): 136 | """ Return the song name(s) and the artist name(s) of a certain rank on a certain date; 137 | If no date is given, return the list of of a certain rank of all dates. 138 | 139 | Args: 140 | rank (int): the interested rank in billboard; from 1 to 100. 141 | date (Optional, str, in YYYY-MM-DD format): the interested date; leave it blank if do not want to specify the date. 142 | 143 | Returns: 144 | rank_list (list): a list of song names of a certain rank (on a certain date). 145 | artist_list (list): a list of author names corresponding to the song names returned. 146 | """ 147 | 148 | rank_list = [] 149 | artist_list = [] 150 | if not str(rank) in self.rank_dict_hot.keys(): 151 | return None, None 152 | else: 153 | if date: 154 | for item in self.rank_dict_hot[str(rank)]: 155 | if item['Date'] == date: 156 | return [item['Song']], [item['Artist']] 157 | else: 158 | for item in self.rank_dict_hot[str(rank)]: 159 | rank_list.append(item['Song']) 160 | artist_list.append(item['Artist']) 161 | return rank_list, artist_list 162 | 163 | def get_billboard_attributes(self, date, attribute, song_name): 164 | """ Return the attributes of a certain song on a certain date 165 | 166 | Args: 167 | date (str, in YYYY-MM-DD format): the interested date of the song 168 | attribute (str): attributes from ['rank_last_week', 'weeks_in_chart', 'top_position', 'rank'] 169 | song_name (str): the interested song name 170 | 171 | Returns: 172 | cur_value (str): the value of the interested attribute of a song on a certain date 173 | """ 174 | if not song_name in self.song_dict_hot: 175 | return None 176 | else: 177 | cur_dict = self.song_dict_hot[song_name] 178 | if not date in cur_dict.keys(): 179 | return None 180 | else: 181 | row = cur_dict[date] 182 | if row[6] == '-': 183 | if attribute == 'rank_last_week': 184 | cur_value = row[6] 185 | elif attribute == 'weeks_in_chart': 186 | cur_value = row[5] 187 | elif attribute == 'top_position': 188 | cur_value = row[4] 189 | else: 190 | cur_value = row[3] 191 | else: 192 | if attribute == 'rank_last_week': 193 | cur_value = row[4] 194 | elif attribute == 'weeks_in_chart': 195 | cur_value = row[6] 196 | elif attribute == 'top_position': 197 | cur_value = row[5] 198 | else: 199 | cur_value = row[3] 200 | return cur_value 201 | 202 | def grammy_get_best_artist_by_year(self, year): 203 | """ Return the Best New Artist of a certain year in between 1958 and 2019 204 | 205 | Args: 206 | year (int, in YYYY format): the interested year 207 | 208 | Returns: 209 | artist_list (list): the list of artists who win the award 210 | """ 211 | if year<1957 or year>2019: 212 | return None 213 | else: 214 | filtered_df = self.grammy_df[(self.grammy_df['category'] == 'Best New Artist') & (self.grammy_df['year'] == year)] 215 | artist_list = filtered_df['nominee'].tolist() 216 | return artist_list 217 | 218 | def grammy_get_award_count_by_artist(self, artist_name): 219 | """ Return the number of awards won by a certain artist between 1958 and 2019 220 | 221 | Args: 222 | artist_name (str): the name of the artist 223 | 224 | Returns: 225 | the number of total awards (int) 226 | """ 227 | total_unique_rows_artist = 0 228 | total_unique_rows_nominee = 0 229 | total_unique_rows_worker = 0 230 | for value in self.grammy_df['nominee']: 231 | if artist_name in str(value): 232 | total_unique_rows_nominee += 1 233 | for value in self.grammy_df['artist']: 234 | if artist_name in str(value): 235 | total_unique_rows_artist += 1 236 | for value in self.grammy_df['workers']: 237 | if artist_name in str(value): 238 | total_unique_rows_worker += 1 239 | return total_unique_rows_nominee + total_unique_rows_artist + total_unique_rows_worker 240 | 241 | def grammy_get_award_count_by_song(self, song_name): 242 | """ Return the number of awards won by a certain song between 1958 and 2019 243 | 244 | Args: 245 | song_name (str): the name of the song 246 | 247 | Returns: 248 | the number of total awards (int) 249 | """ 250 | total_unique_rows_nominee = len(self.grammy_df[self.grammy_df['nominee']==song_name]) 251 | return total_unique_rows_nominee 252 | 253 | def grammy_get_best_song_by_year(self, year): 254 | """ Return the Song Of The Year in a certain year between 1958 and 2019 255 | 256 | Args: 257 | year (int, in YYYY format): the interested year 258 | 259 | Returns: 260 | song_list (list): the list of the song names that win the Song Of The Year in a certain year 261 | """ 262 | if year<1957 or year>2019: 263 | return None 264 | else: 265 | filtered_df = self.grammy_df[(self.grammy_df['category'] == 'Song Of The Year') & (self.grammy_df['year'] == year)] 266 | song_list = filtered_df['nominee'].tolist() 267 | return song_list 268 | 269 | def grammy_get_award_date_by_artist(self, artist_name): 270 | """ Return the award winning years of a certain artist 271 | 272 | Args: 273 | artist_name (str): the name of the artist 274 | 275 | Returns: 276 | selected_years (list): the list of years the artist is awarded 277 | """ 278 | idx = [] 279 | for i, value in enumerate(self.grammy_df['nominee']): 280 | if artist_name in str(value): 281 | idx.append(i) 282 | for i, value in enumerate(self.grammy_df['artist']): 283 | if artist_name in str(value): 284 | idx.append(i) 285 | for i, value in enumerate(self.grammy_df['workers']): 286 | if artist_name in str(value): 287 | idx.append(i) 288 | selected_idx = list(set(idx)) 289 | selected_years = [] 290 | for cur_idx in selected_idx: 291 | selected_years.append(self.grammy_df['year'][cur_idx]) 292 | selected_years = list(set(selected_years)) 293 | selected_years = [int(x) for x in selected_years] 294 | return selected_years 295 | 296 | def grammy_get_best_album_by_year(self, year): 297 | """ Return the Album Of The Year of a certain year between 1958 and 2019 298 | 299 | Args: 300 | year (int, in YYYY format): the interested year 301 | 302 | Returns: 303 | song_list (list): the list of albums that won the Album Of The Year in a certain year 304 | """ 305 | if year<1957 or year>2019: 306 | return None 307 | else: 308 | filtered_df = self.grammy_df[(self.grammy_df['category'] == 'Album Of The Year') & (self.grammy_df['year'] == year)] 309 | song_list = filtered_df['nominee'].tolist() 310 | return song_list 311 | 312 | def grammy_get_all_awarded_artists(self): 313 | """Return all the artists ever awarded Grammy Best New Artist between 1958 and 2019 314 | 315 | Args: 316 | None 317 | 318 | Returns: 319 | nominee_values (list): the list of artist ever awarded Grammy Best New Artist 320 | 321 | """ 322 | nominee_values = self.grammy_df[self.grammy_df['category'] == 'Best New Artist']['nominee'].dropna().unique().tolist() 323 | return nominee_values 324 | 325 | def get_artist_birth_place(self, artist_name): 326 | """ Return the birth place country code (2-digit) for the input artist 327 | 328 | Args: 329 | artist_name (str): the name of the artist 330 | 331 | Returns: 332 | country (str): the two-digit country code following ISO-3166 333 | """ 334 | try: 335 | d = self.artist_dict[artist_name] 336 | country = d['country'] 337 | if country: 338 | return country 339 | else: 340 | return None 341 | except: 342 | return None 343 | 344 | def get_artist_birth_date(self, artist_name): 345 | """ Return the birth date of the artist 346 | 347 | Args: 348 | artist_name (str): the name of the artist 349 | 350 | Returns: 351 | life_span_begin (str, in YYYY-MM-DD format if possible): the birth date of the person or the begin date of a band 352 | 353 | """ 354 | try: 355 | d = self.artist_dict[artist_name] 356 | life_span_begin = d['birth_date'] 357 | if life_span_begin: 358 | return life_span_begin 359 | else: 360 | return None 361 | except: 362 | return None 363 | 364 | def get_members(self, band_name): 365 | """ Return the member list of a band 366 | 367 | Args: 368 | band_name (str): the name of the band 369 | 370 | Returns: 371 | the list of members' names. 372 | """ 373 | try: 374 | d = self.artist_dict[band_name] 375 | members = d['members'] 376 | return list(set(members)) 377 | except: 378 | return None 379 | 380 | def get_lifespan(self, artist_name): 381 | """ Return the lifespan of the artist 382 | 383 | Args: 384 | artist_name (str): the name of the artist 385 | 386 | Returns: 387 | the birth and death dates in a list 388 | 389 | """ 390 | try: 391 | d = self.artist_dict[artist_name] 392 | life_span_begin = d['birth_date'] 393 | life_span_end = d['end_date'] 394 | life = [life_span_begin, life_span_end] 395 | return life 396 | except: 397 | return [None, None] 398 | 399 | def get_song_author(self, song_name): 400 | """ Return the author of the song 401 | 402 | Args: 403 | song_name (str): the name of the song 404 | 405 | Returns: 406 | author (str): the author of the song 407 | """ 408 | try: 409 | d = self.song_dict[song_name] 410 | author = d['author'] 411 | if author: 412 | return author 413 | else: 414 | return None 415 | except: 416 | return None 417 | 418 | def get_song_release_country(self, song_name): 419 | """ Return the release country of the song 420 | 421 | Args: 422 | song_name (str): the name of the song 423 | 424 | Returns: 425 | country (str): the two-digit country code following ISO-3166 426 | """ 427 | try: 428 | d = self.song_dict[song_name] 429 | country = d['country'] 430 | if country: 431 | return country 432 | else: 433 | return None 434 | except: 435 | return None 436 | 437 | def get_song_release_date(self, song_name): 438 | """ Return the release date of the song 439 | 440 | Args: 441 | song_name (str): the name of the song 442 | 443 | Returns: 444 | date (str in YYYY-MM-DD format): the date of the song 445 | """ 446 | try: 447 | d = self.song_dict[song_name] 448 | date = d['date'] 449 | if date: 450 | return date 451 | else: 452 | return None 453 | except: 454 | return None 455 | 456 | def get_artist_all_works(self, artist_name): 457 | """ Return the list of all works of a certain artist 458 | 459 | Args: 460 | artist_name (str): the name of the artist 461 | 462 | Returns: 463 | work_list (list): the list of all work names 464 | 465 | """ 466 | if artist_name in self.artist_work_dict.keys(): 467 | work_list = self.artist_work_dict[artist_name] 468 | return work_list 469 | else: 470 | return None 471 | 472 | -------------------------------------------------------------------------------- /mock_api/cragapi/open.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import bz2 8 | import json 9 | import os 10 | import string 11 | 12 | import numpy as np 13 | from loguru import logger 14 | from rank_bm25 import BM25Okapi 15 | 16 | KG_BASE_DIRECTORY = os.getenv("KG_BASE_DIRECTORY", "cragkg") 17 | 18 | 19 | class OpenKG(object): 20 | def __init__(self): 21 | self.kg = {} 22 | for i in range(2): 23 | open_kg_file = os.path.join(KG_BASE_DIRECTORY, "open", "kg."+str(i)+".jsonl.bz2") 24 | logger.info(f"Reading open_kg file from: {open_kg_file}") 25 | with bz2.open(open_kg_file, "rt", encoding='utf8') as f: 26 | l = f.readline() 27 | while l: 28 | l = json.loads(l) 29 | self.kg[l[0]] = l[1] 30 | l = f.readline() 31 | self.key_map = {} 32 | self.corpus = [] 33 | for e in self.kg: 34 | ne = self.normalize(e) 35 | if ne not in self.key_map: 36 | self.key_map[ne] = [] 37 | self.key_map[ne].append(e) 38 | self.corpus.append(ne) 39 | self.corpus = list(set(self.corpus)) 40 | self.corpus.sort() 41 | self.corpus = [ne.split() for ne in self.corpus] 42 | self.bm25 = BM25Okapi(self.corpus) 43 | 44 | logger.info("Open KG initialized ✅") 45 | 46 | 47 | def normalize(self, x): 48 | return " ".join(x.lower().replace("_", " ").translate(str.maketrans('', '', string.punctuation)).split()) 49 | 50 | def search_entity_by_name(self, query): 51 | n = 10 52 | query = self.normalize(query) 53 | scores = self.bm25.get_scores(query.split()) 54 | top_idx = np.argsort(scores)[::-1][:n] 55 | top_ne = [" ".join(self.corpus[i]) for i in top_idx if scores[i] != 0] 56 | top_e = [] 57 | for ne in top_ne: 58 | assert(ne in self.key_map) 59 | top_e += self.key_map[ne] 60 | return top_e[:n] 61 | 62 | def get_entity(self, entity): 63 | return self.kg[entity] if entity in self.kg else None 64 | -------------------------------------------------------------------------------- /mock_api/cragapi/sports.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import datetime 8 | import json 9 | import os 10 | import random 11 | import sqlite3 as sql 12 | 13 | import numpy as np 14 | import pandas as pd 15 | from dateutil import parser as dateutil_parser 16 | from loguru import logger 17 | 18 | KG_BASE_DIRECTORY = os.getenv("KG_BASE_DIRECTORY", "cragkg") 19 | 20 | class SoccerKG: 21 | def __init__(self, file_name='soccer_team_match_stats.pkl'): 22 | """ 23 | Load soccer KG at different time stamp for public and private set 24 | """ 25 | soccer_kg_file = os.path.join(KG_BASE_DIRECTORY, "sports", file_name) 26 | logger.info(f"Reading soccer KG from: {soccer_kg_file}") 27 | team_match_stats = pd.read_pickle(os.path.join(KG_BASE_DIRECTORY, "sports", file_name)) 28 | self.team_match_stats = team_match_stats[team_match_stats.index.get_level_values('league').notna()] 29 | logger.info("Soccer KG initialized ✅") 30 | 31 | # ==================== APIs for competitors ==================== 32 | 33 | def get_games_on_date(self, date_str, soccer_team_name=None): 34 | """ 35 | Description: Get all soccer game rows given date_str 36 | Input: 37 | - soccer_team_name: soccer team name, if None, get results for all teams 38 | - date_str: in format of %Y-%m-%d, %Y-%m or %Y, e.g. 2024-03-01, 2024-03, 2024 39 | Output: a json contains info of the games 40 | """ 41 | parts = date_str.split('-') 42 | if soccer_team_name is None: 43 | filtered_df = self.team_match_stats 44 | else: 45 | filtered_df = self.team_match_stats.loc[(slice(None), slice(None), soccer_team_name, slice(None)), :] 46 | if len(parts) == 3: 47 | # date 48 | filtered_df = filtered_df[filtered_df['date'].dt.strftime('%Y-%m-%d') == date_str] 49 | elif len(parts) == 2: 50 | # month year 51 | filtered_df = filtered_df[filtered_df['date'].dt.strftime('%Y-%m') == date_str] 52 | elif len(parts) == 1: 53 | # year 54 | filtered_df = filtered_df[filtered_df['date'].dt.strftime('%Y') == date_str] 55 | else: 56 | filtered_df = None 57 | if filtered_df is not None and len(filtered_df) > 0: 58 | return filtered_df.to_json(date_format='iso') 59 | 60 | class NBAKG: 61 | # ==================== Helper funcs ==================== 62 | def __init__(self): 63 | nba_kg_file = os.path.join(KG_BASE_DIRECTORY, "sports", 'nba.sqlite') 64 | logger.info(f"Reading NBA KG from: {nba_kg_file}") 65 | self.conn = sql.connect(nba_kg_file) # create connection object to database 66 | logger.info("NBA KG initialized ✅") 67 | 68 | def get_time_cond(self, date_str): 69 | """Helper funcs""" 70 | parts = date_str.split('-') 71 | if len(parts) == 3: 72 | # date 73 | return f"strftime('%Y-%m-%d',game_date) = '{date_str}'" 74 | elif len(parts) == 2: 75 | # month year 76 | return f"strftime('%Y-%m',game_date) = '{date_str}'" 77 | elif len(parts) == 1: 78 | # year 79 | return f"strftime('%Y',game_date) = '{date_str}'" 80 | else: 81 | return "1" 82 | 83 | def team_in_game_cond(self, basketball_team_name): 84 | """Helper funcs""" 85 | return f"(team_name_home = '{basketball_team_name}' or team_name_away = '{basketball_team_name}')" 86 | 87 | # ==================== API for competitors ==================== 88 | 89 | def get_games_on_date(self, date_str, basketball_team_name=None): 90 | """ 91 | Description: Get all nba game rows given date_str 92 | Input: date_str in format of %Y-%m-%d, %Y-%m, or %Y, e.g. 2023-01-01, 2023-01, 2023, basketball_team_name (Optional) 93 | Output: a json contains info of the game 94 | """ 95 | if basketball_team_name is not None: 96 | team_cond = self.team_in_game_cond(basketball_team_name) 97 | time_cond = self.get_time_cond(date_str) 98 | df_game_by_team = pd.read_sql(f"select * from game where {team_cond} and {time_cond}", self.conn) 99 | if len(df_game_by_team) > 0: 100 | return df_game_by_team.to_json(date_format='iso') 101 | else: 102 | time_cond = self.get_time_cond(date_str) 103 | df_game_by_team = pd.read_sql(f"select * from game where {time_cond}", self.conn) 104 | if len(df_game_by_team) > 0: 105 | return df_game_by_team.to_json(date_format='iso') 106 | 107 | def get_play_by_play_data_by_game_ids(self, game_ids): 108 | """ 109 | Description: Get all nba play by play rows given game ids 110 | Input: list of nba game ids, e.g., ["0022200547", "0029600027"] 111 | Output: info of the play by play events of given game id 112 | """ 113 | game_ids_str = ', '.join(f"'{game_id}'" for game_id in game_ids) 114 | df_play_by_play_by_gameids = pd.read_sql(f"select * from play_by_play where game_id in ({game_ids_str})", self.conn) 115 | if len(df_play_by_play_by_gameids) > 0: 116 | return df_play_by_play_by_gameids.to_json(date_format='iso') 117 | -------------------------------------------------------------------------------- /mock_api/cragkg/.gitattributes: -------------------------------------------------------------------------------- 1 | *.pickle filter=lfs diff=lfs merge=lfs -text 2 | *.dict filter=lfs diff=lfs merge=lfs -text 3 | *.sqlite filter=lfs diff=lfs merge=lfs -text 4 | *.json filter=lfs diff=lfs merge=lfs -text 5 | *.bz2 filter=lfs diff=lfs merge=lfs -text 6 | *.pkl filter=lfs diff=lfs merge=lfs -text 7 | -------------------------------------------------------------------------------- /mock_api/cragkg/finance/company_name.dict: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e92fcfb64911508705b59af046cfff8b4edc38ddd6dff391c27870f2cac94ca6 3 | size 614789 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/finance/finance_detailed_price.sqlite: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:6833ea730d6c7bf14532566b450562f153e06ff0fdd04551722552371448fec2 3 | size 717529088 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/finance/finance_dividend.sqlite: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:5e18259fd76d39f8b06801db911d28a152e54e2fa77abe12dffbd66474e81fe2 3 | size 10448896 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/finance/finance_info.sqlite: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:30e80405de86ded41cfde9020409c0c192042eb58f40de78cff77f896c4e7894 3 | size 34877440 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/finance/finance_marketcap.sqlite: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3357344b67e039dc424c54d19f561ff2c97b7507df74d0388430984b62ca89cf 3 | size 303104 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/finance/finance_price.sqlite: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:05fa78f8d319ae2a7f5990b50f8b6f1d17be1f650bf5fb814721e1fc2d4a56ac 3 | size 136196096 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/movie/movie_db.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:673de2da21bd594a8acb8cdb7765346f07ccf4d4a88a74942416d1748c12964b 3 | size 189194170 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/movie/person_db.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:85992666f528e05193cc3c97bde43e8c24a25ad59bd1a0d359df050799f79c79 3 | size 45103919 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/movie/year_db.json: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:05d32ffc43e749ab6e38484b04ec6c6b0f60830f306f514309812c3cba4ed91b 3 | size 1150550 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/music/artist_dict_simplified.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:7f935fc9b76b26b52843d2f52c46aaff00f1193b6e5bc9b1529d1ae9456b6a2a 3 | size 136772303 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/music/artist_work_dict.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bc6db55ab0e32e13769a2473e14050f0828eea82aef5f6e279ef6f9543a54ffc 3 | size 97658647 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/music/grammy_df.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:4c6ee52e6b083a980063ee290eb0b2d6926e1478c0d696380218ca19da34f166 3 | size 731396 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/music/rank_dict_hot100.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:965cf61f1daae796d274ed90a20e724c00568eb2586339e148d97606cb608e1e 3 | size 29121007 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/music/song_dict_hot100.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a5d5943fdd69a5c7730dbc3419730559abd4ba31704cf7bc3ea48ab2e98ded4c 3 | size 4084771 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/music/song_dict_simplified.pickle: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:bb8e988448b76518eb0e16ae894afa982187f972956dd12bd8e787b1a9f732f4 3 | size 234549979 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/open/kg.0.jsonl.bz2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3c25a6533bab6d13254f627ddb2f69f60a127233ec51990c9109bf88e800ee30 3 | size 26290049 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/open/kg.1.jsonl.bz2: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:a636adabb4b5af258f734bf2ad1a24a82ef84f451aaaec5e5324c4ce5753971a 3 | size 17313388 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/sports/nba.sqlite: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:2ff9d5943b32c0fb247a3e7d6b1b8f268d37baf88668ecf83418169234a32d1f 3 | size 2902269952 4 | -------------------------------------------------------------------------------- /mock_api/cragkg/sports/soccer_team_match_stats.pkl: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:86c78cb0e43f294f94d329083b4cfd31f05efed23d3ba51ed9e40dbcad656724 3 | size 462407 4 | -------------------------------------------------------------------------------- /mock_api/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.103.0 2 | uvicorn[standard] 3 | pydantic==2.5.3 4 | rank_bm25 5 | tqdm 6 | pandas==2.0.3 7 | lxml 8 | sqlitedict==2.1.0 9 | loguru -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # Guide to Writing Your Own Models 2 | 3 | ## Model Base Class 4 | Your models should follow the format from the `DummyModel` class found in [dummy_model.py](dummy_model.py). We provide the example model, `dummy_model.py`, to illustrate the structure your own model. Crucially, your model class must implement the `batch_generate_answer` method. 5 | 6 | ## Selecting which model to use 7 | To ensure your model is recognized and utilized correctly, please specify your model class name in the [`user_config.py`](user_config.py) file, by following the instructions in the inline comments. 8 | 9 | ## Model Inputs and Outputs 10 | 11 | ### Inputs 12 | Your model will receive a batch of input queries as a dictionary, where the dictionary has the following keys: 13 | 14 | ``` 15 | - 'query' (List[str]): List of user queries. 16 | - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding 17 | to a query. 18 | - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. 19 | ``` 20 | 21 | ### Outputs 22 | The output from your model's `batch_generate_answer` function should be a list of string responses for all the queries in the input batch. -------------------------------------------------------------------------------- /models/dummy_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Any, Dict, List 9 | 10 | from models.utils import trim_predictions_to_max_token_length 11 | 12 | # Load the environment variable that specifies the URL of the MockAPI. This URL is essential 13 | # for accessing the correct API endpoint in Task 2 and Task 3. The value of this environment variable 14 | # may vary across different evaluation settings, emphasizing the importance of dynamically obtaining 15 | # the API URL to ensure accurate endpoint communication. 16 | 17 | # **Note**: This environment variable will not be available for Task 1 evaluations. 18 | CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") 19 | 20 | 21 | class DummyModel: 22 | def __init__(self): 23 | """ 24 | Initialize your model(s) here if necessary. 25 | This is the constructor for your DummyModel class, where you can set up any 26 | required initialization steps for your model(s) to function correctly. 27 | """ 28 | pass 29 | 30 | def get_batch_size(self) -> int: 31 | """ 32 | Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. 33 | 34 | Returns: 35 | int: The batch size, an integer between 1 and 16. This value indicates how many 36 | queries should be processed together in a single batch. It can be dynamic 37 | across different batch_generate_answer calls, or stay a static value. 38 | """ 39 | self.batch_size = 4 40 | return self.batch_size 41 | 42 | def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: 43 | """ 44 | Generates answers for a batch of queries using associated (pre-cached) search results and query times. 45 | 46 | Parameters: 47 | batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: 48 | - 'interaction_id; (List[str]): List of interaction_ids for the associated queries 49 | - 'query' (List[str]): List of user queries. 50 | - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding 51 | to a query. 52 | - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. 53 | 54 | Returns: 55 | List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. 56 | If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. 57 | 58 | Notes: 59 | - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid 60 | the penalty for hallucination. 61 | - Response Time: Ensure that your model processes and responds to each query within 30 seconds. 62 | Failing to adhere to this time constraint **will** result in a timeout during evaluation. 63 | """ 64 | batch_interaction_ids = batch["interaction_id"] 65 | queries = batch["query"] 66 | search_results = batch["search_results"] 67 | query_times = batch["query_time"] 68 | 69 | answers = [] 70 | for idx, query in enumerate(queries): 71 | # Implement logic to generate answers based on search results and query times 72 | answers.append("i don't know") # Default placeholder response 73 | 74 | return answers 75 | -------------------------------------------------------------------------------- /models/rag_knowledge_graph_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import sys 10 | from collections import defaultdict 11 | from json import JSONDecoder 12 | from typing import Any, Dict, List 13 | 14 | import numpy as np 15 | import ray 16 | import torch 17 | import vllm 18 | from blingfire import text_to_sentences_and_offsets 19 | from bs4 import BeautifulSoup 20 | from loguru import logger 21 | from sentence_transformers import SentenceTransformer 22 | from utils.cragapi_wrapper import CRAG 23 | 24 | ###################################################################################################### 25 | ###################################################################################################### 26 | ### 27 | ### Please pay special attention to the comments that start with "TUNE THIS VARIABLE" 28 | ### as they depend on your model and the available GPU resources. 29 | ### 30 | ### DISCLAIMER: This baseline has NOT been tuned for performance 31 | ### or efficiency, and is provided as is for demonstration. 32 | ###################################################################################################### 33 | 34 | 35 | # Load the environment variable that specifies the URL of the MockAPI. This URL is essential 36 | # for accessing the correct API endpoint in Task 2 and Task 3. The value of this environment variable 37 | # may vary across different evaluation settings, emphasizing the importance of dynamically obtaining 38 | # the API URL to ensure accurate endpoint communication. 39 | 40 | 41 | CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") 42 | 43 | #### CONFIG PARAMETERS --- 44 | 45 | # Define the number of context sentences to consider for generating an answer. 46 | NUM_CONTEXT_SENTENCES = 20 47 | # Set the maximum length for each context sentence (in characters). 48 | MAX_CONTEXT_SENTENCE_LENGTH = 1000 49 | # Set the maximum context references length (in characters). 50 | MAX_CONTEXT_REFERENCES_LENGTH = 4000 51 | 52 | # Batch size you wish the evaluators will use to call the `batch_generate_answer` function 53 | SUBMISSION_BATCH_SIZE = 8 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 54 | 55 | # VLLM Parameters 56 | VLLM_TENSOR_PARALLEL_SIZE = 4 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 57 | VLLM_GPU_MEMORY_UTILIZATION = 0.85 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 58 | 59 | # Sentence Transformer Parameters 60 | SENTENTENCE_TRANSFORMER_BATCH_SIZE = 128 # TUNE THIS VARIABLE depending on the size of your embedding model and GPU mem available 61 | 62 | # entity extraction template 63 | Entity_Extract_TEMPLATE = """ 64 | You are given a Query and Query Time. Do the following: 65 | 66 | 1) Determine the domain the query is about. The domain should be one of the following: "finance", "sports", "music", "movie", "encyclopedia". If none of the domain applies, use "other". Use "domain" as the key in the result json. 67 | 68 | 2) Extract structured information from the query. Include different keys into the result json depending on the domains, amd put them DIRECTLY in the result json. Here are the rules: 69 | 70 | For `encyclopedia` and `other` queries, these are possible keys: 71 | - `main_entity`: extract the main entity of the query. 72 | 73 | For `finance` queries, these are possible keys: 74 | - `market_identifier`: stock identifiers including individual company names, stock symbols. 75 | - `metric`: financial metrics that the query is asking about. This must be one of the following: `price`, `dividend`, `P/E ratio`, `EPS`, `marketCap`, and `other`. 76 | - `datetime`: time frame that query asks about. When datetime is not explicitly mentioned, use `Query Time` as default. 77 | 78 | For `movie` queries, these are possible keys: 79 | - `movie_name`: name of the movie 80 | - `movie_aspect`: if the query is about a movie, which movie aspect the query asks. This must be one of the following: `budget`, `genres`, `original_language`, `original_title`, `release_date`, `revenue`, `title`, `cast`, `crew`, `rating`, `length`. 81 | - `person`: person name related to moves 82 | - `person_aspect`: if the query is about a person, which person aspect the query asks. This must be one of the following: `acted_movies`, `directed_movies`, `oscar_awards`, `birthday`. 83 | - `year`: if the query is about movies released in a specific year, extract the year 84 | 85 | For `music` queries, these are possible keys: 86 | - `artist_name`: name of the artist 87 | - `artist_aspect`: if the query is about an artist, extract the aspect of the artist. This must be one of the following: `member`, `birth place`, `birth date`, `lifespan`, `artist work`, `grammy award count`, `grammy award date`. 88 | - `song_name`: name of the song 89 | - `song_aspect`: if the query is about a song, extract the aspect of the song. This must be one of the following: `auther`, `grammy award count`, `release country`, `release date`. 90 | 91 | For `sports` queries, these are possible keys: 92 | - `sport_type`: one of `basketball`, `soccer`, `other` 93 | - `tournament`: such as NBA, World Cup, Olympic. 94 | - `team`: teams that user interested in. 95 | - `datetime`: time frame that user interested in. When datetime is not explicitly mentioned, use `Query Time` as default. 96 | 97 | Return the results in a FLAT json. 98 | 99 | *NEVER include ANY EXPLANATION or NOTE in the output, ONLY OUTPUT JSON* 100 | """ 101 | 102 | #### CONFIG PARAMETERS END--- 103 | 104 | class ChunkExtractor: 105 | 106 | @ray.remote 107 | def _extract_chunks(self, interaction_id, html_source): 108 | """ 109 | Extracts and returns chunks from given HTML source. 110 | 111 | Note: This function is for demonstration purposes only. 112 | We are treating an independent sentence as a chunk here, 113 | but you could choose to chunk your text more cleverly than this. 114 | 115 | Parameters: 116 | interaction_id (str): Interaction ID that this HTML source belongs to. 117 | html_source (str): HTML content from which to extract text. 118 | 119 | Returns: 120 | Tuple[str, List[str]]: A tuple containing the interaction ID and a list of sentences extracted from the HTML content. 121 | """ 122 | # Parse the HTML content using BeautifulSoup 123 | soup = BeautifulSoup(html_source, "lxml") 124 | text = soup.get_text(" ", strip=True) # Use space as a separator, strip whitespaces 125 | 126 | if not text: 127 | # Return a list with empty string when no text is extracted 128 | return interaction_id, [""] 129 | 130 | # Extract offsets of sentences from the text 131 | _, offsets = text_to_sentences_and_offsets(text) 132 | 133 | # Initialize a list to store sentences 134 | chunks = [] 135 | 136 | # Iterate through the list of offsets and extract sentences 137 | for start, end in offsets: 138 | # Extract the sentence and limit its length 139 | sentence = text[start:end][:MAX_CONTEXT_SENTENCE_LENGTH] 140 | chunks.append(sentence) 141 | 142 | return interaction_id, chunks 143 | 144 | def extract_chunks(self, batch_interaction_ids, batch_search_results): 145 | """ 146 | Extracts chunks from given batch search results using parallel processing with Ray. 147 | 148 | Parameters: 149 | batch_interaction_ids (List[str]): List of interaction IDs. 150 | batch_search_results (List[List[Dict]]): List of search results batches, each containing HTML text. 151 | 152 | Returns: 153 | Tuple[np.ndarray, np.ndarray]: A tuple containing an array of chunks and an array of corresponding interaction IDs. 154 | """ 155 | # Setup parallel chunk extraction using ray remote 156 | ray_response_refs = [ 157 | self._extract_chunks.remote( 158 | self, 159 | interaction_id=batch_interaction_ids[idx], 160 | html_source=html_text["page_result"] 161 | ) 162 | for idx, search_results in enumerate(batch_search_results) 163 | for html_text in search_results 164 | ] 165 | 166 | # Wait until all sentence extractions are complete 167 | # and collect chunks for every interaction_id separately 168 | chunk_dictionary = defaultdict(list) 169 | 170 | for response_ref in ray_response_refs: 171 | interaction_id, _chunks = ray.get(response_ref) # Blocking call until parallel execution is complete 172 | chunk_dictionary[interaction_id].extend(_chunks) 173 | 174 | # Flatten chunks and keep a map of corresponding interaction_ids 175 | chunks, chunk_interaction_ids = self._flatten_chunks(chunk_dictionary) 176 | 177 | return chunks, chunk_interaction_ids 178 | 179 | def _flatten_chunks(self, chunk_dictionary): 180 | """ 181 | Flattens the chunk dictionary into separate lists for chunks and their corresponding interaction IDs. 182 | 183 | Parameters: 184 | chunk_dictionary (defaultdict): Dictionary with interaction IDs as keys and lists of chunks as values. 185 | 186 | Returns: 187 | Tuple[np.ndarray, np.ndarray]: A tuple containing an array of chunks and an array of corresponding interaction IDs. 188 | """ 189 | chunks = [] 190 | chunk_interaction_ids = [] 191 | 192 | for interaction_id, _chunks in chunk_dictionary.items(): 193 | # De-duplicate chunks within the scope of an interaction ID 194 | unique_chunks = list(set(_chunks)) 195 | chunks.extend(unique_chunks) 196 | chunk_interaction_ids.extend([interaction_id] * len(unique_chunks)) 197 | 198 | # Convert to numpy arrays for convenient slicing/masking operations later 199 | chunks = np.array(chunks) 200 | chunk_interaction_ids = np.array(chunk_interaction_ids) 201 | 202 | return chunks, chunk_interaction_ids 203 | 204 | def extract_json_objects(text, decoder=JSONDecoder()): 205 | """Find JSON objects in text, and yield the decoded JSON data 206 | """ 207 | pos = 0 208 | results = [] 209 | while True: 210 | match = text.find("{", pos) 211 | if match == -1: 212 | break 213 | try: 214 | result, index = decoder.raw_decode(text[match:]) 215 | results.append(result) 216 | pos = match + index 217 | except ValueError: 218 | pos = match + 1 219 | return results 220 | 221 | class RAG_KG_Model: 222 | """ 223 | An example RAGModel 224 | """ 225 | def __init__(self): 226 | self.initialize_models() 227 | self.chunk_extractor = ChunkExtractor() 228 | 229 | def initialize_models(self): 230 | # Initialize Meta Llama 3 - 8B Instruct Model 231 | self.model_name = "models/meta-llama/Meta-Llama-3-8B-Instruct" 232 | 233 | if not os.path.exists(self.model_name): 234 | raise Exception( 235 | f""" 236 | The evaluators expect the model weights to be checked into the repository, 237 | but we could not find the model weights at {self.model_name} 238 | 239 | Please follow the instructions in the docs below to download and check in the model weights. 240 | 241 | https://github.com/facebookresearch/CRAG/blob/main/docs/download_baseline_model_weights.md 242 | """ 243 | ) 244 | 245 | # Initialize the model with vllm 246 | self.llm = vllm.LLM( 247 | self.model_name, 248 | worker_use_ray=True, 249 | tensor_parallel_size=VLLM_TENSOR_PARALLEL_SIZE, 250 | gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, 251 | trust_remote_code=True, 252 | dtype= "half", # note: bfloat16 is not supported on nvidia-T4 GPUs 253 | enforce_eager=True 254 | ) 255 | self.tokenizer = self.llm.get_tokenizer() 256 | 257 | # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available. 258 | self.sentence_model = SentenceTransformer( 259 | "models/sentence-transformers/all-MiniLM-L6-v2", 260 | device=torch.device( 261 | "cuda" if torch.cuda.is_available() else "cpu" 262 | ), 263 | ) 264 | 265 | def calculate_embeddings(self, sentences): 266 | """ 267 | Compute normalized embeddings for a list of sentences using a sentence encoding model. 268 | 269 | This function leverages multiprocessing to encode the sentences, which can enhance the 270 | processing speed on multi-core machines. 271 | 272 | Args: 273 | sentences (List[str]): A list of sentences for which embeddings are to be computed. 274 | 275 | Returns: 276 | np.ndarray: An array of normalized embeddings for the given sentences. 277 | 278 | """ 279 | embeddings = self.sentence_model.encode( 280 | sentences=sentences, 281 | normalize_embeddings=True, 282 | batch_size=SENTENTENCE_TRANSFORMER_BATCH_SIZE, 283 | ) 284 | # Note: There is an opportunity to parallelize the embedding generation across 4 GPUs 285 | # but sentence_model.encode_multi_process seems to interefere with Ray 286 | # on the evaluation servers. 287 | # todo: this can also be done in a Ray native approach. 288 | # 289 | return embeddings 290 | 291 | def get_batch_size(self) -> int: 292 | """ 293 | Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. 294 | 295 | The evaluation timeouts linearly scale with the batch size. 296 | i.e.: time out for the `batch_generate_answer` call = batch_size * per_sample_timeout 297 | 298 | 299 | Returns: 300 | int: The batch size, an integer between 1 and 16. It can be dynamic 301 | across different batch_generate_answer calls, or stay a static value. 302 | """ 303 | self.batch_size = SUBMISSION_BATCH_SIZE 304 | return self.batch_size 305 | 306 | def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: 307 | """ 308 | Generates answers for a batch of queries using associated (pre-cached) search results and query times. 309 | 310 | Parameters: 311 | batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: 312 | - 'interaction_id; (List[str]): List of interaction_ids for the associated queries 313 | - 'query' (List[str]): List of user queries. 314 | - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding to a query. 315 | - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. 316 | 317 | Returns: 318 | List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. 319 | If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. 320 | 321 | Notes: 322 | - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid 323 | the penalty for hallucination. 324 | - Response Time: Ensure that your model processes and responds to each query within 30 seconds. 325 | Failing to adhere to this time constraint **will** result in a timeout during evaluation. 326 | """ 327 | batch_interaction_ids = batch["interaction_id"] 328 | queries = batch["query"] 329 | batch_search_results = batch["search_results"] 330 | query_times = batch["query_time"] 331 | 332 | # Chunk all search results using ChunkExtractor 333 | chunks, chunk_interaction_ids = self.chunk_extractor.extract_chunks( 334 | batch_interaction_ids, batch_search_results 335 | ) 336 | 337 | # Calculate all chunk embeddings 338 | chunk_embeddings = self.calculate_embeddings(chunks) 339 | 340 | # Calculate embeddings for queries 341 | query_embeddings = self.calculate_embeddings(queries) 342 | 343 | # Retrieve top matches for the whole batch 344 | batch_retrieval_results = [] 345 | for _idx, interaction_id in enumerate(batch_interaction_ids): 346 | query = queries[_idx] 347 | query_time = query_times[_idx] 348 | query_embedding = query_embeddings[_idx] 349 | 350 | # Identify chunks that belong to this interaction_id 351 | relevant_chunks_mask = chunk_interaction_ids == interaction_id 352 | 353 | # Filter out the said chunks and corresponding embeddings 354 | relevant_chunks = chunks[relevant_chunks_mask] 355 | relevant_chunks_embeddings = chunk_embeddings[relevant_chunks_mask] 356 | 357 | # Calculate cosine similarity between query and chunk embeddings, 358 | cosine_scores = (relevant_chunks_embeddings * query_embedding).sum(1) 359 | 360 | # and retrieve top-N results. 361 | retrieval_results = relevant_chunks[ 362 | (-cosine_scores).argsort()[:NUM_CONTEXT_SENTENCES] 363 | ] 364 | 365 | # You might also choose to skip the steps above and 366 | # use a vectorDB directly. 367 | batch_retrieval_results.append(retrieval_results) 368 | 369 | # Retrieve knowledge graph results 370 | entities = self.extract_entity(batch) 371 | batch_kg_results = self.get_kg_results(entities) 372 | # Prepare formatted prompts from the LLM 373 | formatted_prompts = self.format_prompts(queries, query_times, batch_retrieval_results, batch_kg_results) 374 | # Generate responses via vllm 375 | responses = self.llm.generate( 376 | formatted_prompts, 377 | vllm.SamplingParams( 378 | n=1, # Number of output sequences to return for each prompt. 379 | top_p=0.9, # Float that controls the cumulative probability of the top tokens to consider. 380 | temperature=0.1, # Randomness of the sampling 381 | skip_special_tokens=True, # Whether to skip special tokens in the output. 382 | max_tokens=50, # Maximum number of tokens to generate per output sequence. 383 | 384 | # Note: We are using 50 max new tokens instead of 75, 385 | # because the 75 max token limit for the competition is checked using the Llama2 tokenizer. 386 | # Llama3 instead uses a different tokenizer with a larger vocabulary 387 | # This allows the Llama3 tokenizer to represent the same content more efficiently, 388 | # while using fewer tokens. 389 | ), 390 | use_tqdm=False # you might consider setting this to True during local development 391 | ) 392 | 393 | # Aggregate answers into List[str] 394 | answers = [] 395 | for response in responses: 396 | answers.append(response.outputs[0].text) 397 | 398 | return answers 399 | 400 | def format_prompts(self, queries, query_times, batch_retrieval_results=[], batch_kg_results=[]): 401 | """ 402 | Formats queries, corresponding query_times and retrieval results using the chat_template of the model. 403 | 404 | Parameters: 405 | - queries (List[str]): A list of queries to be formatted into prompts. 406 | - query_times (List[str]): A list of query_time strings corresponding to each query. 407 | - batch_retrieval_results (List[str]) 408 | - batch_kg_results (List[str]) 409 | """ 410 | system_prompt = "You are provided with a question and various references. Your task is to answer the question succinctly, using the fewest words possible. If the references do not contain the necessary information to answer the question, respond with 'I don't know'. There is no need to explain the reasoning behind your answers." 411 | formatted_prompts = [] 412 | 413 | for _idx, query in enumerate(queries): 414 | query_time = query_times[_idx] 415 | retrieval_results = batch_retrieval_results[_idx] 416 | kg_results = batch_kg_results[_idx] 417 | 418 | user_message = "" 419 | retrieval_references = "" 420 | if len(retrieval_results) > 0: 421 | # Format the top sentences as references in the model's prompt template. 422 | for _snippet_idx, snippet in enumerate(retrieval_results): 423 | retrieval_references += f"- {snippet.strip()}\n" 424 | # Limit the length of references to fit the model's input size. 425 | retrieval_references = retrieval_references[: int(MAX_CONTEXT_REFERENCES_LENGTH / 2)] 426 | kg_results = kg_results[: int(MAX_CONTEXT_REFERENCES_LENGTH / 2)] 427 | 428 | references = "### References\n" + \ 429 | "# Web\n" + \ 430 | retrieval_references + \ 431 | "# Knowledge Graph\n" + \ 432 | kg_results 433 | 434 | user_message += f"{references}\n------\n\n" 435 | user_message 436 | user_message += f"Using only the references listed above, answer the following question: \n" 437 | user_message += f"Current Time: {query_time}\n" 438 | user_message += f"Question: {query}\n" 439 | 440 | formatted_prompts.append( 441 | self.tokenizer.apply_chat_template( 442 | [ 443 | {"role": "system", "content": system_prompt}, 444 | {"role": "user", "content": user_message}, 445 | ], 446 | tokenize=False, 447 | add_generation_prompt=True, 448 | ) 449 | ) 450 | 451 | return formatted_prompts 452 | 453 | def extract_entity(self, batch): 454 | queries = batch["query"] 455 | query_times = batch["query_time"] 456 | formatted_prompts = self.format_prompts_for_entity_extraction(queries, query_times) 457 | responses = self.llm.generate( 458 | formatted_prompts, 459 | vllm.SamplingParams( 460 | n=1, # Number of output sequences to return for each prompt. 461 | top_p=0.9, # Float that controls the cumulative probability of the top tokens to consider. 462 | temperature=0.1, # Randomness of the sampling 463 | skip_special_tokens=True, # Whether to skip special tokens in the output. 464 | max_tokens=4096, # Maximum number of tokens to generate per output sequence. 465 | ), 466 | use_tqdm=False # you might consider setting this to True during local development 467 | ) 468 | 469 | entities = [] 470 | for response in responses: 471 | res = response.outputs[0].text 472 | try: 473 | res = json.loads(res) 474 | except: 475 | res = extract_json_objects(res) 476 | entities.append(res) 477 | return entities 478 | 479 | def get_kg_results(self, entities): 480 | # examples for "open" (encyclopedia), "movie" or "other" domains 481 | api = CRAG(server=CRAG_MOCK_API_URL) 482 | batch_kg_results = [] 483 | for entity in entities: 484 | kg_results = [] 485 | res = "" 486 | if "domain" in entity.keys(): 487 | domain = entity["domain"] 488 | if domain in ["encyclopedia", "other"]: 489 | if "main_entity" in entity.keys(): 490 | try: 491 | top_entity_name = api.open_search_entity_by_name(entity["main_entity"])["result"][0] 492 | res = api.open_get_entity(top_entity_name)["result"] 493 | kg_results.append({top_entity_name: res}) 494 | except Exception as e: 495 | logger.warning(f"Error in open_get_entity: {e}") 496 | pass 497 | if domain == "movie": 498 | if "movie_name" in entity.keys() and entity["movie_name"] is not None: 499 | if isinstance(entity["movie_name"], str): 500 | movie_names = entity["movie_name"].split(",") 501 | else: 502 | movie_names = entity["movie_name"] 503 | for movie_name in movie_names: 504 | try: 505 | res = api.movie_get_movie_info(movie_name)["result"][0] 506 | res = res[entity["movie_aspect"]] 507 | kg_results.append({movie_name + "_" + entity["movie_aspect"]: res}) 508 | except Exception as e: 509 | logger.warning(f"Error in movie_get_movie_info: {e}") 510 | pass 511 | if "person" in entity.keys() and entity["person"] is not None: 512 | if isinstance(entity["person"], str): 513 | person_list = entity["person"].split(",") 514 | else: 515 | person_list = entity["person"] 516 | for person in person_list: 517 | try: 518 | res = api.movie_get_person_info(person)["result"][0] 519 | aspect = entity["person_aspect"] 520 | if aspect in ["oscar_awards", "birthday"]: 521 | res = res[aspect] 522 | kg_results.append({person + "_" + aspect: res}) 523 | if aspect in ["acted_movies", "directed_movies"]: 524 | movie_info = [] 525 | for movie_id in res[aspect]: 526 | movie_info.append(api.movie_get_movie_info_by_id(movie_id)) 527 | kg_results.append({person + "_" + aspect: movie_info}) 528 | except Exception as e: 529 | logger.warning(f"Error in movie_get_person_info: {e}") 530 | pass 531 | if "year" in entity.keys() and entity["year"] is not None: 532 | if isinstance(entity["year"], str) or isinstance(entity["year"], int): 533 | years = str(entity["year"]).split(",") 534 | else: 535 | years = entity["year"] 536 | for year in years: 537 | try: 538 | res = api.movie_get_year_info(year)["result"] 539 | all_movies = [] 540 | oscar_movies = [] 541 | for movie_id in res["movie_list"]: 542 | all_movies.append(api.movie_get_movie_info_by_id(movie_id)["result"]["title"]) 543 | for movie_id in res["oscar_awards"]: 544 | oscar_movies.append(api.movie_get_movie_info_by_id(movie_id)["result"]["title"]) 545 | kg_results.append({year + "_all_movies": all_movies}) 546 | kg_results.append({year + "_oscar_movies": oscar_movies}) 547 | except Exception as e: 548 | logger.warning(f"Error in movie_get_year_info: {e}") 549 | pass 550 | batch_kg_results.append("\n".join([str(res) for res in kg_results]) if len(kg_results) > 0 else "") 551 | return batch_kg_results 552 | 553 | def format_prompts_for_entity_extraction(self, queries, query_times): 554 | formatted_prompts = [] 555 | for _idx, query in enumerate(queries): 556 | query_time = query_times[_idx] 557 | user_message = "" 558 | user_message += f"Query: {query}\n" 559 | user_message += f"Query Time: {query_time}\n" 560 | 561 | formatted_prompts.append( 562 | self.tokenizer.apply_chat_template( 563 | [ 564 | {"role": "system", "content": Entity_Extract_TEMPLATE}, 565 | {"role": "user", "content": user_message}, 566 | ], 567 | tokenize=False, 568 | add_generation_prompt=True, 569 | ) 570 | ) 571 | return formatted_prompts 572 | 573 | -------------------------------------------------------------------------------- /models/rag_llama_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from collections import defaultdict 9 | from typing import Any, Dict, List 10 | 11 | import numpy as np 12 | import ray 13 | import torch 14 | import vllm 15 | from blingfire import text_to_sentences_and_offsets 16 | from bs4 import BeautifulSoup 17 | from sentence_transformers import SentenceTransformer 18 | 19 | ###################################################################################################### 20 | ###################################################################################################### 21 | ### 22 | ### Please pay special attention to the comments that start with "TUNE THIS VARIABLE" 23 | ### as they depend on your model and the available GPU resources. 24 | ### 25 | ### DISCLAIMER: This baseline has NOT been tuned for performance 26 | ### or efficiency, and is provided as is for demonstration. 27 | ###################################################################################################### 28 | 29 | 30 | # Load the environment variable that specifies the URL of the MockAPI. This URL is essential 31 | # for accessing the correct API endpoint in Task 2 and Task 3. The value of this environment variable 32 | # may vary across different evaluation settings, emphasizing the importance of dynamically obtaining 33 | # the API URL to ensure accurate endpoint communication. 34 | 35 | CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") 36 | 37 | 38 | #### CONFIG PARAMETERS --- 39 | 40 | # Define the number of context sentences to consider for generating an answer. 41 | NUM_CONTEXT_SENTENCES = 20 42 | # Set the maximum length for each context sentence (in characters). 43 | MAX_CONTEXT_SENTENCE_LENGTH = 1000 44 | # Set the maximum context references length (in characters). 45 | MAX_CONTEXT_REFERENCES_LENGTH = 4000 46 | 47 | # Batch size you wish the evaluators will use to call the `batch_generate_answer` function 48 | SUBMISSION_BATCH_SIZE = 8 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 49 | 50 | # VLLM Parameters 51 | VLLM_TENSOR_PARALLEL_SIZE = 4 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 52 | VLLM_GPU_MEMORY_UTILIZATION = 0.85 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 53 | 54 | # Sentence Transformer Parameters 55 | SENTENTENCE_TRANSFORMER_BATCH_SIZE = 128 # TUNE THIS VARIABLE depending on the size of your embedding model and GPU mem available 56 | 57 | #### CONFIG PARAMETERS END--- 58 | 59 | class ChunkExtractor: 60 | 61 | @ray.remote 62 | def _extract_chunks(self, interaction_id, html_source): 63 | """ 64 | Extracts and returns chunks from given HTML source. 65 | 66 | Note: This function is for demonstration purposes only. 67 | We are treating an independent sentence as a chunk here, 68 | but you could choose to chunk your text more cleverly than this. 69 | 70 | Parameters: 71 | interaction_id (str): Interaction ID that this HTML source belongs to. 72 | html_source (str): HTML content from which to extract text. 73 | 74 | Returns: 75 | Tuple[str, List[str]]: A tuple containing the interaction ID and a list of sentences extracted from the HTML content. 76 | """ 77 | # Parse the HTML content using BeautifulSoup 78 | soup = BeautifulSoup(html_source, "lxml") 79 | text = soup.get_text(" ", strip=True) # Use space as a separator, strip whitespaces 80 | 81 | if not text: 82 | # Return a list with empty string when no text is extracted 83 | return interaction_id, [""] 84 | 85 | # Extract offsets of sentences from the text 86 | _, offsets = text_to_sentences_and_offsets(text) 87 | 88 | # Initialize a list to store sentences 89 | chunks = [] 90 | 91 | # Iterate through the list of offsets and extract sentences 92 | for start, end in offsets: 93 | # Extract the sentence and limit its length 94 | sentence = text[start:end][:MAX_CONTEXT_SENTENCE_LENGTH] 95 | chunks.append(sentence) 96 | 97 | return interaction_id, chunks 98 | 99 | def extract_chunks(self, batch_interaction_ids, batch_search_results): 100 | """ 101 | Extracts chunks from given batch search results using parallel processing with Ray. 102 | 103 | Parameters: 104 | batch_interaction_ids (List[str]): List of interaction IDs. 105 | batch_search_results (List[List[Dict]]): List of search results batches, each containing HTML text. 106 | 107 | Returns: 108 | Tuple[np.ndarray, np.ndarray]: A tuple containing an array of chunks and an array of corresponding interaction IDs. 109 | """ 110 | # Setup parallel chunk extraction using ray remote 111 | ray_response_refs = [ 112 | self._extract_chunks.remote( 113 | self, 114 | interaction_id=batch_interaction_ids[idx], 115 | html_source=html_text["page_result"] 116 | ) 117 | for idx, search_results in enumerate(batch_search_results) 118 | for html_text in search_results 119 | ] 120 | 121 | # Wait until all sentence extractions are complete 122 | # and collect chunks for every interaction_id separately 123 | chunk_dictionary = defaultdict(list) 124 | 125 | for response_ref in ray_response_refs: 126 | interaction_id, _chunks = ray.get(response_ref) # Blocking call until parallel execution is complete 127 | chunk_dictionary[interaction_id].extend(_chunks) 128 | 129 | # Flatten chunks and keep a map of corresponding interaction_ids 130 | chunks, chunk_interaction_ids = self._flatten_chunks(chunk_dictionary) 131 | 132 | return chunks, chunk_interaction_ids 133 | 134 | def _flatten_chunks(self, chunk_dictionary): 135 | """ 136 | Flattens the chunk dictionary into separate lists for chunks and their corresponding interaction IDs. 137 | 138 | Parameters: 139 | chunk_dictionary (defaultdict): Dictionary with interaction IDs as keys and lists of chunks as values. 140 | 141 | Returns: 142 | Tuple[np.ndarray, np.ndarray]: A tuple containing an array of chunks and an array of corresponding interaction IDs. 143 | """ 144 | chunks = [] 145 | chunk_interaction_ids = [] 146 | 147 | for interaction_id, _chunks in chunk_dictionary.items(): 148 | # De-duplicate chunks within the scope of an interaction ID 149 | unique_chunks = list(set(_chunks)) 150 | chunks.extend(unique_chunks) 151 | chunk_interaction_ids.extend([interaction_id] * len(unique_chunks)) 152 | 153 | # Convert to numpy arrays for convenient slicing/masking operations later 154 | chunks = np.array(chunks) 155 | chunk_interaction_ids = np.array(chunk_interaction_ids) 156 | 157 | return chunks, chunk_interaction_ids 158 | 159 | class RAGModel: 160 | """ 161 | An example RAGModel 162 | """ 163 | def __init__(self): 164 | self.initialize_models() 165 | self.chunk_extractor = ChunkExtractor() 166 | 167 | def initialize_models(self): 168 | # Initialize Meta Llama 3 - 8B Instruct Model 169 | self.model_name = "models/meta-llama/Meta-Llama-3-8B-Instruct" 170 | 171 | if not os.path.exists(self.model_name): 172 | raise Exception( 173 | f""" 174 | The evaluators expect the model weights to be checked into the repository, 175 | but we could not find the model weights at {self.model_name} 176 | """ 177 | ) 178 | 179 | # Initialize the model with vllm 180 | self.llm = vllm.LLM( 181 | self.model_name, 182 | worker_use_ray=True, 183 | tensor_parallel_size=VLLM_TENSOR_PARALLEL_SIZE, 184 | gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, 185 | trust_remote_code=True, 186 | dtype="half", # note: update the dtype based on the available GPU 187 | enforce_eager=True 188 | ) 189 | self.tokenizer = self.llm.get_tokenizer() 190 | 191 | # Load a sentence transformer model optimized for sentence embeddings, using CUDA if available. 192 | self.sentence_model = SentenceTransformer( 193 | "models/sentence-transformers/all-MiniLM-L6-v2", 194 | device=torch.device( 195 | "cuda" if torch.cuda.is_available() else "cpu" 196 | ), 197 | ) 198 | 199 | def calculate_embeddings(self, sentences): 200 | """ 201 | Compute normalized embeddings for a list of sentences using a sentence encoding model. 202 | 203 | This function leverages multiprocessing to encode the sentences, which can enhance the 204 | processing speed on multi-core machines. 205 | 206 | Args: 207 | sentences (List[str]): A list of sentences for which embeddings are to be computed. 208 | 209 | Returns: 210 | np.ndarray: An array of normalized embeddings for the given sentences. 211 | 212 | """ 213 | embeddings = self.sentence_model.encode( 214 | sentences=sentences, 215 | normalize_embeddings=True, 216 | batch_size=SENTENTENCE_TRANSFORMER_BATCH_SIZE, 217 | ) 218 | # Note: There is an opportunity to parallelize the embedding generation across 4 GPUs 219 | # but sentence_model.encode_multi_process seems to interefere with Ray 220 | # on the evaluation servers. 221 | # todo: this can also be done in a Ray native approach. 222 | # 223 | return embeddings 224 | 225 | def get_batch_size(self) -> int: 226 | """ 227 | Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. 228 | 229 | The evaluation timeouts linearly scale with the batch size. 230 | i.e.: time out for the `batch_generate_answer` call = batch_size * per_sample_timeout 231 | 232 | 233 | Returns: 234 | int: The batch size, an integer between 1 and 16. It can be dynamic 235 | across different batch_generate_answer calls, or stay a static value. 236 | """ 237 | self.batch_size = SUBMISSION_BATCH_SIZE 238 | return self.batch_size 239 | 240 | def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: 241 | """ 242 | Generates answers for a batch of queries using associated (pre-cached) search results and query times. 243 | 244 | Parameters: 245 | batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: 246 | - 'interaction_id; (List[str]): List of interaction_ids for the associated queries 247 | - 'query' (List[str]): List of user queries. 248 | - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding 249 | to a query. 250 | - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. 251 | 252 | Returns: 253 | List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. 254 | If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. 255 | 256 | Notes: 257 | - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid 258 | the penalty for hallucination. 259 | - Response Time: Ensure that your model processes and responds to each query within 30 seconds. 260 | Failing to adhere to this time constraint **will** result in a timeout during evaluation. 261 | """ 262 | batch_interaction_ids = batch["interaction_id"] 263 | queries = batch["query"] 264 | batch_search_results = batch["search_results"] 265 | query_times = batch["query_time"] 266 | 267 | # Chunk all search results using ChunkExtractor 268 | chunks, chunk_interaction_ids = self.chunk_extractor.extract_chunks( 269 | batch_interaction_ids, batch_search_results 270 | ) 271 | 272 | # Calculate all chunk embeddings 273 | chunk_embeddings = self.calculate_embeddings(chunks) 274 | 275 | # Calculate embeddings for queries 276 | query_embeddings = self.calculate_embeddings(queries) 277 | 278 | # Retrieve top matches for the whole batch 279 | batch_retrieval_results = [] 280 | for _idx, interaction_id in enumerate(batch_interaction_ids): 281 | query = queries[_idx] 282 | query_time = query_times[_idx] 283 | query_embedding = query_embeddings[_idx] 284 | 285 | # Identify chunks that belong to this interaction_id 286 | relevant_chunks_mask = chunk_interaction_ids == interaction_id 287 | 288 | # Filter out the said chunks and corresponding embeddings 289 | relevant_chunks = chunks[relevant_chunks_mask] 290 | relevant_chunks_embeddings = chunk_embeddings[relevant_chunks_mask] 291 | 292 | # Calculate cosine similarity between query and chunk embeddings, 293 | cosine_scores = (relevant_chunks_embeddings * query_embedding).sum(1) 294 | 295 | # and retrieve top-N results. 296 | retrieval_results = relevant_chunks[ 297 | (-cosine_scores).argsort()[:NUM_CONTEXT_SENTENCES] 298 | ] 299 | 300 | # You might also choose to skip the steps above and 301 | # use a vectorDB directly. 302 | batch_retrieval_results.append(retrieval_results) 303 | 304 | # Prepare formatted prompts from the LLM 305 | formatted_prompts = self.format_prompts(queries, query_times, batch_retrieval_results) 306 | 307 | # Generate responses via vllm 308 | responses = self.llm.generate( 309 | formatted_prompts, 310 | vllm.SamplingParams( 311 | n=1, # Number of output sequences to return for each prompt. 312 | top_p=0.9, # Float that controls the cumulative probability of the top tokens to consider. 313 | temperature=0.1, # Randomness of the sampling 314 | skip_special_tokens=True, # Whether to skip special tokens in the output. 315 | max_tokens=50, # Maximum number of tokens to generate per output sequence. 316 | 317 | # Note: We are using 50 max new tokens instead of 75, 318 | # because the 75 max token limit for the competition is checked using the Llama2 tokenizer. 319 | # Llama3 instead uses a different tokenizer with a larger vocabulary 320 | # This allows the Llama3 tokenizer to represent the same content more efficiently, 321 | # while using fewer tokens. 322 | ), 323 | use_tqdm=False # you might consider setting this to True during local development 324 | ) 325 | 326 | # Aggregate answers into List[str] 327 | answers = [] 328 | for response in responses: 329 | answers.append(response.outputs[0].text) 330 | 331 | return answers 332 | 333 | def format_prompts(self, queries, query_times, batch_retrieval_results=[]): 334 | """ 335 | Formats queries, corresponding query_times and retrieval results using the chat_template of the model. 336 | 337 | Parameters: 338 | - queries (List[str]): A list of queries to be formatted into prompts. 339 | - query_times (List[str]): A list of query_time strings corresponding to each query. 340 | - batch_retrieval_results (List[str]) 341 | """ 342 | system_prompt = "You are provided with a question and various references. Your task is to answer the question succinctly, using the fewest words possible. If the references do not contain the necessary information to answer the question, respond with 'I don't know'. There is no need to explain the reasoning behind your answers." 343 | formatted_prompts = [] 344 | 345 | for _idx, query in enumerate(queries): 346 | query_time = query_times[_idx] 347 | retrieval_results = batch_retrieval_results[_idx] 348 | 349 | user_message = "" 350 | references = "" 351 | 352 | if len(retrieval_results) > 0: 353 | references += "# References \n" 354 | # Format the top sentences as references in the model's prompt template. 355 | for _snippet_idx, snippet in enumerate(retrieval_results): 356 | references += f"- {snippet.strip()}\n" 357 | 358 | references = references[:MAX_CONTEXT_REFERENCES_LENGTH] 359 | # Limit the length of references to fit the model's input size. 360 | 361 | user_message += f"{references}\n------\n\n" 362 | user_message 363 | user_message += f"Using only the references listed above, answer the following question: \n" 364 | user_message += f"Current Time: {query_time}\n" 365 | user_message += f"Question: {query}\n" 366 | 367 | formatted_prompts.append( 368 | self.tokenizer.apply_chat_template( 369 | [ 370 | {"role": "system", "content": system_prompt}, 371 | {"role": "user", "content": user_message}, 372 | ], 373 | tokenize=False, 374 | add_generation_prompt=True, 375 | ) 376 | ) 377 | 378 | return formatted_prompts 379 | -------------------------------------------------------------------------------- /models/user_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # isort: skip_file 8 | # from models.dummy_model import DummyModel 9 | # UserModel = DummyModel 10 | 11 | # Uncomment the lines below to use the Vanilla LLAMA baseline 12 | from models.vanilla_llama_baseline import InstructModel 13 | UserModel = InstructModel 14 | 15 | 16 | # Uncomment the lines below to use the RAG LLAMA baseline 17 | # from models.rag_llama_baseline import RAGModel 18 | # UserModel = RAGModel 19 | 20 | # Uncomment the lines below to use the RAG KG LLAMA baseline 21 | # from models.rag_knowledge_graph_baseline import RAG_KG_Model 22 | # UserModel = RAG_KG_Model 23 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | from transformers import LlamaTokenizerFast 10 | 11 | tokenizer_path = os.path.join(os.path.dirname(__file__), "..", "tokenizer") 12 | tokenizer = LlamaTokenizerFast.from_pretrained(tokenizer_path) 13 | 14 | def trim_predictions_to_max_token_length(prediction): 15 | """Trims prediction output to 75 tokens""" 16 | max_token_length = 75 17 | tokenized_prediction = tokenizer.encode(prediction) 18 | trimmed_tokenized_prediction = tokenized_prediction[1: max_token_length+1] 19 | trimmed_prediction = tokenizer.decode(trimmed_tokenized_prediction) 20 | return trimmed_prediction 21 | -------------------------------------------------------------------------------- /models/vanilla_llama_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | from typing import Any, Dict, List 9 | 10 | import numpy as np 11 | import torch 12 | import vllm 13 | from models.utils import trim_predictions_to_max_token_length 14 | 15 | ###################################################################################################### 16 | ###################################################################################################### 17 | ### 18 | ### Please pay special attention to the comments that start with "TUNE THIS VARIABLE" 19 | ### as they depend on your model and the available GPU resources. 20 | ### 21 | ### DISCLAIMER: This baseline has NOT been tuned for performance 22 | ### or efficiency, and is provided as is for demonstration. 23 | ###################################################################################################### 24 | 25 | 26 | # Load the environment variable that specifies the URL of the MockAPI. This URL is essential 27 | # for accessing the correct API endpoint in Task 2 and Task 3. The value of this environment variable 28 | # may vary across different evaluation settings, emphasizing the importance of dynamically obtaining 29 | # the API URL to ensure accurate endpoint communication. 30 | 31 | CRAG_MOCK_API_URL = os.getenv("CRAG_MOCK_API_URL", "http://localhost:8000") 32 | 33 | 34 | #### CONFIG PARAMETERS --- 35 | 36 | # Batch size you wish the evaluators will use to call the `batch_generate_answer` function 37 | BATCH_SIZE = 8 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 38 | 39 | # VLLM Parameters 40 | VLLM_TENSOR_PARALLEL_SIZE = 4 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 41 | VLLM_GPU_MEMORY_UTILIZATION = 0.85 # TUNE THIS VARIABLE depending on the number of GPUs you are requesting and the size of your model. 42 | 43 | #### CONFIG PARAMETERS END--- 44 | 45 | class InstructModel: 46 | def __init__(self): 47 | """ 48 | Initialize your model(s) here if necessary. 49 | This is the constructor for your DummyModel class, where you can set up any 50 | required initialization steps for your model(s) to function correctly. 51 | """ 52 | self.initialize_models() 53 | 54 | def initialize_models(self): 55 | # Initialize Meta Llama 3 - 8B Instruct Model 56 | self.model_name = "models/meta-llama/Meta-Llama-3-8B-Instruct" 57 | 58 | if not os.path.exists(self.model_name): 59 | raise Exception( 60 | f""" 61 | The evaluators expect the model weights to be checked into the repository, 62 | but we could not find the model weights at {self.model_name} 63 | """ 64 | ) 65 | 66 | # initialize the model with vllm 67 | self.llm = vllm.LLM( 68 | self.model_name, 69 | worker_use_ray=True, 70 | tensor_parallel_size=VLLM_TENSOR_PARALLEL_SIZE, 71 | gpu_memory_utilization=VLLM_GPU_MEMORY_UTILIZATION, 72 | trust_remote_code=True, 73 | dtype="half", # note: update the dtype based on the available GPU 74 | enforce_eager=True 75 | ) 76 | self.tokenizer = self.llm.get_tokenizer() 77 | 78 | def get_batch_size(self) -> int: 79 | """ 80 | Determines the batch size that is used by the evaluator when calling the `batch_generate_answer` function. 81 | 82 | Returns: 83 | int: The batch size, an integer between 1 and 16. This value indicates how many 84 | queries should be processed together in a single batch. It can be dynamic 85 | across different batch_generate_answer calls, or stay a static value. 86 | """ 87 | self.batch_size = BATCH_SIZE 88 | return self.batch_size 89 | 90 | def batch_generate_answer(self, batch: Dict[str, Any]) -> List[str]: 91 | """ 92 | Generates answers for a batch of queries using associated (pre-cached) search results and query times. 93 | 94 | Parameters: 95 | batch (Dict[str, Any]): A dictionary containing a batch of input queries with the following keys: 96 | - 'interaction_id; (List[str]): List of interaction_ids for the associated queries 97 | - 'query' (List[str]): List of user queries. 98 | - 'search_results' (List[List[Dict]]): List of search result lists, each corresponding 99 | to a query. 100 | - 'query_time' (List[str]): List of timestamps (represented as a string), each corresponding to when a query was made. 101 | 102 | Returns: 103 | List[str]: A list of plain text responses for each query in the batch. Each response is limited to 75 tokens. 104 | If the generated response exceeds 75 tokens, it will be truncated to fit within this limit. 105 | 106 | Notes: 107 | - If the correct answer is uncertain, it's preferable to respond with "I don't know" to avoid 108 | the penalty for hallucination. 109 | - Response Time: Ensure that your model processes and responds to each query within 30 seconds. 110 | Failing to adhere to this time constraint **will** result in a timeout during evaluation. 111 | """ 112 | batch_interaction_ids = batch["interaction_id"] 113 | queries = batch["query"] 114 | batch_search_results = batch["search_results"] 115 | query_times = batch["query_time"] 116 | 117 | formatted_prompts = self.format_prommpts(queries, query_times) 118 | 119 | # Generate responses via vllm 120 | breakpoint() 121 | responses = self.llm.generate( 122 | formatted_prompts, 123 | vllm.SamplingParams( 124 | n=1, # Number of output sequences to return for each prompt. 125 | top_p=0.9, # Float that controls the cumulative probability of the top tokens to consider. 126 | temperature=0.1, # randomness of the sampling 127 | skip_special_tokens=True, # Whether to skip special tokens in the output. 128 | max_tokens=50, # Maximum number of tokens to generate per output sequence. 129 | # Note: We are using 50 max new tokens instead of 75, 130 | # because the 75 max token limit is checked using the Llama2 tokenizer. 131 | # The Llama3 model instead uses a differet tokenizer with a larger vocabulary 132 | # This allows it to represent the same content more efficiently, using fewer tokens. 133 | ), 134 | use_tqdm = False 135 | ) 136 | 137 | # Aggregate answers into List[str] 138 | answers = [] 139 | for response in responses: 140 | answers.append(response.outputs[0].text) 141 | 142 | return answers 143 | 144 | def format_prommpts(self, queries, query_times): 145 | """ 146 | Formats queries and corresponding query_times using the chat_template of the model. 147 | 148 | Parameters: 149 | - queries (list of str): A list of queries to be formatted into prompts. 150 | - query_times (list of str): A list of query_time strings corresponding to each query. 151 | 152 | """ 153 | system_prompt = "You are provided with a question and various references. Your task is to answer the question succinctly, using the fewest words possible. If the references do not contain the necessary information to answer the question, respond with 'I don't know'." 154 | formatted_prompts = [] 155 | 156 | for _idx, query in enumerate(queries): 157 | query_time = query_times[_idx] 158 | user_message = "" 159 | user_message += f"Current Time: {query_time}\n" 160 | user_message += f"Question: {query}\n" 161 | 162 | formatted_prompts.append( 163 | self.tokenizer.apply_chat_template( 164 | [ 165 | {"role": "system", "content": system_prompt}, 166 | {"role": "user", "content": user_message}, 167 | ], 168 | tokenize=False, 169 | add_generation_prompt=True, 170 | ) 171 | ) 172 | 173 | return formatted_prompts 174 | -------------------------------------------------------------------------------- /prompts/templates.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | INSTRUCTIONS = """Assume you are a human expert in grading predictions given by a model. You are given a question and a model prediction. Judge if the prediction matches the ground truth answer by following these steps: 8 | 1: Take it as granted that the Ground Truth is always correct. 9 | 2: If the Prediction indicates it is not sure about the answer, "score" should be "0"; otherwise, go the next step. 10 | 3: If the Prediction exactly matches the Ground Truth, "score" is 1. 11 | 4: If the Prediction does not exactly match the Ground Truth, go through the following steps and likely give a score as 0. 12 | 5: If the Ground Truth is a number, "score" is 1 if and only if the Prediction gives a number that almost exactly matches the ground truth. 13 | 6: If the Prediction is self-contradictory, "score" must be 0. 14 | 7: If the prediction is not answering the question, "score" must be 0. 15 | 8: If the prediction is a concise and correct summary of the ground truth, "score" is 1. 16 | 9: If ground truth contains a set of items, prediction must contain exactly same items for the score to be 1. 17 | 10: Otherwise, "score" is 0. 18 | 19 | ### Output a JSON blob with an "explanation" field explaining your answer as short as possible and an "score" field with value 1 or 0.""" 20 | 21 | IN_CONTEXT_EXAMPLES = """You should make the judgment based on provided examples. 22 | Examples: 23 | Question: "which company has higher eps, btu or cma?" 24 | Ground Truth: "cma" 25 | Prediction: "it is not possible to determine which company has a higher eps." 26 | Output: {"score": 0, "explanation": "The prediction is not sure about the answer."} 27 | 28 | Question: "who is taller, a or b?" 29 | Ground Truth: "a" 30 | Prediction: "The answer is a. a is of height 1.75 m and b is of height 1.82 m. So b is taller." 31 | Output: {"score": 0, "explanation": "The prediction is self-contradictory."} 32 | 33 | Question: "who among patricia place and lana clarkson is younger?" 34 | Ground Truth: "lana clarkson" 35 | Prediction: "lana clarkson (born april 5, 1962) is younger than patricia clarkson (born december 29, 1959)." 36 | Output: {"score": 1, "explanation": "The prediction gives the correct answer."} 37 | 38 | Question: "did lakers win more games than warriors in 2024-4?" 39 | Ground Truth: "yes" 40 | Prediction: "yes, lakers won more games than warriors in 2023-2024 season." 41 | Output: {"score": 0, "explanation": "The prediction is not answering the question."} 42 | 43 | Question: "what is the height of the building where citigroup is headquartered?" 44 | Ground Truth: "151 m" 45 | Prediction: "the height is 915 feet (279 m)." 46 | Output: {"score": 0, "explanation": "The prediction, 151 m, does not match the ground truth, 279 m."} 47 | 48 | Question: "who were the last 3 prime ministers of uk?" 49 | Ground Truth: "rishi sunak, liz truss, and boris johnson" 50 | Prediction: "1. liz truss, 2. boris johnson, 3.theresa may" 51 | Output: {"score": 0, "explanation": "The prediction does not contain item, rishi sunak, that is in the grount truth."} 52 | 53 | Question: "who authored the taming of the shrew?" 54 | Ground Truth: "william shakespeare" 55 | Prediction: "william shakespeare" 56 | Output: {"score": 1, "explanation": "The prediction exactly matches the ground truth."} 57 | 58 | Question: "who authored The Taming of the Shrew?" 59 | Ground Truth: "william shakespeare" 60 | Prediction: "w shakespeare" 61 | Output: {"score": 1, "explanation": "The prediction matches the ground truth as w is the abbreviation of william."} 62 | 63 | Question: "what is the state bird of california?" 64 | Ground Truth: "california quail" 65 | Prediction: "california valley quail" 66 | Output: {"score": 1, "explanation": "The prediction matches the ground truth as california quail is also known as california valley quail."} 67 | 68 | Question: "what country is maharashtra metro rail corporation limited located in?" 69 | Ground truth: "india" 70 | Prediction: "indianapolis" 71 | Output: {"score": 0, "explanation": "The prediction doesn't match the ground truth answers."} 72 | 73 | Question: "what was wolves's record in their last game in eng-premier league?" 74 | Ground Truth: "w" 75 | Prediction: "wolves's record in their last game in the eng-premier League was a 2-1 win over fulham." 76 | Output: {"score": 1, "explanation": "The prediction correctly answers the result is a win."} 77 | 78 | Question: "how many seconds is 3 minutes 15 seconds?" 79 | Ground Truth: "195 seconds" 80 | Prediction: "3 minutes 15 seconds is 195 seconds." 81 | Output: {"score": 1, "explanation": "The prediction matches ground truth as it correctly states the number of seconds asked in the question."} 82 | 83 | Question: "on which days did am distribute dividends in the last year?" 84 | Ground Truth: "2023-01-24, 2023-04-25, 2023-07-25, 2023-10-24" 85 | Prediction: "the company distributed dividends on october 24, 2023 and january 24, 2023." 86 | Output: {"score": 0, "explanation": "The prediction partially matches the ground truth, but it is incomplete."} 87 | 88 | Question: "on which days did xxx distribute dividends in the last year?" 89 | Ground Truth: "2023-01-13, 2023-03-25, 2023-11-21" 90 | Prediction: "xxx distributed dividends on 1. 2023-01-13, 2. 2023-03-25, 3. 2023-10-21." 91 | Output: {"score": 0, "explanation": "The prediction contains an item, 2023-10-21, which is not in ground truth."} 92 | 93 | Question: "who are current band members of maroon 5?" 94 | Ground Truth: "adam Levine, jesse carmichael, james valentine, matt flynn, pj morton, sam farrar" 95 | Prediction: "they are jesse carmichael, matt flynn, adam levine, pj morton, sam farrar, and james valentine." 96 | Output: {"score": 1, "explanation": "The prediction exactly matches the ground truth."} 97 | 98 | Question: "which movies comprise the matrix franchise?" 99 | Ground Truth: "the matrix, the matrix reloaded, the matrix revolutions, the matrix resurrections" 100 | Prediction: "the matrix, the matrix reloaded, the matrix revolutions, the animatrix, and the matrix resurrections." 101 | Output: {"score": 0, "explanation": "The prediction covers more items than what are given by the ground truth."} 102 | 103 | Question: "how deep is the deepest lake of new york?" 104 | Ground Truth: "618 ft" 105 | Prediction: "the deepest lake in new york is seneca lake, with a depth of 618.23 feet." 106 | Output: {"score": 1, "explanation": "The prediction exactly matches the number in ground truth after rounding."} 107 | 108 | Question: "what is the closing price of meta yesterday?" 109 | Ground Truth: "$310.17" 110 | Prediction: "310.2" 111 | Output: {"score": 1, "explanation": "The prediction exactly matches the number in ground truth after rounding."} 112 | 113 | Question: "what is the current market cap of appl?" 114 | Ground Truth: "2.81 trillion" 115 | Prediction: "2.667 trillion" 116 | Output: {"score": 0, "explanation": "The prediction does not match the number in ground truth."} 117 | 118 | Question: "what is the current pe ratio of appl?" 119 | Ground Truth: "28.3" 120 | Prediction: "the current pe ratio of apple is 26.66" 121 | Output: {"score": 0, "explanation": "The prediction does not match the number in ground truth."} 122 | 123 | Question: "how much is tesla's stock price down from its all-time high?" 124 | Ground Truth: "$221.83" 125 | Prediction: "209.52" 126 | Output: {"score": 0, "explanation": "The prediction does not match the number in ground truth."} 127 | 128 | Question: "what is the length of amazon river?" 129 | Ground Truth: "over 4000 miles" 130 | Prediction: "the length of amazon river is 4,000 miles" 131 | Output: {"score": 0, "explanation": "The prediction does not say Amazon River is longer than 4000 miles."} 132 | 133 | Question: "how many copies x were sold?" 134 | Ground Truth: "2 million." 135 | Prediction: "it is over 2 million" 136 | Output: {"score": 0, "explanation": "The prediction does not match the ground truth."} 137 | 138 | Question: "what is the population of country x?" 139 | Ground Truth: "3,576,873" 140 | Prediction: "the population of country x is 3.3 million." 141 | Output: {"score": 0, "explanation": "The prediction, 3.3 M, does not match the number, 3.6 M, in ground truth."} 142 | 143 | Question: "what is the current market value of stock x?" 144 | Ground Truth: "$2,237,578,268" 145 | Prediction: "$2.16 billion." 146 | Output: {"score": 0, "explanation": "The prediction, 2.16 B, does not match the number, 2.2 B, in ground truth."} 147 | 148 | Question: "how old is michael jordan?" 149 | Ground Truth: "61" 150 | Prediction: "mj was born on February 17, 1963." 151 | Output: {"score": 0, "explanation": "The prediction does not explicitly say the age."} 152 | 153 | Question: "how long did amy teach as a teacher?" 154 | Ground Truth: "24 years and 255 days" 155 | Prediction: "25 years." 156 | Output: {"score": 0, "explanation": "The prediction does not exactly match the ground truth."} 157 | 158 | Question: "at what age did usher perform in the half time show of superbowl?" 159 | Ground Truth: "45 years old" 160 | Prediction: "usher has not performed at a super bowl halftime show." 161 | Output: {"score": 0, "explanation": "The prediction does not match the ground truth."} 162 | 163 | Question: "what year did olivia rodrigo win the grammys best new artist award?" 164 | Ground Truth: "2022" 165 | Prediction: "sure! here's an answer in three words or fewer: 2022 grammy awards." 166 | Output: {"score": 1, "explanation": "The prediction answers the correct number of the year."} 167 | 168 | Question: "how many times have zendaya and tom holland starred in a film together?" 169 | Ground Truth: "3" 170 | Prediction: "they have starred in 3 films together: spider-man: homecoming, spider-man: far from home, and spider-man: no way home." 171 | Output: {"score": 1, "explanation": "The prediction gives the right number and it is reasonable."} 172 | 173 | Question: "who is younger, timothee chalamet or tom holland?" 174 | Ground Truth: "tom holland" 175 | Prediction: "timothée chalamet is younger than tom holland." 176 | Output: {"score": 0, "explanation": "The prediction does not match the ground truth."} 177 | 178 | Question: "who had more number one hits on the us billboard, a or b?" 179 | Ground Truth: "a had more number one hits on the us billboard than b, with 20 number one hits compared to b's 15." 180 | Prediction: "a" 181 | Output: {"score": 1, "explanation": "The prediction is a concise and correct summary of the ground truth."} 182 | 183 | Question: "what is xxx's birthdate?" 184 | Ground Truth: "1996-01-01." 185 | Prediction: "02/01/1996" 186 | Output: {"score": 0, "explanation": "The prediction does not match the ground truth."} 187 | 188 | Question: "what was the worldwide box office haul for movie x?" 189 | Ground Truth: "101756123." 190 | Prediction: "102 million" 191 | Output: {"score": 1, "explanation": "The prediction exactly matches the number in ground truth after rounding."} 192 | 193 | Question: "how much has spotify's user base increased by since 2020 in na?" 194 | Ground Truth: "spotify's user base increased by 34 million since 2020." 195 | Prediction: "spotify's north american user base increased from 36 million in 2020 to 85 million by 2021" 196 | Output: {"score": 0, "explanation": "The prediction is not answering the question as it only gives the increase from 2020 to 2021."} 197 | """ 198 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | beautifulsoup4 3 | bitsandbytes 4 | blingfire 5 | hf-transfer 6 | huggingface-hub 7 | loguru 8 | lxml 9 | openai==1.13.3 10 | sentence_transformers 11 | torch 12 | transformers 13 | vllm>=0.4.2 -------------------------------------------------------------------------------- /tokenizer/README.md: -------------------------------------------------------------------------------- 1 | # hf-internal-testing/llama-tokenizer 2 | 3 | This tokenizer has been obtained from: https://huggingface.co/hf-internal-testing/llama-tokenizer -------------------------------------------------------------------------------- /tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "unk_token": { 17 | "content": "", 18 | "lstrip": false, 19 | "normalized": true, 20 | "rstrip": false, 21 | "single_word": false 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/CRAG/c71ad61ea4f18ab0cc4b5b009932bc76e21be394/tokenizer/tokenizer.model -------------------------------------------------------------------------------- /tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": true, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "clean_up_tokenization_spaces": false, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": true, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "model_max_length": 2048, 22 | "pad_token": null, 23 | "sp_model_kwargs": {}, 24 | "tokenizer_class": "LlamaTokenizer", 25 | "unk_token": { 26 | "__type": "AddedToken", 27 | "content": "", 28 | "lstrip": false, 29 | "normalized": true, 30 | "rstrip": false, 31 | "single_word": false 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /utils/cragapi_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | from typing import List 10 | 11 | import requests 12 | 13 | """ 14 | Borrowed from: https://github.com/facebookresearch/CRAG/blob/main/mock_api/apiwrapper/pycragapi.py 15 | """ 16 | 17 | 18 | class CRAG(object): 19 | """ 20 | A client for interacting with the CRAG server, offering methods to query various domains such as Open, Movie, Finance, Music, and Sports. Each method corresponds to an API endpoint on the CRAG server. 21 | 22 | Attributes: 23 | server (str): The base URL of the CRAG server. Defaults to "http://127.0.0.1:8000". 24 | 25 | Methods: 26 | open_search_entity_by_name(query: str) -> dict: Search for entities by name in the Open domain. 27 | open_get_entity(entity: str) -> dict: Retrieve detailed information about an entity in the Open domain. 28 | movie_get_person_info(person_name: str) -> dict: Get information about a person related to movies. 29 | movie_get_movie_info(movie_name: str) -> dict: Get information about a movie. 30 | movie_get_year_info(year: str) -> dict: Get information about movies released in a specific year. 31 | movie_get_movie_info_by_id(movie_id: int) -> dict: Get movie information by its unique ID. 32 | movie_get_person_info_by_id(person_id: int) -> dict: Get person information by their unique ID. 33 | finance_get_company_name(query: str) -> dict: Search for company names in the finance domain. 34 | finance_get_ticker_by_name(query: str) -> dict: Retrieve the ticker symbol for a given company name. 35 | finance_get_price_history(ticker_name: str) -> dict: Get the price history for a given ticker symbol. 36 | finance_get_detailed_price_history(ticker_name: str) -> dict: Get detailed price history for a ticker symbol. 37 | finance_get_dividends_history(ticker_name: str) -> dict: Get dividend history for a ticker symbol. 38 | finance_get_market_capitalization(ticker_name: str) -> dict: Retrieve market capitalization for a ticker symbol. 39 | finance_get_eps(ticker_name: str) -> dict: Get earnings per share (EPS) for a ticker symbol. 40 | finance_get_pe_ratio(ticker_name: str) -> dict: Get the price-to-earnings (PE) ratio for a ticker symbol. 41 | finance_get_info(ticker_name: str) -> dict: Get financial information for a ticker symbol. 42 | music_search_artist_entity_by_name(artist_name: str) -> dict: Search for music artists by name. 43 | music_search_song_entity_by_name(song_name: str) -> dict: Search for songs by name. 44 | music_get_billboard_rank_date(rank: int, date: str = None) -> dict: Get Billboard ranking for a specific rank and date. 45 | music_get_billboard_attributes(date: str, attribute: str, song_name: str) -> dict: Get attributes of a song from Billboard rankings. 46 | music_grammy_get_best_artist_by_year(year: int) -> dict: Get the Grammy Best New Artist for a specific year. 47 | music_grammy_get_award_count_by_artist(artist_name: str) -> dict: Get the total Grammy awards won by an artist. 48 | music_grammy_get_award_count_by_song(song_name: str) -> dict: Get the total Grammy awards won by a song. 49 | music_grammy_get_best_song_by_year(year: int) -> dict: Get the Grammy Song of the Year for a specific year. 50 | music_grammy_get_award_date_by_artist(artist_name: str) -> dict: Get the years an artist won a Grammy award. 51 | music_grammy_get_best_album_by_year(year: int) -> dict: Get the Grammy Album of the Year for a specific year. 52 | music_grammy_get_all_awarded_artists() -> dict: Get all artists awarded the Grammy Best New Artist. 53 | music_get_artist_birth_place(artist_name: str) -> dict: Get the birthplace of an artist. 54 | music_get_artist_birth_date(artist_name: str) -> dict: Get the birth date of an artist. 55 | music_get_members(band_name: str) -> dict: Get the member list of a band. 56 | music_get_lifespan(artist_name: str) -> dict: Get the lifespan of an artist. 57 | music_get_song_author(song_name: str) -> dict: Get the author of a song. 58 | music_get_song_release_country(song_name: str) -> dict: Get the release country of a song. 59 | music_get_song_release_date(song_name: str) -> dict: Get the release date of a song. 60 | music_get_artist_all_works(artist_name: str) -> dict: Get all works by an artist. 61 | sports_soccer_get_games_on_date(team_name: str, date: str) -> dict: Get soccer games on a specific date. 62 | sports_nba_get_games_on_date(team_name: str, date: str) -> dict: Get NBA games on a specific date. 63 | sports_nba_get_play_by_play_data_by_game_ids(game_ids: List[str]) -> dict: Get NBA play by play data for a set of game ids. 64 | 65 | Note: 66 | Each method performs a POST request to the corresponding API endpoint and returns the response as a JSON dictionary. 67 | """ 68 | def __init__(self, server = None): 69 | if server is None: 70 | self.server = os.environ.get('CRAG_SERVER', "http://127.0.0.1:8000") 71 | else: 72 | self.server = server 73 | 74 | def open_search_entity_by_name(self, query:str): 75 | url = self.server + '/open/search_entity_by_name' 76 | headers={'accept': "application/json"} 77 | data = {'query': query} 78 | result = requests.post(url, json=data, headers=headers) 79 | return json.loads(result.text) 80 | 81 | def open_get_entity(self, entity:str): 82 | url = self.server + '/open/get_entity' 83 | headers={'accept': "application/json"} 84 | data = {'query': entity} 85 | result = requests.post(url, json=data, headers=headers) 86 | return json.loads(result.text) 87 | 88 | def movie_get_person_info(self, person_name:str): 89 | url = self.server + '/movie/get_person_info' 90 | headers={'accept': "application/json"} 91 | data = {'query': person_name} 92 | result = requests.post(url, json=data, headers=headers) 93 | return json.loads(result.text) 94 | 95 | def movie_get_movie_info(self, movie_name:str): 96 | url = self.server + '/movie/get_movie_info' 97 | headers={'accept': "application/json"} 98 | data = {'query': movie_name} 99 | result = requests.post(url, json=data, headers=headers) 100 | return json.loads(result.text) 101 | 102 | def movie_get_year_info(self, year:str): 103 | url = self.server + '/movie/get_year_info' 104 | headers={'accept': "application/json"} 105 | data = {'query': year} 106 | result = requests.post(url, json=data, headers=headers) 107 | return json.loads(result.text) 108 | 109 | def movie_get_movie_info_by_id(self, movid_id:int): 110 | url = self.server + '/movie/get_movie_info_by_id' 111 | headers={'accept': "application/json"} 112 | data = {'query': movid_id} 113 | result = requests.post(url, json=data, headers=headers) 114 | return json.loads(result.text) 115 | 116 | def movie_get_person_info_by_id(self, person_id:int): 117 | url = self.server + '/movie/get_person_info_by_id' 118 | headers={'accept': "application/json"} 119 | data = {'query': person_id} 120 | result = requests.post(url, json=data, headers=headers) 121 | return json.loads(result.text) 122 | 123 | def finance_get_company_name(self, query:str): 124 | url = self.server + '/finance/get_company_name' 125 | headers={'accept': "application/json"} 126 | data = {'query': query} 127 | result = requests.post(url, json=data, headers=headers) 128 | return json.loads(result.text) 129 | 130 | def finance_get_ticker_by_name(self, query:str): 131 | url = self.server + '/finance/get_ticker_by_name' 132 | headers={'accept': "application/json"} 133 | data = {'query': query} 134 | result = requests.post(url, json=data, headers=headers) 135 | return json.loads(result.text) 136 | 137 | def finance_get_price_history(self, ticker_name:str): 138 | url = self.server + '/finance/get_price_history' 139 | headers={'accept': "application/json"} 140 | data = {'query': ticker_name} 141 | result = requests.post(url, json=data, headers=headers) 142 | return json.loads(result.text) 143 | 144 | def finance_get_detailed_price_history(self, ticker_name:str): 145 | url = self.server + '/finance/get_detailed_price_history' 146 | headers={'accept': "application/json"} 147 | data = {'query': ticker_name} 148 | result = requests.post(url, json=data, headers=headers) 149 | return json.loads(result.text) 150 | 151 | def finance_get_dividends_history(self, ticker_name:str): 152 | url = self.server + '/finance/get_dividends_history' 153 | headers={'accept': "application/json"} 154 | data = {'query': ticker_name} 155 | result = requests.post(url, json=data, headers=headers) 156 | return json.loads(result.text) 157 | 158 | def finance_get_market_capitalization(self, ticker_name:str): 159 | url = self.server + '/finance/get_market_capitalization' 160 | headers={'accept': "application/json"} 161 | data = {'query': ticker_name} 162 | result = requests.post(url, json=data, headers=headers) 163 | return json.loads(result.text) 164 | 165 | def finance_get_eps(self, ticker_name:str): 166 | url = self.server + '/finance/get_eps' 167 | headers={'accept': "application/json"} 168 | data = {'query': ticker_name} 169 | result = requests.post(url, json=data, headers=headers) 170 | return json.loads(result.text) 171 | 172 | def finance_get_pe_ratio(self, ticker_name:str): 173 | url = self.server + '/finance/get_pe_ratio' 174 | headers={'accept': "application/json"} 175 | data = {'query': ticker_name} 176 | result = requests.post(url, json=data, headers=headers) 177 | return json.loads(result.text) 178 | 179 | def finance_get_info(self, ticker_name:str): 180 | url = self.server + '/finance/get_info' 181 | headers={'accept': "application/json"} 182 | data = {'query': ticker_name} 183 | result = requests.post(url, json=data, headers=headers) 184 | return json.loads(result.text) 185 | 186 | def music_search_artist_entity_by_name(self, artist_name:str): 187 | url = self.server + '/music/search_artist_entity_by_name' 188 | headers={'accept': "application/json"} 189 | data = {'query': artist_name} 190 | result = requests.post(url, json=data, headers=headers) 191 | return json.loads(result.text) 192 | 193 | def music_search_song_entity_by_name(self, song_name:str): 194 | url = self.server + '/music/search_song_entity_by_name' 195 | headers={'accept': "application/json"} 196 | data = {'query': song_name} 197 | result = requests.post(url, json=data, headers=headers) 198 | return json.loads(result.text) 199 | 200 | def music_get_billboard_rank_date(self, rank:int, date:str=None): 201 | url = self.server + '/music/get_billboard_rank_date' 202 | headers={'accept': "application/json"} 203 | data = {'rank': rank, 'date': date} 204 | result = requests.post(url, json=data, headers=headers) 205 | return json.loads(result.text) 206 | 207 | def music_get_billboard_attributes(self, date:str, attribute:str, song_name:str): 208 | url = self.server + '/music/get_billboard_attributes' 209 | headers={'accept': "application/json"} 210 | data = {'date': date, 'attribute': attribute, 'song_name': song_name} 211 | result = requests.post(url, json=data, headers=headers) 212 | return json.loads(result.text) 213 | 214 | def music_grammy_get_best_artist_by_year(self, year:int): 215 | url = self.server + '/music/grammy_get_best_artist_by_year' 216 | headers={'accept': "application/json"} 217 | data = {'query': year} 218 | result = requests.post(url, json=data, headers=headers) 219 | return json.loads(result.text) 220 | 221 | def music_grammy_get_award_count_by_artist(self, artist_name:str): 222 | url = self.server + '/music/grammy_get_award_count_by_artist' 223 | headers={'accept': "application/json"} 224 | data = {'query': artist_name} 225 | result = requests.post(url, json=data, headers=headers) 226 | return json.loads(result.text) 227 | 228 | def music_grammy_get_award_count_by_song(self, song_name:str): 229 | url = self.server + '/music/grammy_get_award_count_by_song' 230 | headers={'accept': "application/json"} 231 | data = {'query': song_name} 232 | result = requests.post(url, json=data, headers=headers) 233 | return json.loads(result.text) 234 | 235 | def music_grammy_get_best_song_by_year(self, year:int): 236 | url = self.server + '/music/grammy_get_best_song_by_year' 237 | headers={'accept': "application/json"} 238 | data = {'query': year} 239 | result = requests.post(url, json=data, headers=headers) 240 | return json.loads(result.text) 241 | 242 | def music_grammy_get_award_date_by_artist(self, artist_name:str): 243 | url = self.server + '/music/grammy_get_award_date_by_artist' 244 | headers={'accept': "application/json"} 245 | data = {'query': artist_name} 246 | result = requests.post(url, json=data, headers=headers) 247 | return json.loads(result.text) 248 | 249 | def music_grammy_get_best_album_by_year(self, year:int): 250 | url = self.server + '/music/grammy_get_best_album_by_year' 251 | headers={'accept': "application/json"} 252 | data = {'query': year} 253 | result = requests.post(url, json=data, headers=headers) 254 | return json.loads(result.text) 255 | 256 | def music_grammy_get_all_awarded_artists(self): 257 | url = self.server + '/music/grammy_get_all_awarded_artists' 258 | headers={'accept': "application/json"} 259 | result = requests.post(url, headers=headers) 260 | return json.loads(result.text) 261 | 262 | def music_get_artist_birth_place(self, artist_name:str): 263 | url = self.server + '/music/get_artist_birth_place' 264 | headers={'accept': "application/json"} 265 | data = {'query': artist_name} 266 | result = requests.post(url, json=data, headers=headers) 267 | return json.loads(result.text) 268 | 269 | def music_get_artist_birth_date(self, artist_name:str): 270 | url = self.server + '/music/get_artist_birth_date' 271 | headers={'accept': "application/json"} 272 | data = {'query': artist_name} 273 | result = requests.post(url, json=data, headers=headers) 274 | return json.loads(result.text) 275 | 276 | def music_get_members(self, band_name:str): 277 | url = self.server + '/music/get_members' 278 | headers={'accept': "application/json"} 279 | data = {'query': band_name} 280 | result = requests.post(url, json=data, headers=headers) 281 | return json.loads(result.text) 282 | 283 | def music_get_lifespan(self, artist_name:str): 284 | url = self.server + '/music/get_lifespan' 285 | headers={'accept': "application/json"} 286 | data = {'query': artist_name} 287 | result = requests.post(url, json=data, headers=headers) 288 | return json.loads(result.text) 289 | 290 | def music_get_song_author(self, song_name:str): 291 | url = self.server + '/music/get_song_author' 292 | headers={'accept': "application/json"} 293 | data = {'query': song_name} 294 | result = requests.post(url, json=data, headers=headers) 295 | return json.loads(result.text) 296 | 297 | def music_get_song_release_country(self, song_name:str): 298 | url = self.server + '/music/get_song_release_country' 299 | headers={'accept': "application/json"} 300 | data = {'query': song_name} 301 | result = requests.post(url, json=data, headers=headers) 302 | return json.loads(result.text) 303 | 304 | def music_get_song_release_date(self, song_name:str): 305 | url = self.server + '/music/get_song_release_date' 306 | headers={'accept': "application/json"} 307 | data = {'query': song_name} 308 | result = requests.post(url, json=data, headers=headers) 309 | return json.loads(result.text) 310 | 311 | def music_get_artist_all_works(self, song_name:str): 312 | url = self.server + '/music/get_artist_all_works' 313 | headers={'accept': "application/json"} 314 | data = {'query': song_name} 315 | result = requests.post(url, json=data, headers=headers) 316 | return json.loads(result.text) 317 | 318 | def sports_soccer_get_games_on_date(self, date:str, team_name:str=None): 319 | url = self.server + '/sports/soccer/get_games_on_date' 320 | headers={'accept': "application/json"} 321 | data = {'team_name': team_name, 'date': date} 322 | result = requests.post(url, json=data, headers=headers) 323 | return json.loads(result.text) 324 | 325 | def sports_nba_get_games_on_date(self, date:str, team_name:str=None): 326 | url = self.server + '/sports/nba/get_games_on_date' 327 | headers={'accept': "application/json"} 328 | data = {'team_name': team_name, 'date': date} 329 | result = requests.post(url, json=data, headers=headers) 330 | return json.loads(result.text) 331 | 332 | def sports_nba_get_play_by_play_data_by_game_ids(self, game_ids:List[str]): 333 | url = self.server + '/sports/nba/get_play_by_play_data_by_game_ids' 334 | headers={'accept': "application/json"} 335 | data = {'game_ids': game_ids} 336 | result = requests.post(url, json=data, headers=headers) 337 | return json.loads(result.text) 338 | 339 | 340 | --------------------------------------------------------------------------------