├── .gitignore ├── LICENSE ├── README.md ├── baseline ├── configs │ └── fsdp.json ├── dataset.py ├── predict_and_evaluate.py ├── prepare_dataset.py ├── readme.md └── train.py ├── dataset ├── response.jsonl └── source_info.jsonl └── image └── fig1.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | baseline/*.jsonl 2 | baseline/exp/ 3 | baseline/wandb/ 4 | version.txt 5 | .DS_Store 6 | eda.ipynb 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/#use-with-ide 116 | .pdm.toml 117 | 118 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 119 | __pypackages__/ 120 | 121 | # Celery stuff 122 | celerybeat-schedule 123 | celerybeat.pid 124 | 125 | # SageMath parsed files 126 | *.sage.py 127 | 128 | # Environments 129 | .env 130 | .venv 131 | env/ 132 | venv/ 133 | ENV/ 134 | env.bak/ 135 | venv.bak/ 136 | 137 | # Spyder project settings 138 | .spyderproject 139 | .spyproject 140 | 141 | # Rope project settings 142 | .ropeproject 143 | 144 | # mkdocs documentation 145 | /site 146 | 147 | # mypy 148 | .mypy_cache/ 149 | .dmypy.json 150 | dmypy.json 151 | 152 | # Pyre type checker 153 | .pyre/ 154 | 155 | # pytype static type analyzer 156 | .pytype/ 157 | 158 | # Cython debug symbols 159 | cython_debug/ 160 | 161 | # PyCharm 162 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 163 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 164 | # and can be added to the global gitignore or merged into this file. For a more nuclear 165 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 166 | #.idea/ 167 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Particle Media 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAGTruth 2 | ![](image/fig1.jpg) 3 | 4 | RAGTruth is a word-level hallucination corpus in various tasks within the Retrieval-augmented generation (RAG) setting both for **training** and **evaluating**. 5 | 6 | RAG has become a main technique for alleviating hallucinations in large language models (LLMs). Despite the integration of RAG, LLMs may still present unsupported or contradictory claims to the retrieved contents. In order to develop effective hallucination prevention strategies under RAG, it is important to create benchmark datasets that can measure the extent of hallucination. RAGTruth comprises nearly 18,000 naturally generated responses from diverse LLMs using RAG. These responses have undergone meticulous manual annotations at both the individual cases and word levels, incorporating evaluations of hallucination intensity. 7 | 8 | ## Updates 9 | 1. [2024/06] We released our training and evaluation code. Model weight can be found [here](https://github.com/CodingLL/RAGTruth_Eval/tree/master) 10 | 2. [2024/02] We updated the data: we included more annotated hallucinations and added one new meta, `implicit_true`. 11 | 3. [2024/01] We released the RAGTruth corpus. 12 | ## Dataset 13 | Given that each piece of source information elicits six distinct responses from various Language Models, we publish the source information and its corresponding responses individually. 14 | 15 | 1. **response.jsonl** 16 | 17 | | Field Name | Field Value | Description | 18 | | ----------- | ----------- | ------------------------------------------- | 19 | | id | String | The index of the response. | 20 | | source_id | String | The index of the source information. | 21 | | model | String | The model that generates the response. | 22 | | temperature | Float | The temperature when generating response. | 23 | | labels |List[Dict] | The hallucination spans. `text` is the hallucination itself; `start` & `end` are the position of hallucination span in the response; `label_type` is the type of the hallucination span. `due_to_null` means whether the hallucination is caused by null value. `implicit_true` means this span is correct while the info is not mentioned in the context. `meta` is the comment of the hallucination given by annotators. | 24 | | split | String | `train` or `test` | 25 | | quality | String | Indicate whether this answer has quality issues. `incorrect_refusal` means the model mistakenly refuses to answer despite the presence of the relevant context. `truncated` means the response is truncated unexpectedly. | 26 | | response | String | The response of LLMs given a certain instruction. | 27 | 28 | Sample: 29 | ```json 30 | { 31 | "id": "1472", 32 | "source_id": "11316", 33 | "model": "mistral-7B-instruct", 34 | "temperature": 0.925, 35 | "labels": [ 36 | { 37 | "start": 219, 38 | "end": 229, 39 | "text": "Gaza Strip", 40 | "meta": "HIGH INTRO OF NEW INFO\nIt is not mentioned in the original source that Gaza Strip is included in the occupied Palestinian territory.", 41 | "label_type": "Evident Baseless Info" 42 | } 43 | ], 44 | "split": "train", 45 | "quality": "good", 46 | "response": "The Palestinian Authority has officially become the 123rd member of the International Criminal Court (ICC), giving the court jurisdiction over alleged crimes in Palestinian territories. This includes East Jerusalem and Gaza Strip, which are occupied by Israel. The signing of Rome Statute by Palestinians in January 2021 had already established ICC's jurisdiction over alleged crimes committed \"since June 13, 2014\" in these areas. Now, the court can open a preliminary investigation or formal investigation into the situation in Palestinian territories, potentially leading to war crimes probes against Israeli individuals. However, this could also lead to counter-charges against Palestinians. The ICC welcomed Palestine's accession, while Israel and the US, who are not ICC members, opposed the move." 47 | } 48 | ``` 49 | 50 | 2. **source_info.jsonl** 51 | 52 | | Field Name | Field Value | Description | 53 | | ----------- | ----------- | ------------------------------------------- | 54 | | source_id | String | The index of the source information. | 55 | | task_type | String | The task type of the data, including QA, Data2txt and Summary | 56 | | source | String | The source of the original content. | 57 | | source_info | String or Dict | Base content under RAG setting. For summarization tasks, the value of this field is of the string type; for data-to-text and QA (question-answering) tasks, the value of this field is of the dict type. | 58 | | prompt | String | The prompt we used to generate responses. For Llama and Mistral models, we used `[INST] {prompt} [/INST]`. | 59 | 60 | QA sample: 61 | ```json 62 | { 63 | "source_id": "14312", 64 | "task_type": "QA", 65 | "source": "MARCO", 66 | "source_info": { 67 | "question": "how to prepare beets and beet greens", 68 | "passages": "passage 1:Procedures: 1 Preheat oven to 350 degrees Fahrenheit. 2 Wash beets thoroughly, leaving skins on. 3 Place beets in a small baking dish or roasting pan, toss with 2 tablespoons of coconut oil, cover and bake for 45 to 60 minutes or until tender. For the greens: heat remaining coconut oil in a skillet over medium-low heat.\n\npassage 2:Serve with red wine vinegar or butter and salt and pepper. For the greens: heat remaining coconut oil in a skillet over medium-low heat. Add garlic and onion and cook for one minute. Tear the beet greens into 2 to 3 inch pieces, and add to skillet, stirring until wilted and tender. Season with salt and pepper.\n\npassage 3:Directions See How It's Made. 1 Wash the greens thoroughly several times in deep water. Cook in very little boiling salted water until just tender, a few minutes. 2 Submit a Correction.\n\n" 69 | }, 70 | "prompt": "Briefly answer the following question:\nhow to prepare beets and beet greens\nBear in mind that your response should be strictly based on the following ten passages:\npassage 1:Procedures: 1 Preheat oven to 350 degrees Fahrenheit. 2 Wash beets thoroughly, leaving skins on. 3 Place beets in a small baking dish or roasting pan, toss with 2 tablespoons of coconut oil, cover and bake for 45 to 60 minutes or until tender. For the greens: heat remaining coconut oil in a skillet over medium-low heat.\n\npassage 2:Serve with red wine vinegar or butter and salt and pepper. For the greens: heat remaining coconut oil in a skillet over medium-low heat. Add garlic and onion and cook for one minute. Tear the beet greens into 2 to 3 inch pieces, and add to skillet, stirring until wilted and tender. Season with salt and pepper.\n\npassage 3:Directions See How It's Made. 1 Wash the greens thoroughly several times in deep water. Cook in very little boiling salted water until just tender, a few minutes. 2 Submit a Correction.\n\nIn case the passages do not contain the necessary information to answer the question, please reply with: \"Unable to answer based on given passages.\"\noutput:" 71 | } 72 | ``` 73 | 74 | Data2txt sample: 75 | ```json 76 | { 77 | "source_id": "13661", 78 | "task_type": "Data2txt", 79 | "source": "Yelp", 80 | "source_info": { 81 | "name": "Subway", 82 | "address": "1940 Cliff Dr, Ste B-13", 83 | "city": "Santa Barbara", 84 | "state": "CA", 85 | "categories": "Restaurants, Sandwiches, Salad, Fast Food", 86 | "hours": { 87 | "Monday": "9:0-22:30", 88 | "Tuesday": "9:0-22:30", 89 | "Wednesday": "9:0-22:30", 90 | "Thursday": "9:0-22:30", 91 | "Friday": "9:0-22:30", 92 | "Saturday": "9:0-22:30", 93 | "Sunday": "11:0-22:0" 94 | }, 95 | "attributes": { 96 | "BusinessParking": { 97 | "garage": false, 98 | "street": false, 99 | "validated": false, 100 | "lot": true, 101 | "valet": false 102 | }, 103 | "RestaurantsReservations": false, 104 | "OutdoorSeating": true, 105 | "WiFi": "no", 106 | "RestaurantsTakeOut": true, 107 | "RestaurantsGoodForGroups": true, 108 | "Music": null, 109 | "Ambience": { 110 | "touristy": false, 111 | "hipster": false, 112 | "romantic": null, 113 | "divey": null, 114 | "intimate": null, 115 | "trendy": null, 116 | "upscale": null, 117 | "classy": null, 118 | "casual": null 119 | } 120 | }, 121 | "business_stars": 3.0, 122 | "review_info": [ 123 | { 124 | "review_stars": 1.0, 125 | "review_date": "2020-05-11 02:07:36", 126 | "review_text": "My husband and I came in earlier today for lunch after I ordered my sandwich my husband ordered a club and didn't think anything about it while the girl made it because he assumed she knew what went on a club. Once we got in the car I looked at the receipt and realized she made him a turkey sandwich so we went back in to ask her to add the other meat and to refund us and recharge the correct price since a club is a little more. She was very rude about it and told us she wasn't going to do anything about and proceeded to call us liars and say he asked for a turkey sub. I told her she didn't have to be so rude so she told me to \"get the f**k out b***h\" and if I had a problem with it I could \"call her f***ing manager\". Also she proceeded to cuss us as we walked out of the store. It was quite unacceptable and inappropriate of an employee to be this unprofessional and aggressive." 127 | }, 128 | { 129 | "review_stars": 3.0, 130 | "review_date": "2020-03-02 20:05:55", 131 | "review_text": "Small store, personnel not very well organized, store is only moderately clean. \nStaff is friendly sometimes, other times they will only barely recognize you." 132 | }, 133 | { 134 | "review_stars": 5.0, 135 | "review_date": "2019-07-10 01:49:07", 136 | "review_text": "Nice and clean location. Toppings look fresh and well stocked. Joaquin and Odalis are always helpful and friendly." 137 | } 138 | ] 139 | }, 140 | "prompt": "Instruction:\nWrite an objective overview about the following local business based only on the provided structured data in the JSON format. You should include details and cover the information mentioned in the customers' review. The overview should be 100 - 200 words. Don't make up information. Structured data:\n{'name': 'Subway', 'address': '1940 Cliff Dr, Ste B-13', 'city': 'Santa Barbara', 'state': 'CA', 'categories': 'Restaurants, Sandwiches, Salad, Fast Food', 'hours': {'Monday': '9:0-22:30', 'Tuesday': '9:0-22:30', 'Wednesday': '9:0-22:30', 'Thursday': '9:0-22:30', 'Friday': '9:0-22:30', 'Saturday': '9:0-22:30', 'Sunday': '11:0-22:0'}, 'attributes': {'BusinessParking': {'garage': False, 'street': False, 'validated': False, 'lot': True, 'valet': False}, 'RestaurantsReservations': False, 'OutdoorSeating': True, 'WiFi': 'no', 'RestaurantsTakeOut': True, 'RestaurantsGoodForGroups': True, 'Music': None, 'Ambience': {'touristy': False, 'hipster': False, 'romantic': None, 'divey': None, 'intimate': None, 'trendy': None, 'upscale': None, 'classy': None, 'casual': None}}, 'business_stars': 3.0, 'review_info': [{'review_stars': 1.0, 'review_date': '2020-05-11 02:07:36', 'review_text': 'My husband and I came in earlier today for lunch after I ordered my sandwich my husband ordered a club and didn\\'t think anything about it while the girl made it because he assumed she knew what went on a club. Once we got in the car I looked at the receipt and realized she made him a turkey sandwich so we went back in to ask her to add the other meat and to refund us and recharge the correct price since a club is a little more. She was very rude about it and told us she wasn\\'t going to do anything about and proceeded to call us liars and say he asked for a turkey sub. I told her she didn\\'t have to be so rude so she told me to \"get the f**k out b***h\" and if I had a problem with it I could \"call her f***ing manager\". Also she proceeded to cuss us as we walked out of the store. It was quite unacceptable and inappropriate of an employee to be this unprofessional and aggressive.'}, {'review_stars': 3.0, 'review_date': '2020-03-02 20:05:55', 'review_text': 'Small store, personnel not very well organized, store is only moderately clean. \\nStaff is friendly sometimes, other times they will only barely recognize you.'}, {'review_stars': 5.0, 'review_date': '2019-07-10 01:49:07', 'review_text': 'Nice and clean location. Toppings look fresh and well stocked. Joaquin and Odalis are always helpful and friendly.'}]}\nOverview:" 141 | } 142 | ``` 143 | 144 | Summary sample: 145 | ```json 146 | { 147 | "source_id": "11316", 148 | "task_type": "Summary", 149 | "source": "CNN/DM", 150 | "source_info": "The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed \"in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014.\" Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a move toward greater justice. \"As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice,\" he said, according to an ICC news release. \"Indeed, today brings us closer to our shared goals of justice and peace.\" Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. \"As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly,\" she said. Rights group Human Rights Watch welcomed the development. \"Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court's treaty should speak out to welcome its membership,\" said Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts to undermine international justice, not Palestine's decision to join a treaty to which over 100 countries around the world are members.\" In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it \"strongly\" disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC,\" the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. \"We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,\" it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as \"Palestine.\" While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would \"conduct its analysis in full independence and impartiality.\" The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.\n", 151 | "prompt": "Summarize the following news within 141 words:\nThe Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed \"in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014.\" Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday's ceremony, said it was a move toward greater justice. \"As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice,\" he said, according to an ICC news release. \"Indeed, today brings us closer to our shared goals of justice and peace.\" Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. \"As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly,\" she said. Rights group Human Rights Watch welcomed the development. \"Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court's treaty should speak out to welcome its membership,\" said Balkees Jarrah, international justice counsel for the group. \"What's objectionable is the attempts to undermine international justice, not Palestine's decision to join a treaty to which over 100 countries around the world are members.\" In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it \"strongly\" disagreed with the court's decision. \"As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC,\" the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. \"We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace,\" it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as \"Palestine.\" While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would \"conduct its analysis in full independence and impartiality.\" The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN's Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.\n\noutput:" 152 | } 153 | ``` 154 | 155 | ## Data Statistics 156 | #### Descriptive Statistics Devided by Task 157 | 158 | Task | Instances | Responses | Hallucination Responses | Hallucination Spans | 159 | --- | --- | --- | --- | --- | 160 | Summarization(CNN/DM) | 628 | 3768 | 1165 | 1474 | 161 | Summarization(Recent News) | 315 | 1890 | 521 | 598 | 162 | Question Answering | 989 | 5934 | 1724 | 2927 | 163 | Data-to-text | 1033 | 6198 | 4254 | 9290 | 164 | Overall | 2965 | 17790 | 7664 | 14289 | 165 | 166 | 167 | #### Descriptive Statistics Devided by LLMs 168 | 169 | Model | Hallucination Responses | Hallucination Spans | 170 | --- | --- | --- | 171 | GPT-3.5-turbo-0613 | 401 | 533 | 172 | GPT-4-0613 | 406 | 485 | 173 | Llama-2-7B-chat | 1832 | 3302 | 174 | Llama-2-13B-chat | 1677 | 3799 | 175 | Llama-2-70B-chat | 1395 | 2608 | 176 | Mistral-7B-Instruct | 1953 | 3562 | 177 | 178 | ## Citation 179 | 180 | Please cite our paper if you use our dataset: 181 | ```bibtex 182 | @misc{wu2023ragtruth, 183 | title={RAGTruth: A Hallucination Corpus for Developing Trustworthy Retrieval-Augmented Language Models}, 184 | author={Yuanhao Wu and Juno Zhu and Siliang Xu and Kashun Shum and Cheng Niu and Randy Zhong and Juntong Song and Tong Zhang}, 185 | year={2023}, 186 | eprint={2401.00396}, 187 | archivePrefix={arXiv}, 188 | primaryClass={cs.CL} 189 | } 190 | ``` 191 | 192 | ## Star History 193 | 194 | [![Star History Chart](https://api.star-history.com/svg?repos=ParticleMedia/RAGTruth&type=Date)](https://star-history.com/#ParticleMedia/RAGTruth&Date) 195 | -------------------------------------------------------------------------------- /baseline/configs/fsdp.json: -------------------------------------------------------------------------------- 1 | { 2 | "use_orig_params": "False", 3 | "activation_checkpointing": "True" 4 | } -------------------------------------------------------------------------------- /baseline/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | 4 | # For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html 5 | 6 | # references: 7 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L382 8 | # https://github.com/facebookresearch/llama-recipes/blob/03faba661f079ee1ecaeb66deaa6bdec920a7bab/inference/chat_utils.py#L19 9 | 10 | import json 11 | import random 12 | from datetime import datetime 13 | import re 14 | from torch.utils.data import Dataset 15 | 16 | 17 | TEMPLATES = { 18 | "QA": ( 19 | "Below is a question:\n" 20 | "{question}\n\n" 21 | "Below are related passages:\n" 22 | "{reference}\n\n" 23 | "Below is an answer:\n" 24 | "{response}\n\n" 25 | "Your task is to determine whether the summary contains either or both of the following two types of hallucinations:\n" 26 | "1. conflict: instances where the summary presents direct contraction or opposition to the original news;\n" 27 | "2. baseless info: instances where the generated summary includes information which is not substantiated by or inferred from the original news. \n" 28 | "Then, compile the labeled hallucinated spans into a JSON dict, with a key \"hallucination list\" and its value is a list of hallucinated spans. If there exist potential hallucinations, the output should be in the following JSON format: {{\"hallucination list\": [hallucination span1, hallucination span2, ...]}}. Otherwise, leave the value as a empty list as following: {{\"hallucination list\": []}}.\n" 29 | "Output:" 30 | ), 31 | "Summary": ( 32 | "Below is the original news:\n" 33 | "{reference}\n\n" 34 | "Below is a summary of the news:\n" 35 | "{response}\n" 36 | "Your task is to determine whether the summary contains either or both of the following two types of hallucinations:\n" 37 | "1. conflict: instances where the summary presents direct contraction or opposition to the original news;\n" 38 | "2. baseless info: instances where the generated summary includes information which is not substantiated by or inferred from the original news. \n" 39 | "Then, compile the labeled hallucinated spans into a JSON dict, with a key \"hallucination list\" and its value is a list of hallucinated spans. If there exist potential hallucinations, the output should be in the following JSON format: {{\"hallucination list\": [hallucination span1, hallucination span2, ...]}}. Otherwise, leave the value as a empty list as following: {{\"hallucination list\": []}}.\n" 40 | "Output:" 41 | ), 42 | "Data2txt": ( 43 | "Below is a structured data in the JSON format:\n" 44 | "{reference}\n\n" 45 | "Below is an overview article written in accordance with the structured data:\n" 46 | "{response}\n\n" 47 | "Your task is to determine whether the summary contains either or both of the following two types of hallucinations:\n" 48 | "1. conflict: instances where the summary presents direct contraction or opposition to the original news;\n" 49 | "2. baseless info: instances where the generated summary includes information which is not substantiated by or inferred from the original news. \n" 50 | "Then, compile the labeled hallucinated spans into a JSON dict, with a key \"hallucination list\" and its value is a list of hallucinated spans. If there exist potential hallucinations, the output should be in the following JSON format: {{\"hallucination list\": [hallucination span1, hallucination span2, ...]}}. Otherwise, leave the value as a empty list as following: {{\"hallucination list\": []}}.\n" 51 | "Output:" 52 | ), 53 | } 54 | 55 | B_INST, E_INST = "[INST]", "[/INST]" 56 | B_SYS, E_SYS = "<>\n", "\n<>\n\n" 57 | 58 | 59 | def process_dialog(dialog, tokenizer, min_turn_idx=0, return_prompt=False, train=False, train_on_context=-1): 60 | IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss 61 | assert len(dialog)>=2 62 | dialog = dialog[:2*len(dialog)//2] 63 | inputs = [] 64 | labels = [] 65 | total_turns = len(dialog)//2 66 | prompt = "" 67 | for turn_idx in range(total_turns): 68 | cur_turn_text = f"{B_INST} {dialog[turn_idx*2].strip()} {E_INST} {dialog[turn_idx*2+1].strip()}" 69 | 70 | turn_input = [tokenizer.bos_token_id]+ \ 71 | tokenizer.encode(cur_turn_text, 72 | add_special_tokens=False, 73 | truncation=False)+ \ 74 | [tokenizer.eos_token_id] 75 | if turn_idx>=min_turn_idx: 76 | cur_turn_only_input_text = f"{B_INST} {dialog[turn_idx*2].strip()} {E_INST}" 77 | turn_only_input = tokenizer.encode(cur_turn_only_input_text, 78 | add_special_tokens=False, 79 | truncation=False) 80 | turn_label = turn_input.copy() 81 | input_len = len(turn_only_input)+1 82 | for i in range(input_len): # plus one for bos 83 | turn_label[i] = IGNORE_INDEX 84 | prompt += cur_turn_only_input_text 85 | else: 86 | # for single turn training, we need to mask all history 87 | turn_label = [IGNORE_INDEX]*len(turn_input) 88 | prompt += cur_turn_text 89 | inputs.extend(turn_input) 90 | labels.extend(turn_label) 91 | if return_prompt: 92 | return prompt 93 | assert len(inputs)==len(labels) 94 | inputs = inputs[:tokenizer.model_max_length] 95 | labels = labels[:tokenizer.model_max_length] 96 | return inputs, labels 97 | 98 | def process_dialog_to_single_turn(data, tokenizer, return_prompt=False, meta=False, highlight=False, train=False): 99 | if data['task_type']=='QA': 100 | prompt = TEMPLATES[data['task_type']].format( 101 | question=data['question'], 102 | reference=data['reference'], 103 | response=data['response'] 104 | ) 105 | else: 106 | prompt = TEMPLATES[data['task_type']].format( 107 | reference=data['reference'], 108 | response=data['response'] 109 | ) 110 | if return_prompt: 111 | return prompt 112 | label = sorted(data['labels'], key=lambda x: x['start']) 113 | label_dict = { 114 | 'hallucination list': [x['text'] for x in label] 115 | } 116 | return process_dialog([prompt, json.dumps(label_dict, indent=2)], tokenizer) 117 | 118 | 119 | class CaseDetectDataset(Dataset): 120 | def __init__(self, tokenizer, args, train=True): 121 | self.ann = [] 122 | with open(args.train_file if train else args.eval_file, "r") as f: 123 | for line in f: 124 | d = json.loads(line) 125 | # if fold >= 0 and d.get("fold") == fold: 126 | # continue 127 | self.ann.append(d) 128 | 129 | self.train = train 130 | self.tokenizer = tokenizer 131 | 132 | def __len__(self): 133 | return len(self.ann) 134 | 135 | def __getitem__(self, index): 136 | ann = self.ann[index] 137 | inputs, labels = process_dialog_to_single_turn(ann, self.tokenizer, train=self.train) 138 | return { 139 | "input_ids": inputs, 140 | "labels": labels 141 | } 142 | 143 | -------------------------------------------------------------------------------- /baseline/predict_and_evaluate.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import AsyncInferenceClient 2 | from argparse import ArgumentParser 3 | from transformers import AutoTokenizer 4 | import json 5 | import asyncio 6 | from tqdm import tqdm 7 | import random 8 | import pandas as pd 9 | from sklearn.metrics import recall_score, precision_score, f1_score 10 | from dataset import process_dialog_to_single_turn 11 | 12 | # from utils import get_short_ctx, get_short_ctx_embedding 13 | parser = ArgumentParser() 14 | parser.add_argument('--raw_dataset', default="./test.jsonl") 15 | parser.add_argument('--output_file', default="./prediction.jsonl") 16 | parser.add_argument('--model_name', default='baseline') 17 | parser.add_argument('--tokenizer', default="meta-llama/Meta-Llama-3-8B") 18 | parser.add_argument('--meta', action='store_true') 19 | parser.add_argument('--fold', type=int, default=-1) 20 | args = parser.parse_args() 21 | 22 | embedder = None 23 | B_INST, E_INST = "[INST]", "[/INST]" 24 | 25 | client1 = AsyncInferenceClient(model="http://127.0.0.1:8300", timeout=100) 26 | 27 | clients = [client1] 28 | 29 | # actually we do not need tokenizer 30 | # just to meet the parameter requirements 31 | # tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B') 32 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) 33 | finished_count = 0 34 | 35 | async def generate_response(data, sem, pbar): 36 | 37 | global finished_count 38 | 39 | input_prompt = process_dialog_to_single_turn(data, tokenizer, return_prompt=True, train=False) 40 | input_prompt = f"{B_INST} {input_prompt.strip()} {E_INST}" 41 | client = random.choice(clients) 42 | for i in range(10): 43 | start = -1 44 | try: 45 | async with sem: 46 | answer = await client.text_generation(input_prompt, 47 | max_new_tokens=512, 48 | stream=False, 49 | do_sample=True, 50 | temperature=0.05, 51 | top_p=0.95, 52 | top_k=40) 53 | answer = answer.strip() 54 | answer = json.loads(answer) 55 | break 56 | except: 57 | print(input_prompt) 58 | print(answer.strip()) 59 | continue 60 | ret = dict(data) 61 | ret['pred'] = answer 62 | pbar.update(1) 63 | return ret 64 | 65 | async def main(args): 66 | idx = 0 67 | tasks = [] 68 | sem = asyncio.Semaphore(80) 69 | pbar = tqdm() 70 | 71 | with open(args.raw_dataset, 'r') as f: 72 | lines = f.readlines() 73 | for line in tqdm(lines): 74 | data = json.loads(line) 75 | if args.fold>=0: 76 | if data['fold']!=args.fold: 77 | continue 78 | tasks.append(asyncio.create_task(generate_response(data, sem, pbar))) 79 | idx += 1 80 | print("total tasks:", len(tasks)) 81 | pbar.reset(total=len(tasks)) 82 | results = await asyncio.gather(*tasks, return_exceptions=True) 83 | pbar.close() 84 | df = pd.DataFrame.from_records(results) 85 | df['is_halu'] = df['labels'].apply(lambda x: len(x)>0) 86 | df['pred_halu'] = df['pred'].apply(lambda x: len(x.get('hallucination list', []))>0) 87 | print(f"Case recall/precision/f1: {recall_score(df['is_halu'], df['pred_halu']):.3f}, {precision_score(df['is_halu'], df['pred_halu']):.3f}, {f1_score(df['is_halu'], df['pred_halu']):.3f}") 88 | for task in ['QA','Summary','Data2txt']: 89 | temp = df[df['task_type']==task] 90 | print(f"{task}-Case recall/precision/f1: {recall_score(temp['is_halu'], temp['pred_halu']):.3f}, {precision_score(temp['is_halu'], temp['pred_halu']):.3f}, {f1_score(temp['is_halu'], temp['pred_halu']):.3f}") 91 | 92 | bad_sample = 0 93 | with open(args.output_file, 'w') as f: 94 | for d in results: 95 | if isinstance(d, dict): 96 | f.write(json.dumps(d)+"\n") 97 | else: 98 | bad_sample += 1 99 | print(d) 100 | print(bad_sample) 101 | 102 | if __name__ == '__main__': 103 | asyncio.run(main(args)) 104 | 105 | 106 | -------------------------------------------------------------------------------- /baseline/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import random 3 | import re 4 | import json 5 | from os import path 6 | random.seed(2024) 7 | 8 | pattern = re.compile(r'\bnull\b') 9 | 10 | def get_json_data(data): 11 | to_save = [] 12 | for idx, row in data.iterrows(): 13 | d = row.to_dict() 14 | d.pop('prompt') 15 | d['type'] = 'response' 16 | label = sorted(d['labels'], key=lambda x: x['start']) 17 | format_label = {'baseless info': [], 'conflict': []} 18 | for l in label: 19 | if l['label_type'].lower().find('baseless')>=0: 20 | format_label['baseless info'].append(l['text']) 21 | else: 22 | format_label['conflict'].append(l['text']) 23 | d['format_label'] = format_label 24 | if row['task_type']=='QA': 25 | d['reference'] = row['source_info']['passages'] 26 | d['question'] = row['source_info']['question'] 27 | elif row['task_type']=='Summary': 28 | d['reference'] = row['source_info'] 29 | else: 30 | d['reference'] = f"{row['source_info']}" 31 | to_save.append(d) 32 | return to_save 33 | 34 | def read_ragtruth_split(ragtruth_dir, split): 35 | resp = pd.read_json(path.join(ragtruth_dir, 'response.jsonl'), lines=True) 36 | test = resp[(resp['split']==split)&(resp['quality']=='good')] 37 | oc = pd.read_json(path.join(ragtruth_dir, 'source_info.jsonl'), lines=True) 38 | test = test.merge(oc, on='source_id') 39 | print(test.shape) 40 | return test 41 | 42 | def get_id(item): 43 | return f"{item['id']}_{item['sentence_id']}_{item['model']}" 44 | 45 | # do not split sentence 46 | def get_data(): 47 | # process train 48 | # split into train and dev(10%) group by source_id 49 | # reference, prompt, labels, sentence 50 | data = read_ragtruth_split('../dataset', 'train') 51 | dev_source_id = [] 52 | for task in ['QA', 'Summary', 'Data2txt']: 53 | source_ids = data[data['task_type']==task]['source_id'].unique().tolist() 54 | dev_source_id.extend(random.sample(source_ids, 50)) 55 | 56 | train = data[~data['source_id'].isin(dev_source_id)].reset_index(drop=True) 57 | 58 | dev = data[data['source_id'].isin(dev_source_id)] 59 | print(dev['task_type'].value_counts()) 60 | train['fold'] = -1 61 | dev['fold']=-1 62 | train = get_json_data(train) 63 | dev = get_json_data(dev) 64 | with open(f'./train.jsonl', 'w') as f: 65 | for d in train: 66 | f.write(json.dumps(d)+"\n") 67 | 68 | with open(f'./dev.jsonl', 'w') as f: 69 | for d in dev: 70 | f.write(json.dumps(d)+"\n") 71 | 72 | test = read_ragtruth_split('../dataset', 'test') 73 | test = get_json_data(test) 74 | with open(f'./test.jsonl', 'w') as f: 75 | for d in test: 76 | f.write(json.dumps(d)+"\n") 77 | 78 | get_data() 79 | -------------------------------------------------------------------------------- /baseline/readme.md: -------------------------------------------------------------------------------- 1 | # Baseline 2 | How to run: 3 | 4 | 1. Generate training/evaluating data 5 | ``` 6 | python prepare_dataset.py 7 | ``` 8 | 9 | 2. Train model. Remove the leading two lines if you don't want to manually set WANDB configure. 10 | ``` 11 | WANDB_API_KEY={your key} \ 12 | WANDB_PROJECT={your project} \ 13 | CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nnodes 1 --nproc_per_node 4 train.py \ 14 | --model_name_or_path meta-llama/Llama-2-13b-hf \ 15 | --output_dir ./exp/baseline \ 16 | --do_train \ 17 | --dataset detect_yesno \ 18 | --num_train_epochs 1 \ 19 | --learning_rate 2e-5 \ 20 | --drop_neg_ratio -1 \ 21 | --train_file ./train.jsonl \ 22 | --eval_file ./dev.jsonl \ 23 | --bf16 True \ 24 | --tf32 True \ 25 | --use_flashatt_2 True \ 26 | --per_device_train_batch_size 8 \ 27 | --per_device_eval_batch_size 8 \ 28 | --gradient_accumulation_steps 1 \ 29 | --model_max_length 4096 \ 30 | --report_to wandb \ 31 | --ddp_find_unused_parameters False \ 32 | --logging_steps 1 \ 33 | --run_name baseline \ 34 | --lr_scheduler_type 'cosine' \ 35 | --warmup_ratio 0.1 \ 36 | --save_steps 10000 \ 37 | --save_total_limit 2 \ 38 | --overwrite_output_dir \ 39 | --eval_strategy steps \ 40 | --eval_steps 80 \ 41 | --fsdp "shard_grad_op auto_wrap" \ 42 | --fsdp_config ./configs/fsdp.json 43 | ``` 44 | 45 | 3. Evaluate model. We use `text-generation-inference` to serve the model. 46 | ``` 47 | model_path=baseline 48 | docker run -d --name baseline --gpus '"device=7"' -v $PWD:/data --shm-size 1g -p 8300:80 ghcr.io/huggingface/text-generation-inference:2.0.1 --model-id /data/exp/$model_path --dtype bfloat16 --max-total-tokens 8000 --sharded false --max-input-length 4095 49 | 50 | python predict_and_evaluate.py --model_name $model_path --tokenizer meta-llama/Llama-2-13b-hf 51 | ``` 52 | -------------------------------------------------------------------------------- /baseline/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 3 | import sys 4 | import os 5 | import json 6 | from typing import Dict, Optional, Sequence, List 7 | from dataclasses import dataclass, field, asdict 8 | 9 | import torch 10 | from peft import ( 11 | get_peft_model, 12 | prepare_model_for_kbit_training, 13 | set_peft_model_state_dict, 14 | LoraConfig 15 | ) 16 | 17 | from transformers import ( 18 | LlamaForCausalLM, 19 | LlamaTokenizer, 20 | LlamaConfig, 21 | default_data_collator, 22 | AutoTokenizer, 23 | AutoModelForCausalLM, 24 | DataCollatorForSeq2Seq, 25 | TrainingArguments, 26 | Trainer, 27 | TrainerCallback, 28 | HfArgumentParser 29 | ) 30 | 31 | from dataset import CaseDetectDataset 32 | 33 | 34 | def merge_dataclasses(dc1, dc2): 35 | return {**asdict(dc1), **asdict(dc2)} 36 | 37 | @dataclass 38 | class ModelArguments: 39 | model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-chat-hf") 40 | quantization: Optional[bool] = field(default=False) 41 | use_fast_kernels: Optional[bool] = field(default=False) 42 | use_flashatt_2: Optional[bool] = field(default=False) 43 | add_citation_token: bool = field(default=False) 44 | 45 | @dataclass 46 | class DataArguments: 47 | train_file: str = field(default=None) 48 | eval_file: str = field(default=None) 49 | use_system_prompt: bool = field(default=False) 50 | only_first_turn: bool = field(default=False) 51 | shuffle_ref: bool = field(default=False) 52 | drop_neg_ratio: float = field(default=0) 53 | 54 | 55 | @dataclass 56 | class TrainingArguments(TrainingArguments): 57 | dataset_config: Optional[str] = field(default=None) 58 | cache_dir: Optional[str] = field(default=None) 59 | optim: str = field(default="adamw_torch") 60 | model_max_length: int = field( 61 | default=4096, 62 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 63 | ) 64 | use_peft: Optional[bool] = field(default=False) 65 | peft_method: Optional[str] = field(default='lora') 66 | lora_r: Optional[int] = field(default=8) 67 | lora_alpha: Optional[int]= field(default=32) 68 | target_modules: str = field(default="q_proj,v_proj,k_proj,gate_proj,up_proj,down_proj") #,k_proj,gate_proj,up_proj,down_proj 69 | bias: Optional[str] = field(default="none") 70 | task_type: Optional[str] = field(default="CAUSAL_LM") 71 | lora_dropout: Optional[float]= field(default=0.05) 72 | inference_mode: Optional[bool] = field(default=False) 73 | 74 | class ProfCallback(TrainerCallback): 75 | def __init__(self, prof): 76 | self.prof = prof 77 | 78 | def on_step_end(self, args, state, control, **kwargs): 79 | self.prof.step() 80 | 81 | def main(): 82 | # Update the configuration for the training and sharding process 83 | parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 84 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 85 | 86 | # Load the tokenizer and add special tokens 87 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 88 | 89 | tokenizer.model_max_length = training_args.model_max_length 90 | tokenizer.padding_side = "right" 91 | 92 | if tokenizer.pad_token is None: 93 | if model_args.model_name_or_path.find("Qwen")>=0: 94 | # Qwen uses tiktoken, as a result, it has different special token names 95 | tokenizer.pad_token_id = tokenizer.eod_id 96 | tokenizer.bos_token_id = tokenizer.im_start_id 97 | tokenizer.eos_token_id = tokenizer.im_end_id 98 | elif model_args.model_name_or_path.find("Mistral")>=0 or model_args.model_name_or_path.find("Llama-3")>=0: 99 | tokenizer.pad_token = tokenizer.eos_token 100 | else: 101 | tokenizer.pad_token_id = 0 # unk. we want this to be different from the eos token 102 | 103 | 104 | # dataset_config = generate_dataset_config(data_args, {}) 105 | # merge dataset config into args for automatic logging 106 | training_args.dataset_config = str(data_args) 107 | 108 | # Load and preprocess the dataset for training and validation 109 | dataset_train = CaseDetectDataset( 110 | tokenizer, 111 | data_args, 112 | train=True 113 | ) 114 | 115 | dataset_val = CaseDetectDataset( 116 | tokenizer, 117 | data_args, 118 | train=False 119 | ) 120 | 121 | if training_args.bf16: 122 | torch_dtype=torch.bfloat16 123 | elif training_args.fp16: 124 | torch_dtype=torch.float16 125 | else: 126 | torch_dtype='auto' 127 | device_map = "auto" 128 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 129 | ddp = world_size != 1 130 | if ddp: 131 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} 132 | 133 | if training_args.do_train: 134 | model_path = model_args.model_name_or_path 135 | elif training_args.do_eval: 136 | model_path = training_args.output_dir 137 | model = AutoModelForCausalLM.from_pretrained( 138 | model_path, 139 | cache_dir=training_args.cache_dir, 140 | load_in_8bit=True if model_args.quantization else None, 141 | # device_map=device_map, 142 | torch_dtype = torch_dtype, 143 | trust_remote_code=True, 144 | attn_implementation="flash_attention_2" if model_args.use_flashatt_2 else "sdpa" 145 | ) 146 | 147 | 148 | if training_args.use_peft: 149 | lora_target_modules = training_args.target_modules.split(',') 150 | peft_config = LoraConfig( 151 | r=training_args.lora_r, 152 | lora_alpha=training_args.lora_alpha, 153 | target_modules=lora_target_modules, 154 | lora_dropout=training_args.lora_dropout, 155 | bias="none", 156 | task_type="CAUSAL_LM" 157 | ) 158 | 159 | if model_args.quantization: 160 | model = prepare_model_for_kbit_training(model, 161 | use_gradient_checkpointing=training_args.gradient_checkpointing) 162 | model = get_peft_model(model, peft_config) 163 | model.print_trainable_parameters() 164 | 165 | model.config.use_cache = False 166 | if training_args.gradient_checkpointing: 167 | if training_args.use_peft: 168 | model.enable_input_require_grads() 169 | # training_args.ddp_find_unused_parameters = False if ddp else None 170 | if not ddp and torch.cuda.device_count() > 1: 171 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available 172 | model.is_parallelizable = True 173 | model.model_parallel = True 174 | 175 | trainer = Trainer( 176 | model, 177 | args=training_args, 178 | # data_collator=PaddingCollactor(tokenizer, input_pad_id=0, max_length=training_args.model_max_length), 179 | data_collator = DataCollatorForSeq2Seq(tokenizer, padding='longest', max_length=training_args.model_max_length, pad_to_multiple_of=8), 180 | train_dataset=dataset_train, 181 | eval_dataset=dataset_val, 182 | tokenizer=tokenizer 183 | ) 184 | # if torch.__version__ >= "2" and sys.platform != "win32": 185 | # model = torch.compile(model) 186 | torch.cuda.empty_cache() 187 | 188 | if training_args.do_train: 189 | trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) 190 | tokenizer.save_pretrained(training_args.output_dir) 191 | trainer.save_state() 192 | trainer.save_model(output_dir=training_args.output_dir) 193 | elif training_args.do_eval: 194 | result = trainer.evaluate() 195 | print(result) 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /image/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ParticleMedia/RAGTruth/c103204b9ce28d6bbad859304bf30de72b8ed8fe/image/fig1.jpg --------------------------------------------------------------------------------