├── .gitignore ├── LICENSE ├── README.md ├── data ├── dev.tsv ├── external_nodes_dev.txt ├── internal_nodes_dev.txt ├── preds.tsv ├── refinement_graphs_train.tsv ├── relations.txt ├── test_no_labels.tsv └── train.tsv ├── eval_scripts ├── eval_ea.sh ├── eval_first.sh └── eval_seca.sh ├── metrics ├── __init__.py ├── compute_ea.py ├── create_ea_data.py ├── eval.py ├── graph_matching.py ├── run_ea.py ├── run_seca.py ├── utils_ea.py └── utils_seqa.py ├── model_scripts ├── test_stance_pred.sh ├── test_structured_model.sh ├── train_conceptnet_finetuning.sh ├── train_stance_pred.sh └── train_structured_model.sh ├── models └── README.md ├── requirements.txt ├── src ├── run_pl_stance_pred.py └── utils_stance_pred.py ├── structured_model ├── README.md ├── inference.py ├── joint_model.py ├── relation_model.py ├── run_commonsense_finetuning.py ├── run_joint_model.py ├── save_relation_embeddings.py ├── utils_joint_model.py └── utils_relation.py └── tmp └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Swarnadeep Saha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExplaGraphs 2 | Dataset and PyTorch code for our EMNLP 2021 paper: 3 | 4 | [ExplaGraphs: An Explanation Graph Generation Task for Structured Commonsense Reasoning](https://arxiv.org/abs/2104.07644) 5 | 6 | [Swarnadeep Saha](https://swarnahub.github.io/), [Prateek Yadav](https://prateek-yadav.github.io/), [Lisa Bauer](https://www.cs.unc.edu/~lbauer6/), and [Mohit Bansal](https://www.cs.unc.edu/~mbansal/) 7 | 8 | ## Website and Leaderboard 9 | ExplaGraphs is hosted [here](https://explagraphs.github.io/). 10 | You can find the leaderboard, a brief discussion of our dataset, evaluation metrics and some notes about how to submit predictions on the test set. 11 | 12 | ## Installation 13 | This repository is tested on Python 3.8.3. 14 | You should install ExplaGraphs on a virtual environment. All dependencies can be installed as follows: 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ## Dataset 20 | ExplaGraphs dataset can be found inside the ```data``` folder. 21 | 22 | It contains the training data in ```train.tsv```, the validation samples in ```dev.tsv``` and the test samples (without labels) in ```test.tsv```. 23 | 24 | Each training sample contains four tab-separated entries -- belief, argument, stance label and the explanation graph. 25 | 26 | The graph is organized as a bracketed string ```(edge_1)(edge_2)...(edge_n)```, where each edge is of the form ```concept_1; relation; concept_2```. 27 | 28 | ## Evaluation Metrics 29 | ExplaGraphs is a joint task that requires predicting both the stance label and the corresonding explanation graph. Independent of how you choose to represent the graphs in your models, you must represent the graphs as bracketed strings (as in our training data) in order to use our evaluation scripts. 30 | 31 | We propose multiple evaluation metrics as detailed in Section 6 of our paper. Below we provide the steps to use our evaluation scripts. 32 | 33 | ### Step 1 34 | First, we evaluate the graphs against all the non-model based metrics. This includes computing the stance accuracy (SA), Structural Correctness Accuracy for Graphs (StCA), G-BertScore (G-BS) and Graph Edit Distance (GED). Run the following script to get these. 35 | ``` 36 | bash eval_scripts/eval_first 37 | ``` 38 | This takes as input the gold file, predictions file, and the relations file and outputs an intermediate file ```annotations.tsv```. In this intermediate file, each sample is annotated with one of the three labels -- ```stance_incorrect```, ```struct_incorrect``` and ```struct_correct```. The first label denotes the samples where the predicted stance is incorrect, the second denotes the ones where the stance is correct but the graph is structurally incorrect and the third denotes the ones where the stance is correct and the graph is also structurally correct. 39 | 40 | Structural Correctness Evaluation requires satisfying all the constraints we define for the task, which include the graph be connected DAG with at least three edges and having at least two exactly matching concepts from the belief and two from the argument. You **SHOULD NOT** look to boost this accuracy up by some arbitrary post-hoc correction of structurally incorrect graphs (like adding a random edge to make a disconnected graph connected). 41 | 42 | Note that our evaluation framework is a pipeline, so the G-BS and GED metrics are computed only on the fraction of samples with annotation ```struct_correct```. 43 | 44 | ### Step 2 45 | Given this intermediate annotation file, we'll now compute the Semantic Correctness Accuracy for Graphs (SeCA). Once again, this will only evaluate the fraction of samples where the stance is correct and the graphs are structurally correct. It is a model-based metric and we release our pre-trained model [here](https://drive.google.com/drive/folders/1omxJhM7XG_QxBcO0cddJwld1mEvdKKDd?usp=sharing). Once you download the model, run the following script to get SeCA. 46 | ``` 47 | bash eval_scripts/eval_seca 48 | ``` 49 | 50 | ### Step 3 51 | In the final step, we compute the Edge Importance Accuracy (EA). This is again a model-based metric and you can download the pre-trained model [here](https://drive.google.com/drive/folders/1gVUGZRsIefFfRgg_EbtuckP-0vwZBgI9?usp=sharing). Once you download the model, run the following script 52 | ``` 53 | bash eval_scripts/eval_ea 54 | ``` 55 | This measures the importance of an edge by removing it from the predicted graph and checking for the difference in stance confidence (with and without it) according to the model. An increase denotes that the edge is important while a decrease suggests otherwise. 56 | 57 | ## Evaluating on the Test Set 58 | 59 | The test samples are available inside ```data``` folder. To evaluate your model on the test set, please email us your predictions at swarna@cs.unc.edu. 60 | 61 | The predictions should be generated in a ```tsv``` file with each line containing two tab-separated entries, first the predicted stance (support/counter) followed by the predicted graph in the same bracketed format as in the train and validation files. A sample prediction file in shown inside ```data``` folder. 62 | 63 | For all latest results on ExplaGraphs, please refer to the leaderboard [here](https://explagraphs.github.io/). 64 | 65 | ## Baseline Models 66 | 67 | For training the stance prediction model, run the following script 68 | ``` 69 | bash model_scripts/train_stance_pred.sh 70 | ``` 71 | Note that this belongs to the rationalizing model family where the stance is predicted first and the predicted stance is conditioned on to generate the explanation graph. If you wish to work with the reasoning model family, append the generated linearized graph to the input by appropriately changing the ```src/utils_stance_pred.py``` file. 72 | 73 | We also release the trained stance prediction model [here](https://drive.google.com/drive/folders/1THK-LxVpOY2G6VZp1bQbDlCVzynRXHGN?usp=sharing). You can test the model on the validation split by running the following script 74 | ``` 75 | bash model_scripts/test_stance_pred.sh 76 | ``` 77 | 78 | For training and testing our structured graph generation model, refer to the README inside ```structured_model```. 79 | 80 | BART and T5-based graph generation models are coming soon! 81 | 82 | ### Citation 83 | ``` 84 | @inproceedings{saha2021explagraphs, 85 | title={ExplaGraphs: An Explanation Graph Generation Task for Structured Commonsense Reasoning}, 86 | author={Saha, Swarnadeep and Yadav, Prateek and Bauer, Lisa and Bansal, Mohit}, 87 | booktitle={EMNLP}, 88 | year={2021} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /data/external_nodes_dev.txt: -------------------------------------------------------------------------------- 1 | freedom of speech, freedom of religion, separation of religion 2 | good for family, good for health and happiness, best for family 3 | false situations, bad for society, put people in danger 4 | strong families, united states, strong states, united countries 5 | harm to citizens, extorting from them, abuse citizens 6 | equality, same as before, same thing as before 7 | godly love, shouldn't be abandoned, should be reconsidered 8 | setting up a crime, justified, not a crime 9 | catch criminals, catch criminals quicker, less crime, more crime 10 | personal relationships, godly relationships, good for society 11 | personal decision, no compromise, same sex couples, same gender couples 12 | manipulation, causing a disturbance, causing an emotional reaction 13 | get married, most people in society want to get married 14 | urban areas are less dangerous, less dangerous than suburban areas 15 | personal decision, not government, choose location, choose ceremony 16 | many benefits, good for society, benefits for society 17 | catch criminals sooner, keep people safe, catch criminals quicker 18 | good for society, and is a good union for people 19 | false pretense, committed crime, innocent person involved 20 | catch a wanted person, catch a crime quicker, catch criminals quicker 21 | catch criminals quickly, catch criminals quicker, prevent crimes 22 | stable home, important to families, important for families 23 | equality, good homes, united states, good families 24 | deep seeded meaning, show love and trust, should be respected 25 | catch criminals first, innocents second, catch innocents quicker 26 | more people, less stress, more money, more jobs 27 | harm, banning it, harm to society, banned 28 | no longer necessary, no conscience, separation from husband 29 | no sex, unsexual, unsexifiable, unneeded 30 | healthy and stable relationships, good for society, bad for society 31 | deceptive, law abiding citizens, criminal behavior, bad for society 32 | freedom of religion, freedom of speech, legal system 33 | same sex, same sex couples, same gender couples 34 | catch criminals quickly, leads to arrests, good thing 35 | connection, no one believes it, no religious ceremony 36 | security and privileges, same as dating, same thing 37 | set someone up, causing crime, setting someone up 38 | difficult to capture, criminals evade capture, evade capture 39 | meaningful to the participants, no lasting effect, no meaningful purpose 40 | loss of natural habitats, reduce animals' lifespan, negative for society 41 | plentiful, not destroy too many, good for natural habitats 42 | increase in population, loss of natural habitats, increase in jobs 43 | poor people, bad for society, worse for society 44 | catch criminals quicker, protect citizens, uphold their rights 45 | peace and stability, no war, people feel safe 46 | manipulation, causing false situations, criminal behavior, bad for society 47 | unhealthy families, no longer needed, no good families 48 | strong families, united states, good homes, bad homes 49 | no strings attached, no financial obligations, no obligations 50 | increased war risk, negative effects, negative effect, negative consequences 51 | catch criminals quickly, catch criminals quicker, catching criminals quicker 52 | solve difficult environmental issues, people solve problems, solve problems 53 | immoral, immoral, legal, ethical grey area, immoral 54 | catch criminals quickly, catch criminals quicker, less offend 55 | equality, bad for society, bad thing for society 56 | different situations, different people, different families, different situations 57 | not real religion, many people are atheist, many are not religious 58 | harboring terrorists, bad people, dangerous people, no closure 59 | immoral, human cloning, human error, ethical dilemma 60 | negative effects, negative effect, harmful to society, harmful 61 | many beliefs, many people believe in god, good thing 62 | lack of opportunity, poor quality of life, underprivileged 63 | artificial intelligence, natural selection, human intelligence, artificial intelligence 64 | freedom of religion, freedom of expression, religion has a place 65 | plastic surgery, surgery can fix birth defects, better life 66 | not all life threatening diseases, can be tackled by human cloning 67 | increase pollution, bad for society, good for planet 68 | beliefs, not everyone believes, good and bad, good thing 69 | terror suspects, inhumane and degrading conditions, no humane treatment 70 | cosmetic surgery, not worth the risk, good for society 71 | saving lives, good for society, important for society 72 | war, no religion involved, peace, no conflict 73 | cruel and inhuman, innocent people in jail, inhumane conditions 74 | safe place, hold prisoners, safe places to stay 75 | security, help with prisoners, aiding in their recovery 76 | beliefs and values, not religious, not for everyone 77 | unknown information, reduce awareness, reduce knowledge, harm 78 | long lifespans, overpopulation, less overpopulation 79 | inhuman and degrading, inhuman and inhuman, humane and inhuman 80 | lots of space, more people living, less sprawl 81 | worship the same god, good thing, bad thing 82 | keep people safe, keep country safe, keeping people safe 83 | not know when, stop getting cosmetic surgery, good thing 84 | negative effects, poor self esteem, damaging self esteem 85 | improving self esteem, cosmetic surgery helps people feel better 86 | thinking more critically, religious beliefs, bad, bad 87 | security, people in prison, determining whether they are dangerous 88 | prevent diseases, not help fight diseases, no treatment 89 | less debt, less government debt, lower government debt 90 | keeping the world safe, keeping terrorists at bay, keeping world safe 91 | new innovations, new ideas, new treatments, new innovations 92 | bad for society, human cloning, harmful for society 93 | division, religious arguments, division, division among people 94 | more people, more places to live, more people can fit 95 | cruel and inhuman, inhuman and degrading, inhuman treatment 96 | life-saving treatment, good for society, important for society 97 | good things for humanity, funding for research, good thing for humanity 98 | personal decision, altering appearance, important to society, important for society 99 | not lead to abuse, open prison, no harm 100 | no breakthroughs, no human cloning, no new discoveries 101 | death, put innocent people at risk, freedom of speech 102 | efficiency, less debt, lower interest rates, good thing 103 | increase tourism, increase trade, increase tourist income, increases trade 104 | negative effects, poor people in poverty, more poverty 105 | expensive, need funding, good for society, bad for society 106 | economic growth, necessary for growth, cities have challenges 107 | increase urban population, negative effects on nature, reduce impacts 108 | negative effects, harmful to society, negative effects to society 109 | disfiguration, plastic surgery, cosmetic surgery can be dangerous 110 | increase population, good for society, good thing for society 111 | death, good thing, bad thing, no return 112 | not in us, not in the us, in us 113 | not real, no proof, no evidence, not real 114 | freedom of speech, freedom of expression, negative things 115 | economic sanctions, bad for society, negative for society 116 | cheapening life, cheapens life, makes copies 117 | disfiguration, cosmetic surgery, good, bad 118 | improving self esteem, plastic surgery, good self esteem 119 | bad, serving the purpose, mistreating prisoners badly 120 | beneficial to babies, not medical field, harmful to medical field 121 | negative effects, over taxing businesses, over taxed businesses 122 | plastic surgery, improving self esteem, bad for society 123 | life, living, not dying, good thing, living 124 | no god, natural selection, artificial intelligence, unnatural 125 | disfiguration, plastic surgery correct disfigurement 126 | war profiteering, good for the country, bad for society 127 | lack of resources, increase inequality, no money for others 128 | poor and sick, unfair to poor, sick and sick 129 | poor people, no money for food, remove programs 130 | poor people, more debt, less money, more poverty 131 | war, religious undertones, war, wariness 132 | god's will, new inventions, new innovations, new discoveries 133 | immoral, moral decline, moral grey area, immoral 134 | negative effects, human cloning, human death, negative effects 135 | small percentage of population, good thing, bad thing 136 | freedom of religion, freedom of speech, people should choose their religion 137 | new ideas, new inventions, new discoveries, new ideas 138 | lack of green space, destroy natural beauty, destroy cities 139 | change of heart, people change their mind, cosmetic surgery 140 | harsh, inhuman, inhuman and degrading treatment, unjustifiable 141 | prevent diseases, cure them, money in the bank 142 | security, danger to americans, military and intelligence, too physically close 143 | more people, moving near farms, causing more demand 144 | high crime, people migrate to cities, make money 145 | human error, damage to body, not recover quickly 146 | freedom of religion, unamerican constitution, legal system 147 | control, people's money, government control, society 148 | jobs, increase crime rates, more money in society 149 | saving someone's life, prevent death, cure diseases 150 | expensive, no money, not worth it, unnecessary 151 | positive for society, encourage economic growth, create jobs 152 | new ideas, new discoveries, new technologies, new ideas 153 | keep people safe, keep terrorists out, keep them safe 154 | funded by government, not by taxpayers, funded with their money 155 | improving health, important for society, good for society 156 | negative effects, not favor medical advancements, bad for society 157 | poor people, improve their circumstances, poor people's circumstances 158 | many diseases, good thing, bad thing, cure 159 | keep people in prison, keep people out of american soil 160 | destroy human life, ethical dilemma, ethical dilemmas 161 | self esteem, cosmetic surgery, negative effects, negative effect 162 | health benefits, human embryo, artificial embryo, good for babies 163 | new innovations, new treatments, better treatments, more people 164 | physical strength, people's mental health improved, physical strength improved 165 | remove green space, bad health effects, reduce green space 166 | more people, more money, more people want to invest 167 | inhuman and degrading, inhuman and inhuman, inhuman treatment 168 | treatments, prevent diseases, treat diseases, cure diseases 169 | negative effect, people who follow religion, negative effect 170 | sick people, bad for society, people get sick 171 | new innovations, new treatments, new innovations in medicine 172 | bad people, criminals need to be arrested, bad people 173 | war torn countries, good thing for the country, bad thing for us 174 | people feel whole, self esteem improves, self confidence 175 | new treatments, beneficial for society, new treatments for diseases 176 | plastic surgery, good for body, bad for society 177 | believe in fairy tales, teach lessons, good thing 178 | more people, more cars, less cars, more pollution 179 | open to the public, important for national security, no negative effects 180 | altruistic, playing god, altruistic and altruistic 181 | economic growth, economic recovery, social programs, positive effect 182 | new ideas, new innovations, new drugs, new treatments 183 | increase crime rates, pollution, can't get away 184 | bad for society, bad for democracy, negative effects 185 | negative effects, harmful to babies, harmful for society 186 | freedom of religion, people can choose, adhere to any religion 187 | destroy natural habitats, destroy natural environments, destroy animals 188 | greedy and out for themselves, game system, good thing 189 | negative effects, harmful to babies, harm to babies 190 | economic sanctions, cripple society economically, negative for society 191 | people like the countryside, not like the city, bad for country 192 | not popular, economic growth, less government funding, more debt 193 | bad men living there, public places, private places 194 | disfiguration, people can't maintain normal bodily functions 195 | kill babies, shouldn't be funded, should be funded 196 | immoral, people can be exploited, moral crisis can arise 197 | increase life expectancy, increase productivity, positive for society 198 | war, conflicts, holy lands and religions, conflicts 199 | less debt, more money in pockets, less debt 200 | cure for diseases, prevent diseases, cure for current diseases 201 | hurting the economy, negative effects, negative effect on society 202 | economic growth, increase productivity, create jobs, create more jobs 203 | poor and downtrodden, less government spending, better for society 204 | spread diseases, human cloning, human error, human mistakes 205 | cosmetic surgery, good for body, bad for society 206 | less division, less division in society, less harm 207 | new technology, important to society, important for society 208 | restriction, restriction, freedom of speech, restriction of speech 209 | death, replace loved ones, artificial birth, remove a person 210 | freedom of religion, freedom of speech, good thing 211 | people, less stress, more crime, less crime 212 | killing babies, government subsidize, harm to babies 213 | economic sanctions, negative effect on society, negative effects on society 214 | high interest rates, negative effects, negative effect on society 215 | reduction in debt, economic growth, economic recovery, reduce debt 216 | positive self esteem, people feel better, good self esteem 217 | no moral compass, no conscience, bad for society 218 | artificial intelligence, tampering with, natural selection, artificial intelligence 219 | reduce government funding, bad for society, cut government 220 | different beliefs, values and values, different beliefs and values 221 | good for society, good for mankind, good thing 222 | good for society, good for mankind, bad for society 223 | unnatural and unnatural, human beings, unnatural, unnatural 224 | lower debt, lower interest rates, more debt, less debt 225 | reduction in debt, reduce government debt, benefit society 226 | increase debt, increase interest rates, increases interest rates 227 | cruel and inhuman, torturous, inhuman, inhumane 228 | no wait, no need for transplants, eliminate wait 229 | increase economic growth, more people, less debt, more jobs 230 | artificial intelligence, not regulated, good thing, bad thing 231 | exploitation, poor people, exploitation, abuse, exploitation 232 | disfiguration, people with disfigurements, good thing 233 | loss of jobs, displaces people, benefits to society 234 | paid for by government, funded by private donors, private donors 235 | increase population, bad for society, good for society 236 | negative effects, poor health, poor quality of life 237 | many opinions, arguments, good for society, bad for society 238 | religious beliefs, bad for society, negative for society 239 | extra security, people in prison, people out of prison 240 | plastic surgery, not surgery, people need plastic surgery 241 | plastic surgery, people living normal lives, plastic surgery 242 | negative effects, human cloning, human error, negative effects 243 | new innovations, important to society, new innovations in medicine 244 | ethical issues, artificial birth, not natural birth, unnatural birth 245 | reduction in government spending, negative effects, negative effect 246 | damage self esteem, negative effects, damaging self esteem 247 | break up, negative effect, bad for society, bad thing for society 248 | traditional family values, religious devotion, devotion to god 249 | healthy relationships, parents can see, children can see 250 | same sex, same sex attraction, same gender attraction 251 | united states, united states ceremony, celebration of life 252 | catch criminals, catch criminals quicker, bad guys quicker 253 | catch criminals quicker, catch criminals quickly, good thing 254 | many people still value marriage, many people no longer value it 255 | bad parts, accepted by criminals, refusal to refuse 256 | catch criminals quickly, catch criminals quicker, catch them quicker 257 | strong families, united states, united countries, equality 258 | happy homes, parents and children, good homes, good environment 259 | manipulation, keep people honest, keep criminals honest, good thing 260 | traditional family values, no longer needed, society no longer needs 261 | honor, no longer relevant, outdated in modern society 262 | negative effects, people behave badly, harm others, negative effects 263 | catch criminals quicker, innocent people being charged, catch dangerous people 264 | strong families, important for society, important to families 265 | freedom of speech, equality, freedom of expression, equal treatment 266 | less trust, less evidence, less trust in police 267 | bad guys off street, good people off streets, bad criminals off streets 268 | criminals are off streets, get rid of, good thing 269 | still getting married, important for society, important to society 270 | legal quagmire, necessary for law enforcement, necessary 271 | many burdens, develop discipline, personal growth, emotional stability 272 | plastic surgery, damaging self esteem, harmful to self esteem 273 | disfiguration, cosmetic surgery for health reasons, unnecessary 274 | plastic surgery, not know when to stop, good thing 275 | freedom of religion, people become atheist, society benefits 276 | not religious, bad for christians, not good for society 277 | economic sanctions, poor people, no money for education 278 | safe harbor, terrorists in cuba, closure of prison 279 | artificial intelligence, artificial intelligence and human beings, only for curing 280 | negative effects, harmful to babies, not good for babies 281 | security, freedom of movement, people in prison, no wrongdoing 282 | funded by donors, altruistic donors, good thing 283 | recovery, not engage in austerity, for economic growth 284 | disfiguration, good for body, bad for body 285 | reduction in debt, economic growth quicker, less debt 286 | economic sanctions, economic hardship, negative effects, negative effect 287 | organs available, dying prematurely, good thing, dying quickly 288 | natural people, choose genes, natural people choose genes 289 | human beings, causing harm, human beings can't live without rules 290 | no cure, good thing to do, no cure 291 | expensive in theory, subsidize, expensive in practice 292 | confusion among people, artificial intelligence, human embryo, artificial embryo 293 | many different beliefs, many different opinions, different beliefs 294 | economic growth, economic recovery, positive effect, bad for society 295 | plastic surgery, good self esteem, people feel better 296 | less debt, lower interest rates, more debt, less debt 297 | cosmetic surgery, good for society, bad for society 298 | security and stability, war crimes and war crimes, closure 299 | disfiguration, negative effects, people questioning themselves 300 | safety and security, cosmetic surgery is always performed safely 301 | plastic surgery, good, people should get plastic surgery 302 | ethical and ethical issues, subsidize research, harm 303 | cosmetic surgery, cosmetic surgery serves a purpose, bad 304 | cruel and inhuman, inhuman and degrading, torturous 305 | lots of pollution, right to live, bad for society 306 | look the same, same as before, same thing as before 307 | exploration of new ideas, push moral boundaries, new discoveries 308 | good environment, better people, better environment, creating better people 309 | less money, more debt, less money for government 310 | different types of stem cells, not embryonic, cure diseases 311 | restricting beliefs, infringes on beliefs, violates human rights 312 | war torn countries, war torn nations, dangerous countries 313 | reducing debt, improving living standards, reduction in living standards 314 | cruel and inhuman, torturous, inhuman and degrading 315 | war crimes, reopening of prison, closure of prisons 316 | improving police morale, building trust, harm to society 317 | violate rights, solve crimes, less harm, less crime 318 | old fashioned, not right, should be reconsidered 319 | security, catch criminals quicker, put criminals in one spot 320 | many people, less conflicts, less fights, less conflict 321 | more people, crime level rise significantly, more crime 322 | crowded area, growth in population, hard to meet 323 | dangerous, not regulated, violation of christian beliefs 324 | inhuman and degrading treatment, justified, justified and justified 325 | survive, regardless of intelligence, the long run, the only way 326 | freedom of movement, held in prison, freedom of speech 327 | urbanization, pollution, health problems, urbanization 328 | negative effects, negative effects to society, negative effect to society 329 | opening up government's wallet, more funds, less debt 330 | inhuman and degrading treatment, long standing, inhuman treatment 331 | lack of religion, violates our rights, violate our rights 332 | cruel and inhuman, inhuman and degrading, inhuman treatment 333 | subjective, plastic surgery, bad for brain, poor for body 334 | security, people in prison, innocent people, not charged 335 | many unknowns, good for society, bad for society 336 | dangerous for babies, underprivileged, harmful for babies 337 | disfigured people, expensive cosmetic surgery, poor people 338 | people do things due to their religion, not due to religion 339 | beliefs, set by god, beliefs set by people 340 | negative effects, bad for society, good for society 341 | copying from copier, poor quality, bad quality 342 | positive effect, people feel comfortable, less stress, more money 343 | security, detection field, improved detection, improved security 344 | personal decision, autonomy, cosmetic surgery, good for society 345 | cosmetic surgery, bad for health, poor financial results 346 | alternative treatments, improving health, better treatments, better treatment 347 | plastic surgery, good for body, bad for mental health 348 | free will, forced to commit a crime, free will 349 | catch criminals quickly, catch criminals quicker, caught quickly 350 | freedom of religion, freedom of speech, equality of religion 351 | dangerous, violation of human rights, violate human rights 352 | same sex couples, same gender couples, no separation 353 | focus on criminals, focus on convictions, focus only on criminals 354 | many families, many branches, united states, many families 355 | pollution, bad for the environment, no good for society 356 | bad for society, end of date, bad for institution 357 | catch criminals quicker, easier to catch, catch criminals 358 | motivation, committed by criminals, not by government, not due to government 359 | catch criminals sooner, less crime, catch criminals later 360 | innocent people, guilty people commit crimes, innocent people wouldn't commit crimes 361 | catch criminals quicker, keep people honest, good thing 362 | freedom of religion, freedom of speech, religion in general 363 | religion, political party, no place in politics, religion in politics 364 | make sacrifices, see longer term benefits, get out of debt 365 | not in united states, under united states constitution, not united states 366 | plastic surgery, negative effects, self-image, negative effect 367 | people like living where there is culture, reduce impact 368 | less work, sacrifice earned income, less government funding 369 | spread goodwill, people in different countries, different cultures 370 | freedom of religion, freedom of speech, good thing 371 | belief, regardless of religion, belief in a god 372 | choice, not chosen by everyone, choice by everyone 373 | not religious beliefs, harm to babies, harmful to babies 374 | relaxation, lots of activities, people can do, relaxation 375 | bear the burden, repay debt, bear the debt 376 | beliefs, beliefs, freedom of religion, free speech 377 | alienate religous people, bad for society, good for society 378 | ethical and moral basis, funding of research, moral basis 379 | freedom of religion, including no god, no god 380 | murderers, innocent people, not guilty, not innocent people 381 | negative effects, reduce government debt, increase government debt 382 | urbanization, urban sprawl, pollution in suburbs 383 | artificial organs, failure in people, organs can fail 384 | negative effects, not good, should not be banned 385 | disfigured, feel more comfortable, look more normal 386 | less urbanization, better quality of life for citizens 387 | increase salary, take away income, increase salary level 388 | not natural, god has an issue, not natural 389 | reduction in moral standards, ethical dilemmas, good thing 390 | not as good as original, no guarantees, not as bad as original 391 | smaller number of prisoners, more effective, less collateral damage 392 | keep people safe, keep criminals safe, war on terror 393 | mistakes, causing disfiguration, permanent disfigurement 394 | not know enough about human cloning, dangerous for society 395 | based on facts, fun and entertaining, not based on stereotypes 396 | too high, too low crime, too crowded, too dangerous 397 | freedom of religion, not tied to beliefs, not linked to beliefs 398 | people with less money, solve problems, not hard working people -------------------------------------------------------------------------------- /data/preds.tsv: -------------------------------------------------------------------------------- 1 | support (concept1; relation1; concept2)(concept2; relation2; concept3) 2 | counter (concept1; relation1; concept2)(concept2; relation2; concept3)(concept1; relation3; concept4) 3 | -------------------------------------------------------------------------------- /data/relations.txt: -------------------------------------------------------------------------------- 1 | antonym of 2 | synonym of 3 | at location 4 | not at location 5 | capable of 6 | not capable of 7 | causes 8 | not causes 9 | created by 10 | not created by 11 | is a 12 | is not a 13 | desires 14 | not desires 15 | has subevent 16 | not has subevent 17 | part of 18 | not part of 19 | has context 20 | not has context 21 | has property 22 | not has property 23 | made of 24 | not made of 25 | receives action 26 | not receives action 27 | used for 28 | not used for -------------------------------------------------------------------------------- /eval_scripts/eval_ea.sh: -------------------------------------------------------------------------------- 1 | python metrics/create_ea_data.py \ 2 | --gold_file data/dev.tsv \ 3 | --eval_annotated_file data/annotations.tsv \ 4 | --output_EA_initial data/ea_initial.tsv \ 5 | --output_EA_final data/ea_final.tsv 6 | 7 | python metrics/run_ea.py \ 8 | --model_name_or_path ./models/ea-metric \ 9 | --task_name stance \ 10 | --do_eval \ 11 | --save_steps 1000000 \ 12 | --data_dir ./data/ea_initial.tsv \ 13 | --max_seq_length 128 \ 14 | --per_device_train_batch_size 32 \ 15 | --learning_rate 1e-5 \ 16 | --num_train_epochs 6.0 \ 17 | --output_dir ./models/ea-metric \ 18 | --cache_dir ./models/ \ 19 | --logging_steps 500 \ 20 | --evaluation_strategy="epoch" \ 21 | --overwrite_cache 22 | 23 | mv ./models/ea-metric/probs.txt ./models/ea-metric/probs_initial.txt 24 | 25 | python metrics/run_ea.py \ 26 | --model_name_or_path ./models/ea-metric \ 27 | --task_name stance \ 28 | --do_eval \ 29 | --save_steps 1000000 \ 30 | --data_dir ./data/ea_final.tsv \ 31 | --max_seq_length 128 \ 32 | --per_device_train_batch_size 32 \ 33 | --learning_rate 1e-5 \ 34 | --num_train_epochs 6.0 \ 35 | --output_dir ./models/ea-metric \ 36 | --cache_dir ./models/ \ 37 | --logging_steps 500 \ 38 | --evaluation_strategy="epoch" \ 39 | --overwrite_cache 40 | 41 | mv ./models/ea-metric/probs.txt ./models/ea-metric/probs_final.txt 42 | 43 | python metrics/compute_ea.py \ 44 | --initial_probs ./models/ea-metric/probs_initial.txt \ 45 | --final_probs ./models/ea-metric/probs_final.txt \ 46 | --initial_file ./data/ea_initial.tsv \ 47 | --final_file ./data/ea_final.tsv \ 48 | --gold_file data/dev.tsv \ 49 | 50 | -------------------------------------------------------------------------------- /eval_scripts/eval_first.sh: -------------------------------------------------------------------------------- 1 | python metrics/eval.py --pred_file data/preds.tsv --gold_file data/dev.tsv --relations_file data/relations.txt --eval_annotated_file data/annotations.tsv -------------------------------------------------------------------------------- /eval_scripts/eval_seca.sh: -------------------------------------------------------------------------------- 1 | python metrics/run_seca.py \ 2 | --model_name_or_path ./models/seca-metric \ 3 | --task_name stance \ 4 | --do_eval \ 5 | --save_steps 1000000 \ 6 | --data_dir ./data/annotations.tsv \ 7 | --max_seq_length 128 \ 8 | --per_device_train_batch_size 32 \ 9 | --learning_rate 1e-5 \ 10 | --num_train_epochs 3.0 \ 11 | --output_dir ./models/seca-metric \ 12 | --cache_dir ./models/ \ 13 | --logging_steps 500 \ 14 | --evaluation_strategy="epoch" \ 15 | --overwrite_cache 16 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swarnaHub/ExplaGraphs/67ecab19d9a13ab91e09e99bd94e0480b705c1b7/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/compute_ea.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | if __name__ == '__main__': 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument("--initial_probs", default=None, type=str, required=True) 6 | parser.add_argument("--final_probs", default=None, type=str, required=True) 7 | parser.add_argument("--initial_file", default=None, type=str, required=True) 8 | parser.add_argument("--final_file", default=None, type=str, required=True) 9 | parser.add_argument("--gold_file", default=None, type=str, required=True) 10 | 11 | args = parser.parse_args() 12 | 13 | golds = open(args.gold_file, "r", encoding="utf-8-sig").read().splitlines() 14 | initial_probs = open(args.initial_probs, "r", encoding="utf-8-sig").read().splitlines()[1:] 15 | final_probs = open(args.final_probs, "r", encoding="utf-8-sig").read().splitlines()[1:] 16 | initial_samples = open(args.initial_file, "r", encoding="utf-8-sig").read().splitlines() 17 | final_samples = open(args.final_file, "r", encoding="utf-8-sig").read().splitlines() 18 | 19 | index_to_line = {} 20 | gold_labels = {} 21 | for (i, sample) in enumerate(final_samples): 22 | index = int(sample.split("\t")[0]) 23 | if index not in index_to_line: 24 | index_to_line[index] = [i] 25 | else: 26 | index_to_line[index].append(i) 27 | gold_labels[index] = sample.split("\t")[3] 28 | 29 | initial_probs_converted = {} 30 | for (initial_sample, initial_prob) in zip(initial_samples, initial_probs): 31 | temp = [float(element) for element in initial_prob[1:-1].split(" ") if element != ""] 32 | assert len(temp) == 2 33 | index = initial_sample.split("\t")[0] 34 | initial_probs_converted[int(index)] = temp 35 | 36 | new_probs_converted = [] 37 | for final_prob in final_probs: 38 | temp = [float(element) for element in final_prob[1:-1].split(" ") if element != ""] 39 | assert len(temp) == 2 40 | new_probs_converted.append(temp) 41 | 42 | macro_increment_count = 0 43 | label_to_index = {"support": 0, "counter": 1} 44 | score_list = [] 45 | for index in index_to_line: 46 | lines = index_to_line[index] 47 | label = gold_labels[index] 48 | 49 | sample_increment_count = 0 50 | for line in lines: 51 | temp_final = new_probs_converted[line][label_to_index[label]] 52 | temp_initial = initial_probs_converted[index][label_to_index[label]] 53 | # An edge is important if it causes an increase in stance confidence 54 | if temp_initial > temp_final: 55 | sample_increment_count += 1 56 | 57 | # Average across all edges 58 | macro_increment_count += sample_increment_count / len(lines) 59 | 60 | # Average across all samples 61 | print(f'Edge Importance Accuracy (EA): {macro_increment_count / len(golds):.4f}') 62 | 63 | -------------------------------------------------------------------------------- /metrics/create_ea_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import networkx as nx 3 | 4 | 5 | def get_dfs_ordering(graph_string): 6 | graph = nx.DiGraph() 7 | nodes = [] 8 | relations = {} 9 | for edge in graph_string[1:-1].split(")("): 10 | parts = edge.split("; ") 11 | graph.add_edge(parts[0], parts[2]) 12 | if parts[0] not in nodes: 13 | nodes.append(parts[0]) 14 | if parts[2] not in nodes: 15 | nodes.append(parts[2]) 16 | relations[(parts[0], parts[2])] = parts[1] 17 | 18 | in_degrees = list(graph.in_degree(nodes)) 19 | 20 | start_nodes = [] 21 | for (i, node) in enumerate(nodes): 22 | if in_degrees[i][1] == 0: 23 | start_nodes.append(in_degrees[i][0]) 24 | 25 | dfs_edges = list(nx.edge_dfs(graph, source=start_nodes)) 26 | 27 | new_graph_string = "" 28 | for edge in dfs_edges: 29 | new_graph_string += "(" + edge[0] + "; " + relations[(edge[0], edge[1])] + "; " + edge[1] + ")" 30 | 31 | return new_graph_string 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument("--gold_file", default=None, type=str, required=True) 37 | parser.add_argument("--eval_annotated_file", default=None, type=str, required=True) 38 | parser.add_argument("--output_EA_initial", default=None, type=str, required=True) 39 | parser.add_argument("--output_EA_final", default=None, type=str, required=True) 40 | 41 | args = parser.parse_args() 42 | golds = open(args.gold_file, "r", encoding="utf-8-sig").read().splitlines() 43 | preds = open(args.eval_annotated_file, "r", encoding="utf-8-sig").read().splitlines() 44 | 45 | assert len(golds) == len(preds) 46 | 47 | output_initial = open(args.output_EA_initial, "w", encoding="utf-8-sig") 48 | output_final = open(args.output_EA_final, "w", encoding="utf-8-sig") 49 | 50 | for i, (gold, pred) in enumerate(zip(golds, preds)): 51 | gold_parts = gold.split("\t") 52 | belief, argument, stance = gold_parts[0], gold_parts[1], gold_parts[2] 53 | 54 | pred_parts = pred.split("\t") 55 | if pred_parts[3] != "struct_correct": 56 | continue 57 | 58 | graph = get_dfs_ordering(pred_parts[1]) 59 | for edge in graph[1:-1].split(")("): 60 | leave_one_out_graph = graph.replace("(" + edge + ")", "") 61 | leave_one_out_graph = leave_one_out_graph.replace("(", "").replace(";", "").replace(")", ". ") 62 | leave_one_out_argument = argument + " " + leave_one_out_graph 63 | output_final.write(str(i) + "\t" + belief + "\t" + leave_one_out_argument + "\t" + stance + "\n") 64 | 65 | whole_graph_argument = argument + " " + graph.replace("(", "").replace(";", "").replace(")", ". ") 66 | output_initial.write(str(i) + "\t" + belief + "\t" + whole_graph_argument + "\t" + stance + "\n") 67 | -------------------------------------------------------------------------------- /metrics/eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import networkx as nx 3 | import numpy as np 4 | from graph_matching import split_to_edges, get_tokens, get_bleu_rouge, get_bert_score, get_ged 5 | 6 | 7 | def is_edge_count_correct(edges): 8 | if len(edges) < 3: 9 | return False 10 | else: 11 | return True 12 | 13 | 14 | def is_graph(edges): 15 | for edge in edges: 16 | components = edge.split("; ") 17 | if len(components) != 3: 18 | return False 19 | 20 | return True 21 | 22 | 23 | def is_edge_structure_correct(edges, relations): 24 | for edge in edges: 25 | components = edge.split("; ") 26 | if components[0] == "" or len(components[0].split(" ")) > 3: 27 | return False 28 | if components[1] not in relations: 29 | return False 30 | if components[2] == "" or len(components[2].split(" ")) > 3: 31 | return False 32 | 33 | return True 34 | 35 | 36 | def two_concepts_belief_argument(edges, belief, argument): 37 | belief_concepts = {} 38 | argument_concepts = {} 39 | for edge in edges: 40 | components = edge.split("; ") 41 | if components[0] in belief: 42 | belief_concepts[components[0]] = True 43 | 44 | if components[2] in belief: 45 | belief_concepts[components[2]] = True 46 | 47 | if components[0] in argument: 48 | argument_concepts[components[0]] = True 49 | 50 | if components[2] in argument: 51 | argument_concepts[components[2]] = True 52 | 53 | if len(belief_concepts) < 2 or len(argument_concepts) < 2: 54 | return False 55 | else: 56 | return True 57 | 58 | 59 | def is_connected_DAG(edges): 60 | g = nx.DiGraph() 61 | for edge in edges: 62 | components = edge.split("; ") 63 | g.add_edge(components[0], components[2]) 64 | 65 | return nx.is_weakly_connected(g) and nx.is_directed_acyclic_graph(g) 66 | 67 | 68 | def get_max(first_precisions, first_recalls, first_f1s, second_precisions, second_recalls, second_f1s): 69 | max_indices = np.argmax(np.concatenate((np.expand_dims(first_f1s, axis=1), 70 | np.expand_dims(second_f1s, axis=1)), axis=1), axis=1) 71 | 72 | precisions = np.concatenate((np.expand_dims(first_precisions, axis=1), 73 | np.expand_dims(second_precisions, axis=1)), axis=1) 74 | precisions = np.choose(max_indices, precisions.T) 75 | 76 | recalls = np.concatenate((np.expand_dims(first_recalls, axis=1), 77 | np.expand_dims(second_recalls, axis=1)), axis=1) 78 | recalls = np.choose(max_indices, recalls.T) 79 | 80 | f1s = np.maximum(first_f1s, second_f1s) 81 | 82 | return precisions, recalls, f1s 83 | 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument("--pred_file", default=None, type=str, required=True) 88 | parser.add_argument("--gold_file", default=None, type=str, required=True) 89 | parser.add_argument("--relations_file", default=None, type=str, required=True) 90 | parser.add_argument("--eval_annotated_file", default=None, type=str, required=True) 91 | parser.add_argument("--test", action='store_true') 92 | 93 | args = parser.parse_args() 94 | 95 | preds = open(args.pred_file, "r", encoding="utf-8-sig").read().splitlines() 96 | golds = open(args.gold_file, "r", encoding="utf-8-sig").read().splitlines() 97 | relations = open(args.relations_file, "r", encoding="utf-8-sig").read().splitlines() 98 | eval_annotations = open(args.eval_annotated_file, "w", encoding="utf-8-sig") 99 | 100 | assert len(preds) == len(golds) 101 | 102 | stance_correct_count = 0 103 | structurally_correct_graphs_count = 0 104 | structurally_correct_gold_graphs, structurally_correct_second_gold_graphs, structurally_correct_pred_graphs = [], [], [] 105 | overall_ged = 0. 106 | for (pred, gold) in zip(preds, golds): 107 | parts = pred.split("\t") 108 | 109 | assert len(parts) == 2 110 | 111 | pred_stance = parts[0] 112 | pred_graph = parts[1].lower() 113 | 114 | assert pred_stance in ["support", "counter"] 115 | 116 | parts = gold.split("\t") 117 | belief = parts[0].lower() 118 | argument = parts[1].lower() 119 | gold_stance = parts[2] 120 | gold_graph = parts[3].lower() 121 | if args.test: 122 | second_gold_graph = parts[4].lower() 123 | 124 | # Check for Stance Correctness first 125 | if pred_stance == gold_stance: 126 | stance_correct_count += 1 127 | edges = pred_graph[1:-1].split(")(") 128 | # Check for Structural Correctness of graphs 129 | if is_edge_count_correct(edges) and is_graph(edges) and is_edge_structure_correct(edges, 130 | relations) and two_concepts_belief_argument( 131 | edges, belief, argument) and is_connected_DAG(edges): 132 | structurally_correct_graphs_count += 1 133 | eval_annotations.write(belief + "\t" + pred_graph + "\t" + gold_stance + "\tstruct_correct\n") 134 | 135 | # Save the graphs for Graph Matching or Semantic Correctness Evaluation 136 | structurally_correct_gold_graphs.append(gold_graph) 137 | if args.test: 138 | structurally_correct_second_gold_graphs.append(second_gold_graph) 139 | 140 | structurally_correct_pred_graphs.append(pred_graph) 141 | 142 | # Compute GED 143 | ged = get_ged(gold_graph, pred_graph) 144 | if args.test: 145 | ged = min(ged, get_ged(second_gold_graph, pred_graph)) 146 | else: 147 | eval_annotations.write(belief + "\t" + pred_graph + "\t" + gold_stance + "\tstruct_incorrect\n") 148 | # GED needs to be computed as the upper bound for structurally incorrect graphs 149 | ged = get_ged(gold_graph) 150 | if args.test: 151 | ged = min(ged, get_ged(second_gold_graph)) 152 | else: 153 | # GED also needs to be computed as the upper bound for samples with incorrect stance 154 | ged = get_ged(gold_graph) 155 | if args.test: 156 | ged = min(ged, get_ged(second_gold_graph)) 157 | eval_annotations.write(belief + "\t" + pred_graph + "\t" + gold_stance + "\tstance_incorrect\n") 158 | 159 | overall_ged += ged 160 | 161 | 162 | # Evaluate for Graph Matching 163 | gold_edges = split_to_edges(structurally_correct_gold_graphs) 164 | second_gold_edges = split_to_edges(structurally_correct_second_gold_graphs) if args.test else None 165 | pred_edges = split_to_edges(structurally_correct_pred_graphs) 166 | 167 | gold_tokens, pred_tokens, second_gold_tokens = get_tokens(gold_edges, pred_edges, second_gold_edges) 168 | 169 | precisions_rouge, recalls_rouge, f1s_rouge, precisions_bleu, recalls_bleu, f1s_bleu = get_bleu_rouge( 170 | gold_tokens, pred_tokens, gold_edges, pred_edges) 171 | 172 | precisions_BS, recalls_BS, f1s_BS = get_bert_score(gold_edges, pred_edges) 173 | 174 | # Get max of two gold graphs 175 | if args.test: 176 | second_precisions_rouge, second_recalls_rouge, second_f1s_rouge, second_precisions_bleu, second_recalls_bleu, \ 177 | second_f1s_bleu = get_bleu_rouge(second_gold_tokens, pred_tokens, second_gold_edges, pred_edges) 178 | 179 | second_precisions_BS, second_recalls_BS, second_f1s_BS = get_bert_score(second_gold_edges, pred_edges) 180 | 181 | precisions_bleu, recalls_bleu, f1s_bleu = get_max(precisions_bleu, recalls_bleu, f1s_bleu, 182 | second_precisions_bleu, second_recalls_bleu, second_f1s_bleu) 183 | precisions_rouge, recalls_rouge, f1s_rouge = get_max(precisions_rouge, recalls_rouge, f1s_rouge, 184 | second_precisions_rouge, second_recalls_rouge, 185 | second_f1s_rouge) 186 | precisions_BS, recalls_BS, f1s_BS = get_max(precisions_BS, recalls_BS, f1s_BS, 187 | second_precisions_BS, second_recalls_BS, second_f1s_BS) 188 | 189 | 190 | print(f'Stance Accuracy (SA): {stance_correct_count / len(golds):.4f}') 191 | print(f'Structural Correctness Accuracy (StCA): {structurally_correct_graphs_count / len(golds):.4f}') 192 | 193 | print(f'G-BLEU Precision: {precisions_bleu.sum() / len(golds):.4f}') 194 | print(f'G-BLEU Recall: {recalls_bleu.sum() / len(golds):.4f}') 195 | print(f'G-BLEU F1: {f1s_bleu.sum() / len(golds):.4f}\n') 196 | 197 | print(f'G-Rouge Precision: {precisions_rouge.sum() / len(golds):.4f}') 198 | print(f'G-Rouge Recall Score: {recalls_rouge.sum() / len(golds):.4f}') 199 | print(f'G-Rouge F1 Score: {f1s_rouge.sum() / len(golds):.4f}') 200 | 201 | print(f'G-BertScore Precision Score: {precisions_BS.sum() / len(golds):.4f}') 202 | print(f'G-BertScore Recall Score: {recalls_BS.sum() / len(golds):.4f}') 203 | print(f'G-BertScore F1 Score: {f1s_BS.sum() / len(golds):.4f}\n') 204 | 205 | print(f'Graph Edit Distance (GED): {overall_ged / len(golds):.4f}\n') 206 | -------------------------------------------------------------------------------- /metrics/graph_matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rouge_score import rouge_scorer 3 | from bert_score import score as score_bert 4 | from nltk.translate.bleu_score import sentence_bleu 5 | from scipy.optimize import linear_sum_assignment 6 | from spacy.tokenizer import Tokenizer 7 | from spacy.lang.en import English 8 | import re 9 | import networkx as nx 10 | 11 | 12 | def get_tokens(gold_edges, pred_edges, second_gold_edges): 13 | nlp = English() 14 | tokenizer = Tokenizer(nlp.vocab, infix_finditer=re.compile(r'''[;]''').finditer) 15 | 16 | gold_tokens = [] 17 | pred_tokens = [] 18 | second_gold_tokens = [] 19 | 20 | for i in range(len(gold_edges)): 21 | gold_tokens_edges = [] 22 | pred_tokens_edges = [] 23 | 24 | for sample in tokenizer.pipe(gold_edges[i]): 25 | gold_tokens_edges.append([j.text for j in sample]) 26 | for sample in tokenizer.pipe(pred_edges[i]): 27 | pred_tokens_edges.append([j.text for j in sample]) 28 | gold_tokens.append(gold_tokens_edges) 29 | pred_tokens.append(pred_tokens_edges) 30 | 31 | if second_gold_edges is not None: 32 | second_gold_tokens_edges = [] 33 | for sample in tokenizer.pipe(second_gold_edges[i]): 34 | second_gold_tokens_edges.append([j.text for j in sample]) 35 | second_gold_tokens.append(second_gold_tokens_edges) 36 | 37 | return gold_tokens, pred_tokens, second_gold_tokens 38 | 39 | 40 | def split_to_edges(graphs): 41 | processed_graphs = [] 42 | for graph in graphs: 43 | processed_graphs.append([re.sub('[)(]', '', g.lower().strip()) for g in graph.split(')(')]) 44 | return processed_graphs 45 | 46 | 47 | def get_bert_score(all_gold_edges, all_pred_edges): 48 | references = [] 49 | candidates = [] 50 | 51 | ref_cand_index = {} 52 | for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges): 53 | for (i, gold_edge) in enumerate(gold_edges): 54 | for (j, pred_edge) in enumerate(pred_edges): 55 | references.append(gold_edge) 56 | candidates.append(pred_edge) 57 | ref_cand_index[(gold_edge, pred_edge)] = len(references) - 1 58 | 59 | _, _, bs_F1 = score_bert(cands=candidates, refs=references, lang='en', idf=False) 60 | print("Computed bert scores for all pairs") 61 | 62 | precisions, recalls, f1s = [], [], [] 63 | for (gold_edges, pred_edges) in zip(all_gold_edges, all_pred_edges): 64 | score_matrix = np.zeros((len(gold_edges), len(pred_edges))) 65 | for (i, gold_edge) in enumerate(gold_edges): 66 | for (j, pred_edge) in enumerate(pred_edges): 67 | score_matrix[i][j] = bs_F1[ref_cand_index[(gold_edge, pred_edge)]] 68 | 69 | row_ind, col_ind = linear_sum_assignment(score_matrix, maximize=True) 70 | 71 | sample_precision = score_matrix[row_ind, col_ind].sum() / len(pred_edges) 72 | sample_recall = score_matrix[row_ind, col_ind].sum() / len(gold_edges) 73 | 74 | precisions.append(sample_precision) 75 | recalls.append(sample_recall) 76 | f1s.append(2 * sample_precision * sample_recall / (sample_precision + sample_recall)) 77 | 78 | return np.array(precisions), np.array(recalls), np.array(f1s) 79 | 80 | 81 | # Note: These graph matching metrics are computed by considering each graph as a set of edges and each edge as a 82 | # sentence 83 | def get_bleu_rouge(gold_tokens, pred_tokens, gold_sent, pred_sent): 84 | scorer_rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rougeL'], use_stemmer=True) 85 | 86 | precisions_bleu = [] 87 | recalls_bleu = [] 88 | f1s_bleu = [] 89 | 90 | precisions_rouge = [] 91 | recalls_rouge = [] 92 | f1s_rouge = [] 93 | 94 | for graph_idx in range(len(gold_tokens)): 95 | score_bleu = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx]))) 96 | score_rouge = np.zeros((len(pred_tokens[graph_idx]), len(gold_tokens[graph_idx]))) 97 | for p_idx in range(len(pred_tokens[graph_idx])): 98 | for g_idx in range(len(gold_tokens[graph_idx])): 99 | score_bleu[p_idx, g_idx] = sentence_bleu([gold_tokens[graph_idx][g_idx]], pred_tokens[graph_idx][p_idx]) 100 | score_rouge[p_idx, g_idx] = \ 101 | scorer_rouge.score(gold_sent[graph_idx][g_idx], pred_sent[graph_idx][p_idx])['rouge2'].precision 102 | 103 | def _scores(cost_matrix): 104 | row_ind, col_ind = linear_sum_assignment(cost_matrix, maximize=True) 105 | precision = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[0] 106 | recall = cost_matrix[row_ind, col_ind].sum() / cost_matrix.shape[1] 107 | f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0 108 | return precision, recall, f1 109 | 110 | precision_bleu, recall_bleu, f1_bleu = _scores(score_bleu) 111 | precisions_bleu.append(precision_bleu) 112 | recalls_bleu.append(recall_bleu) 113 | f1s_bleu.append(f1_bleu) 114 | 115 | precision_rouge, recall_rouge, f1_rouge = _scores(score_rouge) 116 | precisions_rouge.append(precision_rouge) 117 | recalls_rouge.append(recall_rouge) 118 | f1s_rouge.append(f1_rouge) 119 | 120 | return np.array(precisions_rouge), np.array(recalls_rouge), np.array(f1s_rouge), np.array( 121 | precisions_bleu), np.array(recalls_bleu), np.array(f1s_bleu) 122 | 123 | 124 | def return_eq_node(node1, node2): 125 | return node1['label'] == node2['label'] 126 | 127 | 128 | def return_eq_edge(edge1, edge2): 129 | return edge1['label'] == edge2['label'] 130 | 131 | 132 | def get_ged(gold_graph, pred_graph=None): 133 | g1 = nx.DiGraph() 134 | g2 = nx.DiGraph() 135 | 136 | for edge in gold_graph[1:-1].split(")("): 137 | parts = edge.split("; ") 138 | g1.add_node(parts[0], label=parts[0]) 139 | g1.add_node(parts[2], label=parts[2]) 140 | g1.add_edge(parts[0], parts[2], label=parts[1]) 141 | 142 | # The upper bound is defined wrt the graph for which GED is the worst. 143 | # Since ExplaGraphs (by construction) allows a maximum of 8 edges, the worst GED = gold_nodes + gold_edges + 8 + 9. 144 | # This happens when the predicted graph is linear with 8 edges and 9 nodes. 145 | # In such a case, for GED to be the worst, we assume that all nodes and edges of the predicted graph are deleted and 146 | # then all nodes and edges of the gold graph are added. 147 | # Note that a stricter upper bound can be computed by considering some replacement operations but we ignore that for convenience 148 | normalizing_constant = g1.number_of_nodes() + g1.number_of_edges() + 17 149 | 150 | if pred_graph is None: 151 | return 1 152 | 153 | for edge in pred_graph[1:-1].split(")("): 154 | parts = edge.split("; ") 155 | g2.add_node(parts[0], label=parts[0]) 156 | g2.add_node(parts[2], label=parts[2]) 157 | g2.add_edge(parts[0], parts[2], label=parts[1]) 158 | 159 | ged = nx.graph_edit_distance(g1, g2, node_match=return_eq_node, edge_match=return_eq_edge) 160 | 161 | assert ged <= normalizing_constant 162 | 163 | return ged / normalizing_constant 164 | -------------------------------------------------------------------------------- /metrics/run_ea.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Callable, Dict, Optional 7 | 8 | import numpy as np 9 | import argparse 10 | 11 | from scipy.special import softmax 12 | 13 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction 14 | from transformers import GlueDataTrainingArguments as DataTrainingArguments 15 | from transformers import ( 16 | HfArgumentParser, 17 | Trainer, 18 | TrainingArguments, 19 | set_seed, 20 | ) 21 | 22 | from utils_ea import stance_output_modes, stance_num_labels, StanceDataset, compute_metrics 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | @dataclass 28 | class ModelArguments: 29 | """ 30 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 31 | """ 32 | 33 | model_name_or_path: str = field( 34 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 35 | ) 36 | config_name: Optional[str] = field( 37 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 38 | ) 39 | tokenizer_name: Optional[str] = field( 40 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 41 | ) 42 | cache_dir: Optional[str] = field( 43 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 44 | ) 45 | 46 | 47 | def main(): 48 | # See all possible arguments in src/transformers/training_args.py 49 | # or by passing the --help flag to this script. 50 | # We now keep distinct sets of args, for a cleaner separation of concerns. 51 | 52 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 53 | 54 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 55 | # If we pass only one argument to the script and it's the path to a json file, 56 | # let's parse it to get our arguments. 57 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 58 | else: 59 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 60 | 61 | if ( 62 | os.path.exists(training_args.output_dir) 63 | and os.listdir(training_args.output_dir) 64 | and training_args.do_train 65 | and not training_args.overwrite_output_dir 66 | ): 67 | raise ValueError( 68 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 69 | ) 70 | 71 | # Setup logging 72 | logging.basicConfig( 73 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 74 | datefmt="%m/%d/%Y %H:%M:%S", 75 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 76 | ) 77 | logger.warning( 78 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 79 | training_args.local_rank, 80 | training_args.device, 81 | training_args.n_gpu, 82 | bool(training_args.local_rank != -1), 83 | training_args.fp16, 84 | ) 85 | logger.info("Training/evaluation parameters %s", training_args) 86 | 87 | # Set seed 88 | set_seed(training_args.seed) 89 | 90 | try: 91 | num_labels = stance_num_labels[data_args.task_name] 92 | output_mode = stance_output_modes[data_args.task_name] 93 | except KeyError: 94 | raise ValueError("Task not found: %s" % (data_args.task_name)) 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--pred_file", default=None, type=str, required=True) 98 | 99 | # Load pretrained model and tokenizer 100 | # 101 | # Distributed training: 102 | # The .from_pretrained methods guarantee that only one local process can concurrently 103 | # download model & vocab. 104 | 105 | config = AutoConfig.from_pretrained( 106 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 107 | num_labels=num_labels, 108 | finetuning_task=data_args.task_name, 109 | cache_dir=model_args.cache_dir, 110 | ) 111 | tokenizer = AutoTokenizer.from_pretrained( 112 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 113 | cache_dir=model_args.cache_dir, 114 | ) 115 | model = AutoModelForSequenceClassification.from_pretrained( 116 | model_args.model_name_or_path, 117 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 118 | config=config, 119 | cache_dir=model_args.cache_dir, 120 | ) 121 | 122 | # Get datasets 123 | train_dataset = ( 124 | StanceDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None 125 | ) 126 | eval_dataset = ( 127 | StanceDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir) 128 | if training_args.do_eval 129 | else None 130 | ) 131 | test_dataset = ( 132 | StanceDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 133 | if training_args.do_predict 134 | else None 135 | ) 136 | 137 | def build_compute_metrics_fn(task_name: str, output_dir: str) -> Callable[[EvalPrediction], Dict]: 138 | def compute_metrics_fn(p: EvalPrediction): 139 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 140 | probs = softmax(preds, axis=1) 141 | with open(os.path.join(output_dir, "probs.txt"), "w", encoding="utf-8-sig") as f: 142 | f.write(str(eval_dataset.get_labels()) + "\n") 143 | for prob in probs: 144 | f.write(str(prob) + "\n") 145 | if output_mode == "classification": 146 | preds = np.argmax(preds, axis=1) 147 | else: # regression 148 | preds = np.squeeze(preds) 149 | return compute_metrics(task_name, output_dir, preds, p.label_ids) 150 | 151 | return compute_metrics_fn 152 | 153 | # Initialize our Trainer 154 | trainer = Trainer( 155 | model=model, 156 | args=training_args, 157 | train_dataset=train_dataset, 158 | eval_dataset=eval_dataset, 159 | compute_metrics=build_compute_metrics_fn(data_args.task_name, training_args.output_dir), 160 | ) 161 | 162 | # Training 163 | if training_args.do_train: 164 | trainer.train( 165 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 166 | ) 167 | trainer.save_model() 168 | # For convenience, we also re-save the tokenizer to the same directory, 169 | # so that you can share your model easily on huggingface.co/models =) 170 | if trainer.is_world_master(): 171 | tokenizer.save_pretrained(training_args.output_dir) 172 | 173 | # Evaluation 174 | eval_results = {} 175 | if training_args.do_eval: 176 | logger.info("*** Evaluate ***") 177 | 178 | eval_datasets = [eval_dataset] 179 | 180 | for eval_dataset in eval_datasets: 181 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name, training_args.output_dir) 182 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 183 | 184 | output_eval_file = os.path.join( 185 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 186 | ) 187 | if trainer.is_world_master(): 188 | with open(output_eval_file, "w") as writer: 189 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 190 | for key, value in eval_result.items(): 191 | logger.info(" %s = %s", key, value) 192 | writer.write("%s = %s\n" % (key, value)) 193 | 194 | eval_results.update(eval_result) 195 | 196 | if training_args.do_predict: 197 | logging.info("*** Test ***") 198 | test_datasets = [test_dataset] 199 | 200 | for test_dataset in test_datasets: 201 | predictions = trainer.predict(test_dataset=test_dataset).predictions 202 | if output_mode == "classification": 203 | predictions = np.argmax(predictions, axis=1) 204 | 205 | output_test_file = os.path.join( 206 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 207 | ) 208 | if trainer.is_world_master(): 209 | with open(output_test_file, "w") as writer: 210 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 211 | writer.write("index\tprediction\n") 212 | for index, item in enumerate(predictions): 213 | if output_mode == "regression": 214 | writer.write("%d\t%3.3f\n" % (index, item)) 215 | else: 216 | item = test_dataset.get_labels()[item] 217 | writer.write("%d\t%s\n" % (index, item)) 218 | return eval_results 219 | 220 | 221 | def _mp_fn(index): 222 | # For xla_spawn (TPUs) 223 | main() 224 | 225 | 226 | if __name__ == "__main__": 227 | main() -------------------------------------------------------------------------------- /metrics/run_seca.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Callable, Dict, Optional 7 | 8 | import numpy as np 9 | 10 | from scipy.special import softmax 11 | 12 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction 13 | from transformers import GlueDataTrainingArguments as DataTrainingArguments 14 | from transformers import ( 15 | HfArgumentParser, 16 | Trainer, 17 | TrainingArguments, 18 | set_seed, 19 | ) 20 | 21 | from utils_seca import stance_output_modes, stance_num_labels, StanceDataset, compute_metrics 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | @dataclass 27 | class ModelArguments: 28 | """ 29 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 30 | """ 31 | 32 | model_name_or_path: str = field( 33 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 34 | ) 35 | config_name: Optional[str] = field( 36 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 37 | ) 38 | tokenizer_name: Optional[str] = field( 39 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 40 | ) 41 | cache_dir: Optional[str] = field( 42 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 43 | ) 44 | 45 | 46 | def main(): 47 | # See all possible arguments in src/transformers/training_args.py 48 | # or by passing the --help flag to this script. 49 | # We now keep distinct sets of args, for a cleaner separation of concerns. 50 | 51 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 52 | 53 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 54 | # If we pass only one argument to the script and it's the path to a json file, 55 | # let's parse it to get our arguments. 56 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 57 | else: 58 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 59 | 60 | if ( 61 | os.path.exists(training_args.output_dir) 62 | and os.listdir(training_args.output_dir) 63 | and training_args.do_train 64 | and not training_args.overwrite_output_dir 65 | ): 66 | raise ValueError( 67 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 68 | ) 69 | 70 | # Setup logging 71 | logging.basicConfig( 72 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 73 | datefmt="%m/%d/%Y %H:%M:%S", 74 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 75 | ) 76 | logger.warning( 77 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 78 | training_args.local_rank, 79 | training_args.device, 80 | training_args.n_gpu, 81 | bool(training_args.local_rank != -1), 82 | training_args.fp16, 83 | ) 84 | logger.info("Training/evaluation parameters %s", training_args) 85 | 86 | # Set seed 87 | set_seed(training_args.seed) 88 | 89 | try: 90 | num_labels = stance_num_labels[data_args.task_name] 91 | output_mode = stance_output_modes[data_args.task_name] 92 | except KeyError: 93 | raise ValueError("Task not found: %s" % (data_args.task_name)) 94 | 95 | # Load pretrained model and tokenizer 96 | # 97 | # Distributed training: 98 | # The .from_pretrained methods guarantee that only one local process can concurrently 99 | # download model & vocab. 100 | 101 | config = AutoConfig.from_pretrained( 102 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 103 | num_labels=num_labels, 104 | finetuning_task=data_args.task_name, 105 | cache_dir=model_args.cache_dir, 106 | ) 107 | tokenizer = AutoTokenizer.from_pretrained( 108 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 109 | cache_dir=model_args.cache_dir, 110 | ) 111 | model = AutoModelForSequenceClassification.from_pretrained( 112 | model_args.model_name_or_path, 113 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 114 | config=config, 115 | cache_dir=model_args.cache_dir, 116 | ) 117 | 118 | # Get datasets 119 | train_dataset = ( 120 | StanceDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None 121 | ) 122 | eval_dataset = ( 123 | StanceDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir) 124 | if training_args.do_eval 125 | else None 126 | ) 127 | test_dataset = ( 128 | StanceDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 129 | if training_args.do_predict 130 | else None 131 | ) 132 | 133 | def build_compute_metrics_fn(data_dir: str, task_name: str, output_dir: str) -> Callable[[EvalPrediction], Dict]: 134 | def compute_metrics_fn(p: EvalPrediction): 135 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 136 | preds = np.argmax(preds, axis=1) 137 | return compute_metrics(data_dir, task_name, output_dir, preds, p.label_ids) 138 | 139 | return compute_metrics_fn 140 | 141 | # Initialize our Trainer 142 | trainer = Trainer( 143 | model=model, 144 | args=training_args, 145 | train_dataset=train_dataset, 146 | eval_dataset=eval_dataset, 147 | compute_metrics=build_compute_metrics_fn(data_args.data_dir, data_args.task_name, training_args.output_dir), 148 | ) 149 | 150 | # Training 151 | if training_args.do_train: 152 | trainer.train( 153 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 154 | ) 155 | trainer.save_model() 156 | # For convenience, we also re-save the tokenizer to the same directory, 157 | # so that you can share your model easily on huggingface.co/models =) 158 | if trainer.is_world_master(): 159 | tokenizer.save_pretrained(training_args.output_dir) 160 | 161 | # Evaluation 162 | eval_results = {} 163 | if training_args.do_eval: 164 | logger.info("*** Evaluate ***") 165 | 166 | eval_datasets = [eval_dataset] 167 | 168 | for eval_dataset in eval_datasets: 169 | trainer.compute_metrics = build_compute_metrics_fn(data_args.data_dir, eval_dataset.args.task_name, training_args.output_dir) 170 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 171 | 172 | output_eval_file = os.path.join( 173 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 174 | ) 175 | if trainer.is_world_master(): 176 | with open(output_eval_file, "w") as writer: 177 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 178 | for key, value in eval_result.items(): 179 | logger.info(" %s = %s", key, value) 180 | writer.write("%s = %s\n" % (key, value)) 181 | 182 | eval_results.update(eval_result) 183 | 184 | return eval_results 185 | 186 | 187 | def _mp_fn(index): 188 | # For xla_spawn (TPUs) 189 | main() 190 | 191 | 192 | if __name__ == "__main__": 193 | main() 194 | -------------------------------------------------------------------------------- /metrics/utils_ea.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | from enum import Enum 5 | from filelock import FileLock 6 | from typing import List, Optional, Union 7 | import time 8 | 9 | from torch.utils.data.dataset import Dataset 10 | from dataclasses import dataclass, field 11 | from transformers.data.processors import DataProcessor 12 | from transformers.data.processors.utils import InputExample, InputFeatures 13 | from transformers.data.datasets import GlueDataTrainingArguments 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | from transformers.data.processors.glue import glue_convert_examples_to_features 16 | from sklearn.metrics import f1_score 17 | import networkx as nx 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class StanceProcessor(DataProcessor): 23 | def get_example_from_tensor_dict(self, tensor_dict): 24 | """See base class.""" 25 | return InputExample( 26 | tensor_dict["idx"].numpy(), 27 | tensor_dict["premise"].numpy().decode("utf-8"), 28 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 29 | str(tensor_dict["label"].numpy()), 30 | ) 31 | 32 | def get_train_examples(self, data_dir): 33 | """See base class.""" 34 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv"))) 35 | 36 | def get_dev_examples(self, data_dir): 37 | """See base class.""" 38 | return self._create_examples(self._read_tsv(os.path.join(data_dir))) 39 | 40 | def get_test_examples(self, data_dir): 41 | """See base class.""" 42 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv"))) 43 | 44 | def get_labels(self): 45 | """See base class.""" 46 | return ["support", "counter"] 47 | 48 | def _create_examples(self, lines): 49 | """Creates examples for the training, dev and test sets.""" 50 | examples = [] 51 | j = 0 52 | for (i, line) in enumerate(lines): 53 | id = i 54 | text_a = line[1] 55 | text_b = line[2] 56 | label = line[3] 57 | examples.append(InputExample(guid=id, text_a=text_a, text_b=text_b, label=label)) 58 | 59 | return examples 60 | 61 | 62 | stance_processor = { 63 | "stance": StanceProcessor 64 | } 65 | 66 | stance_output_modes = { 67 | "stance": "classification" 68 | } 69 | 70 | stance_num_labels = { 71 | "stance": 2 72 | } 73 | 74 | 75 | @dataclass 76 | class StanceDataTrainingArguments: 77 | """ 78 | Arguments pertaining to what data we are going to input our model for training and eval. 79 | 80 | Using `HfArgumentParser` we can turn this class 81 | into argparse arguments to be able to specify them on 82 | the command line. 83 | """ 84 | 85 | task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(stance_processor.keys())}) 86 | data_dir: str = field( 87 | metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} 88 | ) 89 | max_seq_length: int = field( 90 | default=128, 91 | metadata={ 92 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 93 | "than this will be truncated, sequences shorter will be padded." 94 | }, 95 | ) 96 | overwrite_cache: bool = field( 97 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 98 | ) 99 | 100 | def __post_init__(self): 101 | self.task_name = self.task_name.lower() 102 | 103 | 104 | class Split(Enum): 105 | train = "train" 106 | dev = "dev" 107 | test = "test" 108 | 109 | 110 | class StanceDataset(Dataset): 111 | """ 112 | This will be superseded by a framework-agnostic approach 113 | soon. 114 | """ 115 | 116 | args: GlueDataTrainingArguments 117 | output_mode: str 118 | features: List[InputFeatures] 119 | 120 | def __init__( 121 | self, 122 | args: GlueDataTrainingArguments, 123 | tokenizer: PreTrainedTokenizer, 124 | limit_length: Optional[int] = None, 125 | mode: Union[str, Split] = Split.train, 126 | cache_dir: Optional[str] = None 127 | ): 128 | self.args = args 129 | self.processor = stance_processor[args.task_name]() 130 | self.output_mode = stance_output_modes[args.task_name] 131 | if isinstance(mode, str): 132 | try: 133 | mode = Split[mode] 134 | except KeyError: 135 | raise KeyError("mode is not a valid split name") 136 | # Load data features from cache or dataset file 137 | cached_features_file = os.path.join( 138 | cache_dir if cache_dir is not None else args.data_dir, 139 | "cached_{}_{}_{}_{}".format( 140 | mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, 141 | ), 142 | ) 143 | self.label_list = self.processor.get_labels() 144 | 145 | # Make sure only the first process in distributed training processes the dataset, 146 | # and the others will use the cache. 147 | lock_path = cached_features_file + ".lock" 148 | with FileLock(lock_path): 149 | 150 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 151 | start = time.time() 152 | self.features = torch.load(cached_features_file) 153 | logger.info( 154 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 155 | ) 156 | else: 157 | logger.info(f"Creating features from dataset file at {args.data_dir}") 158 | 159 | if mode == Split.dev: 160 | examples = self.processor.get_dev_examples(args.data_dir) 161 | elif mode == Split.test: 162 | examples = self.processor.get_test_examples(args.data_dir) 163 | else: 164 | examples = self.processor.get_train_examples(args.data_dir) 165 | if limit_length is not None: 166 | examples = examples[:limit_length] 167 | self.features = glue_convert_examples_to_features( 168 | examples, 169 | tokenizer, 170 | max_length=args.max_seq_length, 171 | label_list=self.label_list, 172 | output_mode=self.output_mode, 173 | ) 174 | start = time.time() 175 | torch.save(self.features, cached_features_file) 176 | # ^ This seems to take a lot of time so I want to investigate why and how we can improve. 177 | logger.info( 178 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 179 | ) 180 | 181 | def __len__(self): 182 | return len(self.features) 183 | 184 | def __getitem__(self, i) -> InputFeatures: 185 | return self.features[i] 186 | 187 | def get_labels(self): 188 | return self.label_list 189 | 190 | 191 | def compute_metrics(task_name, output_dir, preds, labels): 192 | if task_name == "stance": 193 | return { 194 | "acc": (preds == labels).mean() 195 | } 196 | else: 197 | raise KeyError(task_name) 198 | -------------------------------------------------------------------------------- /metrics/utils_seqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | from enum import Enum 5 | from filelock import FileLock 6 | from typing import List, Optional, Union 7 | import time 8 | 9 | from torch.utils.data.dataset import Dataset 10 | from dataclasses import dataclass, field 11 | from transformers.data.processors import DataProcessor 12 | from transformers.data.processors.utils import InputExample, InputFeatures 13 | from transformers.data.datasets import GlueDataTrainingArguments 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | from transformers.data.processors.glue import glue_convert_examples_to_features 16 | from sklearn.metrics import f1_score 17 | import networkx as nx 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class StanceProcessor(DataProcessor): 23 | 24 | def get_dfs_ordering(self, graph_string): 25 | graph = nx.DiGraph() 26 | nodes = [] 27 | relations = {} 28 | for edge in graph_string[1:-1].split(")("): 29 | parts = edge.split("; ") 30 | graph.add_edge(parts[0], parts[2]) 31 | if parts[0] not in nodes: 32 | nodes.append(parts[0]) 33 | if parts[2] not in nodes: 34 | nodes.append(parts[2]) 35 | relations[(parts[0], parts[2])] = parts[1] 36 | 37 | in_degrees = list(graph.in_degree(nodes)) 38 | 39 | start_nodes = [] 40 | for (i, node) in enumerate(nodes): 41 | if in_degrees[i][1] == 0: 42 | start_nodes.append(in_degrees[i][0]) 43 | 44 | dfs_edges = list(nx.edge_dfs(graph, source=start_nodes)) 45 | 46 | new_graph_string = "" 47 | for edge in dfs_edges: 48 | new_graph_string += "(" + edge[0] + "; " + relations[(edge[0], edge[1])] + "; " + edge[1] + ")" 49 | 50 | return new_graph_string 51 | 52 | def create_leave_one_out_graphs(self, graph): 53 | leave_one_out_graphs = [] 54 | edges = graph[1:-1].split(")(") 55 | for edge in edges: 56 | leave_one_out_graph = graph.replace("(" + edge + ")", "") 57 | leave_one_out_graph = leave_one_out_graph.replace("(", "").replace(";", "").replace(")", ". ") 58 | leave_one_out_graphs.append(leave_one_out_graph) 59 | 60 | return leave_one_out_graphs 61 | 62 | def get_train_examples(self, data_dir): 63 | """See base class.""" 64 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv"))) 65 | 66 | def get_dev_examples(self, data_dir): 67 | """See base class.""" 68 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "annotations.tsv"))) 69 | 70 | def get_test_examples(self, data_dir): 71 | """See base class.""" 72 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv"))) 73 | 74 | def get_test_size(self, data_dir): 75 | return len(self._read_tsv(os.path.join(data_dir, "annotations.tsv"))) 76 | 77 | def get_labels(self): 78 | """See base class.""" 79 | return ["support", "counter", "incorrect"] 80 | 81 | def _create_examples(self, lines): 82 | examples = [] 83 | j = 0 84 | for (i, line) in enumerate(lines): 85 | id = i 86 | if line[3] != "struct_correct": 87 | continue 88 | text_a = line[0] 89 | text_b = self.get_dfs_ordering(line[1]) 90 | label = line[2] 91 | examples.append(InputExample(guid=id, text_a=text_a, text_b=text_b, label=label)) 92 | return examples 93 | 94 | 95 | stance_processor = { 96 | "stance": StanceProcessor 97 | } 98 | 99 | stance_output_modes = { 100 | "stance": "classification" 101 | } 102 | 103 | stance_num_labels = { 104 | "stance": 3 105 | } 106 | 107 | 108 | @dataclass 109 | class StanceDataTrainingArguments: 110 | """ 111 | Arguments pertaining to what data we are going to input our model for training and eval. 112 | 113 | Using `HfArgumentParser` we can turn this class 114 | into argparse arguments to be able to specify them on 115 | the command line. 116 | """ 117 | 118 | task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(stance_processor.keys())}) 119 | data_dir: str = field( 120 | metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} 121 | ) 122 | max_seq_length: int = field( 123 | default=128, 124 | metadata={ 125 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 126 | "than this will be truncated, sequences shorter will be padded." 127 | }, 128 | ) 129 | overwrite_cache: bool = field( 130 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 131 | ) 132 | 133 | def __post_init__(self): 134 | self.task_name = self.task_name.lower() 135 | 136 | 137 | class Split(Enum): 138 | train = "train" 139 | dev = "dev" 140 | test = "test" 141 | 142 | 143 | class StanceDataset(Dataset): 144 | """ 145 | This will be superseded by a framework-agnostic approach 146 | soon. 147 | """ 148 | 149 | args: GlueDataTrainingArguments 150 | output_mode: str 151 | features: List[InputFeatures] 152 | 153 | def __init__( 154 | self, 155 | args: GlueDataTrainingArguments, 156 | tokenizer: PreTrainedTokenizer, 157 | limit_length: Optional[int] = None, 158 | mode: Union[str, Split] = Split.train, 159 | cache_dir: Optional[str] = None, 160 | ): 161 | self.args = args 162 | self.processor = stance_processor[args.task_name]() 163 | self.output_mode = stance_output_modes[args.task_name] 164 | if isinstance(mode, str): 165 | try: 166 | mode = Split[mode] 167 | except KeyError: 168 | raise KeyError("mode is not a valid split name") 169 | # Load data features from cache or dataset file 170 | cached_features_file = os.path.join( 171 | cache_dir if cache_dir is not None else args.data_dir, 172 | "cached_{}_{}_{}_{}".format( 173 | mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, 174 | ), 175 | ) 176 | self.label_list = self.processor.get_labels() 177 | 178 | # Make sure only the first process in distributed training processes the dataset, 179 | # and the others will use the cache. 180 | lock_path = cached_features_file + ".lock" 181 | with FileLock(lock_path): 182 | 183 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 184 | start = time.time() 185 | self.features = torch.load(cached_features_file) 186 | logger.info( 187 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 188 | ) 189 | else: 190 | logger.info(f"Creating features from dataset file at {args.data_dir}") 191 | 192 | if mode == Split.dev: 193 | examples = self.processor.get_dev_examples(args.data_dir) 194 | elif mode == Split.test: 195 | examples = self.processor.get_test_examples(args.data_dir) 196 | else: 197 | examples = self.processor.get_train_examples(args.data_dir) 198 | if limit_length is not None: 199 | examples = examples[:limit_length] 200 | self.features = glue_convert_examples_to_features( 201 | examples, 202 | tokenizer, 203 | max_length=args.max_seq_length, 204 | label_list=self.label_list, 205 | output_mode=self.output_mode, 206 | ) 207 | start = time.time() 208 | torch.save(self.features, cached_features_file) 209 | logger.info( 210 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 211 | ) 212 | 213 | def __len__(self): 214 | return len(self.features) 215 | 216 | def __getitem__(self, i) -> InputFeatures: 217 | return self.features[i] 218 | 219 | def get_labels(self): 220 | return self.label_list 221 | 222 | 223 | 224 | def compute_metrics(data_dir, task_name, output_dir, preds, labels): 225 | if task_name == "stance": 226 | return { 227 | "SeCA": (preds == labels).sum()/stance_processor[task_name]().get_test_size(data_dir) 228 | } 229 | else: 230 | raise KeyError(task_name) 231 | -------------------------------------------------------------------------------- /model_scripts/test_stance_pred.sh: -------------------------------------------------------------------------------- 1 | python src/run_pl_stance_pred.py \ 2 | --model_name_or_path ./models/roberta-large-stance \ 3 | --task_name stance \ 4 | --do_eval \ 5 | --save_steps 10000 \ 6 | --data_dir ./data \ 7 | --max_seq_length 128 \ 8 | --per_device_train_batch_size 32 \ 9 | --learning_rate 1e-5 \ 10 | --num_train_epochs 10.0 \ 11 | --output_dir ./models/roberta-large-stance \ 12 | --cache_dir ./tmp \ 13 | --logging_steps 500 \ 14 | --evaluation_strategy="epoch" -------------------------------------------------------------------------------- /model_scripts/test_structured_model.sh: -------------------------------------------------------------------------------- 1 | python structured_model/run_joint_model.py \ 2 | --model_type roberta_eg \ 3 | --model_name_or_path ./models/sp_model \ 4 | --task_name eg \ 5 | --do_eval \ 6 | --do_eval_edge \ 7 | --do_lower_case \ 8 | --data_dir ./data \ 9 | --max_seq_length 128 \ 10 | --per_gpu_eval_batch_size 8 \ 11 | --per_gpu_train_batch_size 8 \ 12 | --learning_rate 1e-5 \ 13 | --num_train_epochs 10 \ 14 | --output_dir ./models/sp_model \ 15 | --seed 42 \ 16 | --data_cache_dir ./tmp \ 17 | --cache_dir ./tmp \ 18 | --evaluate_during_training -------------------------------------------------------------------------------- /model_scripts/train_conceptnet_finetuning.sh: -------------------------------------------------------------------------------- 1 | python structured_model/run_commonsense_finetuning.py \ 2 | --model_type roberta_relation \ 3 | --model_name_or_path roberta-large \ 4 | --task_name relation \ 5 | --do_train \ 6 | --do_eval \ 7 | --do_lower_case \ 8 | --data_dir ./conceptnet_data \ 9 | --max_seq_length 30 \ 10 | --per_gpu_eval_batch_size 32 \ 11 | --per_gpu_train_batch_size 32 \ 12 | --learning_rate 1e-5 \ 13 | --num_train_epochs 5 \ 14 | --output_dir ./models/relation_model \ 15 | --seed 42 \ 16 | --data_cache_dir ./tmp \ 17 | --cache_dir ./tmp \ 18 | --evaluate_during_training -------------------------------------------------------------------------------- /model_scripts/train_stance_pred.sh: -------------------------------------------------------------------------------- 1 | python src/run_pl_stance_pred.py \ 2 | --model_name_or_path roberta-large \ 3 | --task_name stance \ 4 | --do_train \ 5 | --do_eval \ 6 | --save_steps 10000 \ 7 | --data_dir ./data \ 8 | --max_seq_length 128 \ 9 | --per_device_train_batch_size 32 \ 10 | --learning_rate 1e-5 \ 11 | --num_train_epochs 10.0 \ 12 | --output_dir ./models/roberta-large-stance \ 13 | --cache_dir ./tmp \ 14 | --logging_steps 500 \ 15 | --evaluation_strategy="epoch" -------------------------------------------------------------------------------- /model_scripts/train_structured_model.sh: -------------------------------------------------------------------------------- 1 | python structured_model/run_joint_model.py \ 2 | --model_type roberta_eg \ 3 | --model_name_or_path ./models/relation_model \ 4 | --task_name eg \ 5 | --do_train \ 6 | --do_eval \ 7 | --do_lower_case \ 8 | --data_dir ./data \ 9 | --max_seq_length 128 \ 10 | --per_gpu_eval_batch_size 8 \ 11 | --per_gpu_train_batch_size 8 \ 12 | --learning_rate 1e-5 \ 13 | --num_train_epochs 10 \ 14 | --output_dir ./models/sp_model \ 15 | --seed 42 \ 16 | --data_cache_dir ./tmp \ 17 | --cache_dir ./tmp \ 18 | --evaluate_during_training -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | All trained models come here. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | filelock==3.0.12 2 | bert_score==0.3.6 3 | scipy==1.4.1 4 | spacy==2.1.9 5 | nltk==3.4.5 6 | numpy==1.18.5 7 | rouge_score==0.0.3 8 | sacrebleu==1.4.3 9 | transformers==3.4.0 10 | networkx==2.4 11 | torch==1.9.0 12 | pytorch_lightning==0.8.5 13 | dataclasses==0.8 14 | fairseq==0.10.2 15 | GitPython==3.1.24 16 | scikit_learn==1.0 17 | -------------------------------------------------------------------------------- /src/run_pl_stance_pred.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import os 4 | import sys 5 | from dataclasses import dataclass, field 6 | from typing import Callable, Dict, Optional 7 | 8 | import numpy as np 9 | 10 | from scipy.special import softmax 11 | 12 | from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction 13 | from transformers import GlueDataTrainingArguments as DataTrainingArguments 14 | from transformers import ( 15 | HfArgumentParser, 16 | Trainer, 17 | TrainingArguments, 18 | set_seed, 19 | ) 20 | 21 | from utils_stance_pred import ( 22 | stance_output_modes, 23 | stance_num_labels, 24 | StanceDataset, 25 | compute_metrics 26 | ) 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | @dataclass 32 | class ModelArguments: 33 | """ 34 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 35 | """ 36 | 37 | model_name_or_path: str = field( 38 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 39 | ) 40 | config_name: Optional[str] = field( 41 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 42 | ) 43 | tokenizer_name: Optional[str] = field( 44 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 45 | ) 46 | cache_dir: Optional[str] = field( 47 | default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} 48 | ) 49 | 50 | 51 | def main(): 52 | # See all possible arguments in src/transformers/training_args.py 53 | # or by passing the --help flag to this script. 54 | # We now keep distinct sets of args, for a cleaner separation of concerns. 55 | 56 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) 57 | 58 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 59 | # If we pass only one argument to the script and it's the path to a json file, 60 | # let's parse it to get our arguments. 61 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 62 | else: 63 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 64 | 65 | if ( 66 | os.path.exists(training_args.output_dir) 67 | and os.listdir(training_args.output_dir) 68 | and training_args.do_train 69 | and not training_args.overwrite_output_dir 70 | ): 71 | raise ValueError( 72 | f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." 73 | ) 74 | 75 | # Setup logging 76 | logging.basicConfig( 77 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 78 | datefmt="%m/%d/%Y %H:%M:%S", 79 | level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, 80 | ) 81 | logger.warning( 82 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 83 | training_args.local_rank, 84 | training_args.device, 85 | training_args.n_gpu, 86 | bool(training_args.local_rank != -1), 87 | training_args.fp16, 88 | ) 89 | logger.info("Training/evaluation parameters %s", training_args) 90 | 91 | # Set seed 92 | set_seed(training_args.seed) 93 | 94 | try: 95 | num_labels = stance_num_labels[data_args.task_name] 96 | output_mode = stance_output_modes[data_args.task_name] 97 | except KeyError: 98 | raise ValueError("Task not found: %s" % (data_args.task_name)) 99 | 100 | # Load pretrained model and tokenizer 101 | # 102 | # Distributed training: 103 | # The .from_pretrained methods guarantee that only one local process can concurrently 104 | # download model & vocab. 105 | 106 | config = AutoConfig.from_pretrained( 107 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 108 | num_labels=num_labels, 109 | finetuning_task=data_args.task_name, 110 | cache_dir=model_args.cache_dir, 111 | ) 112 | tokenizer = AutoTokenizer.from_pretrained( 113 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 114 | cache_dir=model_args.cache_dir, 115 | ) 116 | model = AutoModelForSequenceClassification.from_pretrained( 117 | model_args.model_name_or_path, 118 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 119 | config=config, 120 | cache_dir=model_args.cache_dir, 121 | ) 122 | 123 | # Get datasets 124 | train_dataset = ( 125 | StanceDataset(data_args, tokenizer=tokenizer, cache_dir=model_args.cache_dir) if training_args.do_train else None 126 | ) 127 | eval_dataset = ( 128 | StanceDataset(data_args, tokenizer=tokenizer, mode="dev", cache_dir=model_args.cache_dir) 129 | if training_args.do_eval 130 | else None 131 | ) 132 | test_dataset = ( 133 | StanceDataset(data_args, tokenizer=tokenizer, mode="test", cache_dir=model_args.cache_dir) 134 | if training_args.do_predict 135 | else None 136 | ) 137 | 138 | def build_compute_metrics_fn(task_name: str, output_dir: str) -> Callable[[EvalPrediction], Dict]: 139 | def compute_metrics_fn(p: EvalPrediction): 140 | preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions 141 | probs = softmax(preds, axis=1) 142 | with open(os.path.join(output_dir, "probs.txt"), "w", encoding="utf-8-sig") as f: 143 | f.write(str(eval_dataset.get_labels()) + "\n") 144 | for prob in probs: 145 | f.write(str(prob) + "\n") 146 | if output_mode == "classification": 147 | preds = np.argmax(preds, axis=1) 148 | else: # regression 149 | preds = np.squeeze(preds) 150 | return compute_metrics(task_name, output_dir, preds, p.label_ids) 151 | 152 | return compute_metrics_fn 153 | 154 | # Initialize our Trainer 155 | trainer = Trainer( 156 | model=model, 157 | args=training_args, 158 | train_dataset=train_dataset, 159 | eval_dataset=eval_dataset, 160 | compute_metrics=build_compute_metrics_fn(data_args.task_name, training_args.output_dir), 161 | ) 162 | 163 | # Training 164 | if training_args.do_train: 165 | trainer.train( 166 | model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None 167 | ) 168 | trainer.save_model() 169 | # For convenience, we also re-save the tokenizer to the same directory, 170 | # so that you can share your model easily on huggingface.co/models =) 171 | if trainer.is_world_master(): 172 | tokenizer.save_pretrained(training_args.output_dir) 173 | 174 | # Evaluation 175 | eval_results = {} 176 | if training_args.do_eval: 177 | logger.info("*** Evaluate ***") 178 | 179 | eval_datasets = [eval_dataset] 180 | 181 | for eval_dataset in eval_datasets: 182 | trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name, training_args.output_dir) 183 | print(trainer.compute_metrics) 184 | eval_result = trainer.evaluate(eval_dataset=eval_dataset) 185 | 186 | output_eval_file = os.path.join( 187 | training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt" 188 | ) 189 | if trainer.is_world_master(): 190 | with open(output_eval_file, "w") as writer: 191 | logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name)) 192 | for key, value in eval_result.items(): 193 | logger.info(" %s = %s", key, value) 194 | writer.write("%s = %s\n" % (key, value)) 195 | 196 | eval_results.update(eval_result) 197 | 198 | if training_args.do_predict: 199 | logging.info("*** Test ***") 200 | test_datasets = [test_dataset] 201 | 202 | for test_dataset in test_datasets: 203 | predictions = trainer.predict(test_dataset=test_dataset).predictions 204 | if output_mode == "classification": 205 | predictions = np.argmax(predictions, axis=1) 206 | 207 | output_test_file = os.path.join( 208 | training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt" 209 | ) 210 | if trainer.is_world_master(): 211 | with open(output_test_file, "w") as writer: 212 | logger.info("***** Test results {} *****".format(test_dataset.args.task_name)) 213 | writer.write("index\tprediction\n") 214 | for index, item in enumerate(predictions): 215 | if output_mode == "regression": 216 | writer.write("%d\t%3.3f\n" % (index, item)) 217 | else: 218 | item = test_dataset.get_labels()[item] 219 | writer.write("%d\t%s\n" % (index, item)) 220 | return eval_results 221 | 222 | 223 | def _mp_fn(index): 224 | # For xla_spawn (TPUs) 225 | main() 226 | 227 | 228 | if __name__ == "__main__": 229 | main() -------------------------------------------------------------------------------- /src/utils_stance_pred.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | from enum import Enum 5 | from filelock import FileLock 6 | from typing import List, Optional, Union 7 | import time 8 | 9 | from torch.utils.data.dataset import Dataset 10 | from dataclasses import dataclass, field 11 | from transformers.data.processors import DataProcessor 12 | from transformers.data.processors.utils import InputExample, InputFeatures 13 | from transformers.data.datasets import GlueDataTrainingArguments 14 | from transformers.tokenization_utils import PreTrainedTokenizer 15 | from transformers.data.processors.glue import glue_convert_examples_to_features 16 | from sklearn.metrics import f1_score 17 | import networkx as nx 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class StanceProcessor(DataProcessor): 23 | """Processor for the Stance Prediction Task.""" 24 | 25 | def get_example_from_tensor_dict(self, tensor_dict): 26 | """See base class.""" 27 | return InputExample( 28 | tensor_dict["idx"].numpy(), 29 | tensor_dict["premise"].numpy().decode("utf-8"), 30 | tensor_dict["hypothesis"].numpy().decode("utf-8"), 31 | str(tensor_dict["label"].numpy()), 32 | ) 33 | 34 | def get_train_examples(self, data_dir): 35 | """See base class.""" 36 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv"))) 37 | 38 | def get_dev_examples(self, data_dir): 39 | """See base class.""" 40 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv"))) 41 | 42 | def get_test_examples(self, data_dir): 43 | """See base class.""" 44 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv"))) 45 | 46 | def get_labels(self): 47 | """See base class.""" 48 | return ["support", "counter"] 49 | 50 | def _create_examples(self, lines): 51 | """Creates examples for the training, dev and test sets.""" 52 | examples = [] 53 | j = 0 54 | for (i, line) in enumerate(lines): 55 | id = i 56 | text_a = line[0] 57 | text_b = line[1] 58 | label = line[2] 59 | examples.append(InputExample(guid=id, text_a=text_a, text_b=text_b, label=label)) 60 | return examples 61 | 62 | 63 | stance_processor = { 64 | "stance": StanceProcessor 65 | } 66 | 67 | stance_output_modes = { 68 | "stance": "classification" 69 | } 70 | 71 | stance_num_labels = { 72 | "stance": 2 73 | } 74 | 75 | 76 | @dataclass 77 | class StanceDataTrainingArguments: 78 | """ 79 | Arguments pertaining to what data we are going to input our model for training and eval. 80 | 81 | Using `HfArgumentParser` we can turn this class 82 | into argparse arguments to be able to specify them on 83 | the command line. 84 | """ 85 | 86 | task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(stance_processor.keys())}) 87 | data_dir: str = field( 88 | metadata={"help": "The input data dir. Should contain the .tsv files (or other data files) for the task."} 89 | ) 90 | max_seq_length: int = field( 91 | default=128, 92 | metadata={ 93 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 94 | "than this will be truncated, sequences shorter will be padded." 95 | }, 96 | ) 97 | overwrite_cache: bool = field( 98 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 99 | ) 100 | 101 | def __post_init__(self): 102 | self.task_name = self.task_name.lower() 103 | 104 | 105 | class Split(Enum): 106 | train = "train" 107 | dev = "dev" 108 | test = "test" 109 | 110 | 111 | class StanceDataset(Dataset): 112 | """ 113 | This will be superseded by a framework-agnostic approach 114 | soon. 115 | """ 116 | 117 | args: GlueDataTrainingArguments 118 | output_mode: str 119 | features: List[InputFeatures] 120 | 121 | def __init__( 122 | self, 123 | args: GlueDataTrainingArguments, 124 | tokenizer: PreTrainedTokenizer, 125 | limit_length: Optional[int] = None, 126 | mode: Union[str, Split] = Split.train, 127 | cache_dir: Optional[str] = None, 128 | ): 129 | self.args = args 130 | self.processor = stance_processor[args.task_name]() 131 | self.output_mode = stance_output_modes[args.task_name] 132 | if isinstance(mode, str): 133 | try: 134 | mode = Split[mode] 135 | except KeyError: 136 | raise KeyError("mode is not a valid split name") 137 | # Load data features from cache or dataset file 138 | cached_features_file = os.path.join( 139 | cache_dir if cache_dir is not None else args.data_dir, 140 | "cached_{}_{}_{}_{}".format( 141 | mode.value, tokenizer.__class__.__name__, str(args.max_seq_length), args.task_name, 142 | ), 143 | ) 144 | self.label_list = self.processor.get_labels() 145 | 146 | # Make sure only the first process in distributed training processes the dataset, 147 | # and the others will use the cache. 148 | lock_path = cached_features_file + ".lock" 149 | with FileLock(lock_path): 150 | 151 | if os.path.exists(cached_features_file) and not args.overwrite_cache: 152 | start = time.time() 153 | self.features = torch.load(cached_features_file) 154 | logger.info( 155 | f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start 156 | ) 157 | else: 158 | logger.info(f"Creating features from dataset file at {args.data_dir}") 159 | 160 | if mode == Split.dev: 161 | examples = self.processor.get_dev_examples(args.data_dir) 162 | elif mode == Split.test: 163 | examples = self.processor.get_test_examples(args.data_dir) 164 | else: 165 | examples = self.processor.get_train_examples(args.data_dir) 166 | if limit_length is not None: 167 | examples = examples[:limit_length] 168 | self.features = glue_convert_examples_to_features( 169 | examples, 170 | tokenizer, 171 | max_length=args.max_seq_length, 172 | label_list=self.label_list, 173 | output_mode=self.output_mode, 174 | ) 175 | start = time.time() 176 | torch.save(self.features, cached_features_file) 177 | # ^ This seems to take a lot of time so I want to investigate why and how we can improve. 178 | logger.info( 179 | "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start 180 | ) 181 | 182 | def __len__(self): 183 | return len(self.features) 184 | 185 | def __getitem__(self, i) -> InputFeatures: 186 | return self.features[i] 187 | 188 | def get_labels(self): 189 | return self.label_list 190 | 191 | 192 | def compute_metrics(task_name, output_dir, preds, labels): 193 | if task_name == "stance": 194 | with open(os.path.join(output_dir, "gold.txt"), "w", encoding="utf-8-sig") as f: 195 | for label in labels: 196 | f.write(str(label) + "\n") 197 | with open(os.path.join(output_dir, "pred.txt"), "w", encoding="utf-8-sig") as f: 198 | for pred in preds: 199 | f.write(str(pred) + "\n") 200 | return { 201 | "micro_f1": f1_score(preds, labels, average="micro"), 202 | "macro_f1": f1_score(preds, labels, average="macro"), 203 | "weighted_f1": f1_score(preds, labels, average="weighted"), 204 | "acc": (preds == labels).mean() 205 | } 206 | else: 207 | raise KeyError(task_name) 208 | -------------------------------------------------------------------------------- /structured_model/README.md: -------------------------------------------------------------------------------- 1 | ## Commonsense-augmented Structured Prediction Model 2 | 3 | The structured model for graph generation has multiple modules. Follow the below steps in order to train and test the model. Note that all these scripts should be executed from the root folder. 4 | 5 | # Step 1: 6 | 7 | We'll first compute the embeddings of the relations using a pre-trained LM (RoBERTa) and save them inside ```data``` folder. Alternatively, you can find these embeddings in ```data/relations.pt```. 8 | 9 | ``` 10 | python structured_model/save_relation_embeddings.py 11 | ``` 12 | 13 | # Step 2: 14 | 15 | In order to leverage commonsense knowledge from ConceptNet, next we'll fine-tune RoBERTa on ConceptNet. You can do so using the below script and the training data can be found [here](https://drive.google.com/drive/folders/19faqrwXLM5EySeB4yQ3JRPzGsi68DF07?usp=sharing). Alternatively, directly download our pre-trained model [here](https://drive.google.com/drive/folders/14CnyJUQX8Z2rubwofDGvTLnh_3bLsjml?usp=sharing). 16 | 17 | ``` 18 | bash model_scripts/train_conceptnet_finetuning.sh 19 | ``` 20 | 21 | # Step 3: 22 | 23 | Next, we'll use the previously finetuned model to train our graph generation model using the below script. A couple of things to keep in mind: (1) This model has a component which predicts the external nodes first, which we obtain using BART. These are comma separated concepts as uploaded in ```data/external_concepts_dev.txt```. (2) Once you have trained the model, it will save the internal node predictions as uploaded in ```data/internal_concepts_dev.txt```. These are in BIO format where each stretch of B-N to I-N denotes a node. 24 | 25 | ``` 26 | bash model_scripts/train_structured_model.sh 27 | ``` 28 | 29 | # Step 4: 30 | 31 | You can directly download our trained model [here](https://drive.google.com/drive/folders/1fD0BqkigLdxXfR_tLrMnB7CTewsGx_HL?usp=sharing) and test it to generate the final graphs. Note that graph generation uses three predictions -- (1) internal node predictions, (2) external node predictions and (3) the edge logits. All these come together in an ILP to generate the final graphs. The below script handles all of these and will generate the graphs in ```prediction_edges_dev.lst``` inside the model folder. 32 | 33 | Once you download our pre-trained model, you'll also find our generated graphs, so you can use them to directly obtain the metrics. 34 | 35 | ``` 36 | bash model_scripts/test_structured_model.sh 37 | ``` 38 | -------------------------------------------------------------------------------- /structured_model/inference.py: -------------------------------------------------------------------------------- 1 | from pulp import * 2 | import numpy as np 3 | 4 | def merge_nodes(no_edge_prob, max_edge_prob, max_edge_index, ordered_nodes): 5 | new_ordered_nodes = [] 6 | indices_to_remove = [] 7 | for (i, ordered_node) in enumerate(ordered_nodes): 8 | if ordered_node not in new_ordered_nodes: 9 | new_ordered_nodes.append(ordered_node) 10 | else: 11 | indices_to_remove.append(i) 12 | 13 | assert len(new_ordered_nodes) <= 8 14 | 15 | if len(indices_to_remove) == 0: 16 | return no_edge_prob, max_edge_prob, max_edge_index, new_ordered_nodes 17 | 18 | new_no_edge_prob = np.delete(no_edge_prob, indices_to_remove, axis=0) 19 | new_no_edge_prob = np.delete(new_no_edge_prob, indices_to_remove, axis=1) 20 | 21 | new_max_edge_prob = np.delete(max_edge_prob, indices_to_remove, axis=0) 22 | new_max_edge_prob = np.delete(new_max_edge_prob, indices_to_remove, axis=1) 23 | 24 | new_max_edge_index = np.delete(max_edge_index, indices_to_remove, axis=0) 25 | new_max_edge_index = np.delete(new_max_edge_index, indices_to_remove, axis=1) 26 | 27 | return new_no_edge_prob, new_max_edge_prob, new_max_edge_index, new_ordered_nodes 28 | 29 | def solve_LP_no_connectivity(no_edge_prob, max_edge_prob, max_edge_index, ordered_nodes, edge_label_list): 30 | no_edge_prob, max_edge_prob, max_edge_index, ordered_nodes = merge_nodes(no_edge_prob, max_edge_prob, 31 | max_edge_index, ordered_nodes) 32 | prob = LpProblem("Node edge consistency ", LpMaximize) 33 | all_vars = {} 34 | 35 | # Optimization Problem 36 | opt_prob = None 37 | for i in range(len(no_edge_prob)): 38 | for j in range(len(no_edge_prob)): 39 | var0 = LpVariable("Edge_" + str(i + 1) + "_" + str(j + 1) + "_0", 0, 1, LpInteger) 40 | var1 = LpVariable("Edge_" + str(i + 1) + "_" + str(j + 1) + "_1", 0, 1, LpInteger) 41 | 42 | all_vars[(i, j, 0)] = var0 43 | all_vars[(i, j, 1)] = var1 44 | 45 | if opt_prob is None: 46 | opt_prob = no_edge_prob[i][j] * all_vars[(i, j, 0)] + max_edge_prob[i][j] * all_vars[(i, j, 1)] 47 | else: 48 | opt_prob += no_edge_prob[i][j] * all_vars[(i, j, 0)] + max_edge_prob[i][j] * all_vars[(i, j, 1)] 49 | 50 | prob += opt_prob, "Maximum Score" 51 | 52 | # An edge is either present or absent 53 | for i in range(len(no_edge_prob)): 54 | for j in range(len(no_edge_prob)): 55 | prob += all_vars[(i, j, 0)] + all_vars[(i, j, 1)] == 1, "Exist condition" + str(i) + "_" + str(j) 56 | 57 | prob.solve() 58 | 59 | edges = [] 60 | edges_dict = {} 61 | for v in prob.variables(): 62 | if v.varValue > 0 and v.name.endswith("1") and v.name.startswith("Edge"): 63 | name = v.name.split("_") 64 | n_i = int(name[1]) - 1 65 | n_j = int(name[2]) - 1 66 | assert ordered_nodes[n_i] != ordered_nodes[n_j] 67 | assert (ordered_nodes[n_i], ordered_nodes[n_j]) not in edges_dict 68 | edge = "(" + ordered_nodes[n_i] + "; " + edge_label_list[max_edge_index[n_i][n_j]] + "; " + ordered_nodes[ 69 | n_j] + ")" 70 | edges.append(edge) 71 | edges_dict[(ordered_nodes[n_i], ordered_nodes[n_j])] = True 72 | print("Max score = ", value(prob.objective)) 73 | 74 | return edges 75 | 76 | 77 | def solve_LP(no_edge_prob, max_edge_prob, max_edge_index, ordered_nodes, edge_label_list): 78 | no_edge_prob, max_edge_prob, max_edge_index, ordered_nodes = merge_nodes(no_edge_prob, max_edge_prob, max_edge_index, ordered_nodes) 79 | prob = LpProblem("Node edge consistency ", LpMaximize) 80 | all_vars = {} 81 | 82 | all_flow_vars = {} 83 | 84 | source_id = -1 85 | sink_id = -2 86 | 87 | print(ordered_nodes) 88 | 89 | node_ids_present = [i for i in range(len(ordered_nodes))] 90 | 91 | # add flow from source to one node present 92 | # arbitarily choosing that node to be the last node 93 | # 1000 is infinity 94 | all_flow_vars[(source_id, node_ids_present[-1])] = \ 95 | LpVariable("Flow_source_" + str(node_ids_present[-1] + 1), 0, 1000, LpInteger) 96 | 97 | # add flow from all nodes present to sink 98 | for i in range(len(node_ids_present)): 99 | temp = node_ids_present[i] 100 | all_flow_vars[(temp, sink_id)] = LpVariable("Flow_" + str(temp + 1) + "_sink", 0, 1000, LpInteger) 101 | 102 | # define capacities 103 | C = {} 104 | # capacity from source to 1st node is number of nodes in graph 105 | C[(source_id, node_ids_present[-1])] = len(node_ids_present) 106 | C[(node_ids_present[-1], source_id)] = 0 107 | 108 | # capacity from nodes in graph to sink is 1 109 | for i in range(len(node_ids_present)): 110 | temp = node_ids_present[i] 111 | C[(temp, sink_id)] = 1 112 | C[(sink_id, temp)] = 0 113 | 114 | # capacities inside graph are infinite or say 1000 in this case, except self loops and if the edge is not possible 115 | arcs = set() 116 | for i in range(len(no_edge_prob)): 117 | for j in range(len(no_edge_prob)): 118 | if (i == j) or (i not in node_ids_present) or (j not in node_ids_present): 119 | C[(i, j)] = 0 120 | else: 121 | C[(i, j)] = 1000 122 | arcs.add((i, j)) 123 | arcs.add((j, i)) 124 | arcs = list(arcs) 125 | 126 | # Optimization Problem 127 | opt_prob = None 128 | for i in range(len(no_edge_prob)): 129 | for j in range(len(no_edge_prob)): 130 | if i == j: 131 | continue 132 | var0 = LpVariable("Edge_" + str(i + 1) + "_" + str(j + 1) + "_0", 0, 1, LpInteger) 133 | var1 = LpVariable("Edge_" + str(i + 1) + "_" + str(j + 1) + "_1", 0, 1, LpInteger) 134 | 135 | all_vars[(i, j, 0)] = var0 136 | all_vars[(i, j, 1)] = var1 137 | 138 | f_var = LpVariable("Flow_" + str(i + 1) + "_" + str(j + 1), 0, 1000, LpInteger) 139 | all_flow_vars[(i, j)] = f_var 140 | 141 | if opt_prob is None: 142 | opt_prob = no_edge_prob[i][j] * all_vars[(i, j, 0)] + max_edge_prob[i][j] * all_vars[(i, j, 1)] 143 | else: 144 | opt_prob += no_edge_prob[i][j] * all_vars[(i, j, 0)] + max_edge_prob[i][j] * all_vars[(i, j, 1)] 145 | 146 | prob += opt_prob, "Maximum Score" 147 | 148 | # Constraints 149 | for i in range(len(no_edge_prob)): 150 | for j in range(len(no_edge_prob)): 151 | if i == j: 152 | continue 153 | # An edge can either be present or not present 154 | prob += all_vars[(i, j, 0)] + all_vars[(i, j, 1)] == 1, "Exist condition" + str(i) + "_" + str(j) 155 | 156 | # flow less than capacity 157 | prob += all_flow_vars[(i, j)] <= C[(i, j)], "Capacity constraint " + str(i) + " " + str(j) 158 | 159 | # capacity constraint of source to 1st node 160 | prob += all_flow_vars[(source_id, node_ids_present[-1])] <= C[ 161 | (source_id, node_ids_present[-1])], "Capacity constraint source " + str(node_ids_present[-1]) 162 | 163 | # capacity constraint of nodes to sink 164 | for i in range(len(node_ids_present)): 165 | temp = node_ids_present[i] 166 | prob += all_flow_vars[(temp, sink_id)] == C[(temp, sink_id)], "Capacity constraint " + str(temp) + " sink" 167 | 168 | # node flow conservation constraint 169 | for n in range(len(no_edge_prob)): 170 | if n == node_ids_present[-1]: 171 | prob += (all_flow_vars[(source_id, n)] + lpSum([all_flow_vars[(i, j)] for (i, j) in arcs if j == n]) == 172 | lpSum([all_flow_vars[(i, j)] for (i, j) in arcs if i == n]) + all_flow_vars[(n, sink_id)]), \ 173 | "Flow Conservation in Node " + str(n) 174 | else: 175 | prob += (lpSum([all_flow_vars[(i, j)] for (i, j) in arcs if j == n]) == 176 | lpSum([all_flow_vars[(i, j)] for (i, j) in arcs if i == n]) + all_flow_vars[(n, sink_id)]), \ 177 | "Flow Conservation in Node " + str(n) 178 | 179 | # Max flow should be equal to number of nodes in graph 180 | # to ensure this make the flow from source exactly equal to capacity 181 | # also ensure that the flow occurs only when the edge exists 182 | prob += all_flow_vars[(source_id, node_ids_present[-1])] == C[(source_id, node_ids_present[-1])] 183 | for i in range(len(no_edge_prob)): 184 | for j in range(len(no_edge_prob)): 185 | if i == j: 186 | continue 187 | prob += len(node_ids_present) * (all_vars[i, j, 1] + all_vars[j, i, 1]) - all_flow_vars[ 188 | (i, j)] >= 0, "Valid flow " + str( 189 | i + 1) + " " + str(j + 1) 190 | 191 | prob.solve() 192 | 193 | edges = [] 194 | edges_dict = {} 195 | for v in prob.variables(): 196 | if v.varValue > 0 and v.name.endswith("1") and v.name.startswith("Edge"): 197 | name = v.name.split("_") 198 | if name[1] == 'source' or name[2] == 'sink' or name[2] == 'source' or name[1] == 'sink': 199 | continue 200 | n_i = int(name[1]) - 1 201 | n_j = int(name[2]) - 1 202 | assert ordered_nodes[n_i] != ordered_nodes[n_j] 203 | assert (ordered_nodes[n_i], ordered_nodes[n_j]) not in edges_dict 204 | edge = "(" + ordered_nodes[n_i] + "; " + edge_label_list[max_edge_index[n_i][n_j]] + "; " + ordered_nodes[n_j] + ")" 205 | edges.append(edge) 206 | edges_dict[(ordered_nodes[n_i], ordered_nodes[n_j])] = True 207 | 208 | print("Max score = ", value(prob.objective)) 209 | 210 | print(edges) 211 | return edges 212 | -------------------------------------------------------------------------------- /structured_model/joint_model.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers import BertPreTrainedModel, RobertaConfig, \ 2 | ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaModel 3 | from pytorch_transformers.modeling_roberta import RobertaClassificationHead 4 | from torch.nn import CrossEntropyLoss 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class NodeClassificationHead(nn.Module): 10 | def __init__(self, config, num_labels_node): 11 | super(NodeClassificationHead, self).__init__() 12 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 13 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 14 | self.out_proj = nn.Linear(config.hidden_size, num_labels_node) 15 | 16 | def forward(self, features, **kwargs): 17 | x = self.dropout(features) 18 | x = self.dense(x) 19 | x = torch.tanh(x) 20 | x = self.dropout(x) 21 | x = self.out_proj(x) 22 | return x 23 | 24 | 25 | class EdgeClassificationHead(nn.Module): 26 | def __init__(self, config): 27 | super(EdgeClassificationHead, self).__init__() 28 | self.dense = nn.Linear(29 * 4 * config.hidden_size, config.hidden_size) 29 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 30 | self.out_proj = nn.Linear(config.hidden_size, 29) 31 | 32 | def forward(self, features, **kwargs): 33 | x = self.dropout(features) 34 | x = self.dense(x) 35 | x = torch.tanh(x) 36 | x = self.dropout(x) 37 | x = self.out_proj(x) 38 | return x 39 | 40 | 41 | class RobertaForEX(BertPreTrainedModel): 42 | config_class = RobertaConfig 43 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 44 | base_model_prefix = "roberta" 45 | 46 | def __init__(self, config): 47 | super(RobertaForEX, self).__init__(config) 48 | 49 | self.num_labels_node = 3 # 3-way classification for B-N, I-N, O 50 | self.num_labels_edge = 29 # 29-way classification for 28 relations and 1 no edge 51 | self.roberta = RobertaModel(config) 52 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 53 | self.classifier_node = NodeClassificationHead(config, self.num_labels_node) 54 | self.classifier_edge = EdgeClassificationHead(config) 55 | 56 | self.relation_embeddings = torch.load("./data/relations.pt").to("cuda") 57 | 58 | self.apply(self.init_weights) 59 | 60 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, node_start_index=None, node_end_index=None, 61 | node_label=None, 62 | edge_label=None, stance_label=None, position_ids=None, head_mask=None): 63 | outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 64 | attention_mask=attention_mask, head_mask=head_mask) 65 | 66 | loss_fct = CrossEntropyLoss() 67 | sequence_output = outputs[0] 68 | 69 | # Node sequence tagging loss 70 | node_logits = self.classifier_node(self.dropout(sequence_output)) 71 | node_loss = loss_fct(node_logits.view(-1, self.num_labels_node), node_label.view(-1)) 72 | 73 | # Edge embedding computation 74 | max_edges = edge_label.shape[1] 75 | batch_size = node_label.shape[0] 76 | embedding_dim = sequence_output.shape[2] 77 | 78 | batch_edge_embedding = torch.zeros((batch_size, max_edges, self.relation_embeddings.shape[0], 4 * embedding_dim)).to("cuda") 79 | 80 | for batch_index in range(batch_size): 81 | sample_node_embedding = None 82 | count = 0 83 | for (start_index, end_index) in zip(node_start_index[batch_index], node_end_index[batch_index]): 84 | if start_index == 0: 85 | break 86 | else: 87 | node_embedding = torch.mean(sequence_output[batch_index, start_index:(end_index+1), :], 88 | dim=0).unsqueeze(0) 89 | count += 1 90 | if sample_node_embedding is None: 91 | sample_node_embedding = node_embedding 92 | else: 93 | sample_node_embedding = torch.cat((sample_node_embedding, node_embedding), dim=0) 94 | 95 | repeat1 = sample_node_embedding.unsqueeze(0).repeat(len(sample_node_embedding), 1, 1) 96 | repeat2 = sample_node_embedding.unsqueeze(1).repeat(1, len(sample_node_embedding), 1) 97 | sample_edge_embedding = torch.cat((repeat1, repeat2, (repeat1 - repeat2)), dim=2) 98 | 99 | sample_edge_embedding = sample_edge_embedding.view(-1, sample_edge_embedding.shape[-1]) 100 | 101 | relation_embedding = self.relation_embeddings.unsqueeze(0).repeat(sample_edge_embedding.shape[0], 1, 1) 102 | sample_edge_embedding = sample_edge_embedding.unsqueeze(1).repeat(1, relation_embedding.shape[1], 1) 103 | 104 | sample_edge_embedding_with_relation = torch.cat((sample_edge_embedding, relation_embedding), dim=2) 105 | 106 | # Append 0s at the end (these will be ignored for loss) 107 | sample_edge_embedding_with_relation = torch.cat((sample_edge_embedding_with_relation, 108 | torch.zeros( 109 | (max_edges - len(sample_edge_embedding), relation_embedding.shape[1] 110 | , 4 * embedding_dim)).to("cuda")), dim=0) 111 | 112 | batch_edge_embedding[batch_index, :, :, :] = sample_edge_embedding_with_relation 113 | 114 | # Edge loss 115 | edge_logits = self.classifier_edge(batch_edge_embedding.view(batch_size, max_edges, -1)) 116 | edge_loss = loss_fct(edge_logits.view(-1, self.num_labels_edge), edge_label.view(-1)) 117 | total_loss = node_loss + edge_loss 118 | 119 | outputs = (node_logits, edge_logits) + outputs[2:] 120 | outputs = (total_loss, node_loss, edge_loss) + outputs 121 | 122 | return outputs # (total_loss), node_loss, edge_loss, node_logits, edge_logits, 123 | # (hidden_states), (attentions) 124 | -------------------------------------------------------------------------------- /structured_model/relation_model.py: -------------------------------------------------------------------------------- 1 | from pytorch_transformers import BertPreTrainedModel, RobertaConfig, \ 2 | ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP, RobertaModel 3 | from pytorch_transformers.modeling_roberta import RobertaClassificationHead 4 | from torch.nn import CrossEntropyLoss 5 | import torch 6 | import torch.nn as nn 7 | 8 | class EdgeClassificationHead(nn.Module): 9 | def __init__(self, config): 10 | super(EdgeClassificationHead, self).__init__() 11 | self.dense = nn.Linear(29 * 4 * config.hidden_size, config.hidden_size) 12 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 13 | self.out_proj = nn.Linear(config.hidden_size, 29) 14 | 15 | def forward(self, features, **kwargs): 16 | x = self.dropout(features) 17 | x = self.dense(x) 18 | x = torch.tanh(x) 19 | x = self.dropout(x) 20 | x = self.out_proj(x) 21 | return x 22 | 23 | 24 | class RobertaForRelationPrediction(BertPreTrainedModel): 25 | config_class = RobertaConfig 26 | pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP 27 | base_model_prefix = "roberta" 28 | 29 | def __init__(self, config): 30 | super(RobertaForRelationPrediction, self).__init__(config) 31 | 32 | self.num_labels = config.num_labels 33 | self.roberta = RobertaModel(config) 34 | self.classifier_edge = EdgeClassificationHead(config) 35 | 36 | self.relation_embeddings = torch.load("./data/relations.pt") 37 | 38 | self.apply(self.init_weights) 39 | 40 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_indices=None, end_indices=None, relation_label=None, 41 | position_ids=None, head_mask=None): 42 | outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids, 43 | attention_mask=attention_mask, head_mask=head_mask) 44 | 45 | loss_fct = CrossEntropyLoss() 46 | sequence_output = outputs[0] 47 | batch_size = relation_label.shape[0] 48 | embedding_dim = sequence_output.shape[2] 49 | self.relation_embeddings = self.relation_embeddings.to(sequence_output) 50 | batch_edge_embedding = torch.zeros((batch_size, self.relation_embeddings.shape[0], 4 * embedding_dim)).to(sequence_output) 51 | 52 | for batch_index in range(batch_size): 53 | concept1_start_index = start_indices[batch_index][0] 54 | concept1_end_index = end_indices[batch_index][0] 55 | concept1_embedding = torch.mean(sequence_output[batch_index, concept1_start_index:(concept1_end_index+1), :] 56 | , dim=0).unsqueeze(0) 57 | 58 | concept2_start_index = start_indices[batch_index][1] 59 | concept2_end_index = end_indices[batch_index][1] 60 | concept2_embedding = torch.mean(sequence_output[batch_index, concept2_start_index:(concept2_end_index+1), :] 61 | , dim=0).unsqueeze(0) 62 | 63 | edge_embedding = torch.cat((concept1_embedding, concept2_embedding, 64 | (concept1_embedding - concept2_embedding)), dim=1) 65 | edge_embedding_with_relation = torch.cat((edge_embedding.repeat(self.relation_embeddings.shape[0], 1), 66 | self.relation_embeddings), dim=1) 67 | batch_edge_embedding[batch_index, :, :] = edge_embedding_with_relation.unsqueeze(0) 68 | 69 | logits = self.classifier_edge(batch_edge_embedding.view(batch_size, -1)) 70 | loss = loss_fct(logits.view(-1, self.num_labels), relation_label.view(-1)) 71 | 72 | outputs = (loss, logits) + outputs 73 | 74 | return outputs 75 | -------------------------------------------------------------------------------- /structured_model/run_commonsense_finetuning.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import argparse 4 | import glob 5 | import json 6 | import logging 7 | import os 8 | import random 9 | import math 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, 14 | TensorDataset) 15 | from torch.utils.data.distributed import DistributedSampler 16 | from tensorboardX import SummaryWriter 17 | from tqdm import tqdm, trange 18 | from scipy.special import softmax 19 | import pathlib 20 | 21 | from pytorch_transformers import (WEIGHTS_NAME, RobertaConfig, RobertaTokenizer) 22 | 23 | from pytorch_transformers import AdamW, WarmupLinearSchedule 24 | 25 | from relation_model import RobertaForRelationPrediction 26 | from utils_relation import (compute_metrics, output_modes, processors, convert_examples_to_features) 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | MODEL_CLASSES = { 31 | 'roberta_relation': (RobertaConfig, RobertaForRelationPrediction, RobertaTokenizer) 32 | } 33 | 34 | 35 | def set_seed(args): 36 | random.seed(args.seed) 37 | np.random.seed(args.seed) 38 | torch.manual_seed(args.seed) 39 | if args.n_gpu > 0: 40 | torch.cuda.manual_seed_all(args.seed) 41 | 42 | 43 | def train(args, train_dataset, model, tokenizer): 44 | """ Train the model """ 45 | set_seed(args) # Added here for reproductibility (even between python 2 and 3) 46 | if args.local_rank in [-1, 0]: 47 | tb_writer = SummaryWriter() 48 | 49 | processor = processors[args.task_name]() 50 | 51 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 52 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 53 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 54 | 55 | if args.max_steps > 0: 56 | t_total = args.max_steps 57 | args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 58 | else: 59 | t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs 60 | 61 | # Prepare optimizer and schedule (linear warmup and decay) 62 | no_decay = ['bias', 'LayerNorm.weight'] 63 | optimizer_grouped_parameters = [ 64 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 65 | 'weight_decay': args.weight_decay}, 66 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 67 | ] 68 | 69 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 70 | if args.warmup_pct is None: 71 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) 72 | else: 73 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=math.floor(args.warmup_pct * t_total), t_total=t_total) 74 | 75 | if args.fp16: 76 | try: 77 | from apex import amp 78 | except ImportError: 79 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 80 | model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) 81 | 82 | # multi-gpu training (should be after apex fp16 initialization) 83 | if args.n_gpu > 1: 84 | model = torch.nn.DataParallel(model) 85 | 86 | # Distributed training (should be after apex fp16 initialization) 87 | if args.local_rank != -1: 88 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 89 | output_device=args.local_rank, 90 | find_unused_parameters=True) 91 | 92 | # Train! 93 | logger.info("***** Running training *****") 94 | logger.info(" Num examples = %d", len(train_dataset)) 95 | logger.info(" Num Epochs = %d", args.num_train_epochs) 96 | logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) 97 | logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", 98 | args.train_batch_size * args.gradient_accumulation_steps * ( 99 | torch.distributed.get_world_size() if args.local_rank != -1 else 1)) 100 | logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) 101 | logger.info(" Total optimization steps = %d", t_total) 102 | 103 | global_step = 0 104 | tr_loss, logging_loss = 0.0, 0.0 105 | model.zero_grad() 106 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) 107 | # set_seed(args) # Added here for reproductibility (even between python 2 and 3) 108 | for _ in train_iterator: 109 | epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0], 110 | mininterval=10, ncols=100) 111 | for step, batch in enumerate(epoch_iterator): 112 | model.train() 113 | batch = tuple(t.to(args.device) for t in batch) 114 | inputs = {'input_ids': batch[0], 115 | 'attention_mask': batch[1], 116 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet', 'bert_mc'] else None, 117 | # XLM don't use segment_ids 118 | 'start_indices': batch[3], 119 | 'end_indices': batch[4], 120 | 'relation_label': batch[5]} 121 | outputs = model(**inputs) 122 | loss, logits = outputs[:2] # model outputs are always tuple in pytorch-transformers (see doc) 123 | 124 | if args.n_gpu > 1: 125 | loss = loss.mean() # mean() to average on multi-gpu parallel training 126 | logger.info("Loss = %f", loss) 127 | if args.gradient_accumulation_steps > 1: 128 | loss = loss / args.gradient_accumulation_steps 129 | 130 | if args.fp16: 131 | with amp.scale_loss(loss, optimizer) as scaled_loss: 132 | scaled_loss.backward() 133 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 134 | else: 135 | loss.backward() 136 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 137 | 138 | tr_loss += loss.item() 139 | if (step + 1) % args.gradient_accumulation_steps == 0: 140 | optimizer.step() 141 | scheduler.step() # Update learning rate schedule 142 | model.zero_grad() 143 | global_step += 1 144 | 145 | ''' 146 | if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: 147 | # Log metrics 148 | if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well 149 | results = evaluate(args, model, tokenizer, processor, eval_split="dev") 150 | for key, value in results.items(): 151 | tb_writer.add_scalar('eval_{}'.format(key), value, global_step) 152 | tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step) 153 | tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step) 154 | logging_loss = tr_loss 155 | 156 | if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: 157 | # Save model checkpoint 158 | output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) 159 | if not os.path.exists(output_dir): 160 | os.makedirs(output_dir) 161 | model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training 162 | model_to_save.save_pretrained(output_dir) 163 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 164 | logger.info("Saving model checkpoint to %s", output_dir) 165 | ''' 166 | 167 | if args.max_steps > 0 and global_step > args.max_steps: 168 | epoch_iterator.close() 169 | break 170 | if args.max_steps > 0 and global_step > args.max_steps: 171 | train_iterator.close() 172 | break 173 | 174 | # evaluate(args, model, tokenizer, processor, prefix=global_step, eval_split="dev") 175 | 176 | if args.local_rank in [-1, 0]: 177 | tb_writer.close() 178 | 179 | return global_step, tr_loss / global_step 180 | 181 | def evaluate(args, model, tokenizer, processor, prefix="", eval_split=None): 182 | eval_task_names = (args.task_name,) 183 | eval_outputs_dirs = (args.output_dir,) 184 | 185 | assert eval_split is not None 186 | 187 | results = {} 188 | if os.path.exists("/output/metrics.json"): 189 | with open("/output/metrics.json", "r") as f: 190 | existing_results = json.loads(f.read()) 191 | f.close() 192 | results.update(existing_results) 193 | 194 | for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs): 195 | eval_dataset, examples = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True, 196 | eval_split=eval_split) 197 | 198 | if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: 199 | os.makedirs(eval_output_dir) 200 | 201 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 202 | # Note that DistributedSampler samples randomly 203 | eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) 204 | eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) 205 | 206 | # Eval! 207 | logger.info("***** Running evaluation {} on {} *****".format(prefix, eval_split)) 208 | logger.info(" Num examples = %d", len(eval_dataset)) 209 | logger.info(" Batch size = %d", args.eval_batch_size) 210 | eval_loss = 0.0 211 | nb_eval_steps = 0 212 | relation_preds = None 213 | out_relation_label_ids = None 214 | for batch in tqdm(eval_dataloader, desc="Evaluating", mininterval=10, ncols=100): 215 | model.eval() 216 | batch = tuple(t.to(args.device) for t in batch) 217 | 218 | with torch.no_grad(): 219 | inputs = {'input_ids': batch[0], 220 | 'attention_mask': batch[1], 221 | 'token_type_ids': batch[2] if args.model_type in ['bert', 'xlnet', 'bert_mc'] else None, 222 | # XLM don't use segment_ids 223 | 'start_indices': batch[3], 224 | 'end_indices': batch[4], 225 | 'relation_label': batch[5]} 226 | outputs = model(**inputs) 227 | tmp_eval_loss, relation_logits = outputs[:2] 228 | 229 | eval_loss += tmp_eval_loss.mean().item() 230 | nb_eval_steps += 1 231 | if relation_preds is None: 232 | relation_preds = relation_logits.detach().cpu().numpy() 233 | if not eval_split == "test": 234 | out_relation_label_ids = inputs['relation_label'].detach().cpu().numpy() 235 | else: 236 | relation_preds = np.append(relation_preds, relation_logits.detach().cpu().numpy(), axis=0) 237 | if not eval_split == "test": 238 | out_relation_label_ids = np.append(out_relation_label_ids, 239 | inputs['relation_label'].detach().cpu().numpy(), axis=0) 240 | 241 | eval_loss = eval_loss / nb_eval_steps 242 | relation_preds = np.argmax(relation_preds, axis=1) 243 | 244 | if not eval_split == "test": 245 | result = compute_metrics(eval_task, relation_preds, out_relation_label_ids) 246 | result_split = {} 247 | for k, v in result.items(): 248 | result_split[k + "_{}".format(eval_split)] = v 249 | results.update(result_split) 250 | 251 | output_eval_file = os.path.join(eval_output_dir, "eval_results_{}.txt".format(eval_split)) 252 | with open(output_eval_file, "w") as writer: 253 | logger.info("***** Eval results {} on {} *****".format(prefix, eval_split)) 254 | for key in sorted(result_split.keys()): 255 | logger.info(" %s = %s", key, str(result_split[key])) 256 | writer.write("%s = %s\n" % (key, str(result_split[key]))) 257 | 258 | # Relation Predictions 259 | output_pred_file = os.path.join(eval_output_dir, "predictions_{}.lst".format(eval_split)) 260 | with open(output_pred_file, "w") as writer: 261 | logger.info("***** Write predictions {} on {} *****".format(prefix, eval_split)) 262 | for relation_pred in relation_preds: 263 | writer.write("{}\n".format(processor.get_relation_labels()[relation_pred])) 264 | 265 | return results 266 | 267 | 268 | def load_and_cache_examples(args, task, tokenizer, evaluate=False, eval_split="train"): 269 | processor = processors[task]() 270 | # Load data features from cache or dataset file 271 | if args.data_cache_dir is None: 272 | data_cache_dir = args.data_dir 273 | else: 274 | data_cache_dir = args.data_cache_dir 275 | 276 | cached_features_file = os.path.join(data_cache_dir, 'cached_{}_{}_{}_{}'.format( 277 | eval_split, 278 | list(filter(None, args.model_name_or_path.split('/'))).pop(), 279 | str(args.max_seq_length), 280 | str(task))) 281 | 282 | if os.path.exists(cached_features_file): 283 | logger.info("Loading features from cached file %s", cached_features_file) 284 | features = torch.load(cached_features_file) 285 | if eval_split == "dev": 286 | examples = processor.get_dev_examples(args.data_dir, args.do_eval_edge) 287 | else: 288 | examples = None 289 | else: 290 | logger.info("Creating features from dataset file at %s", args.data_dir) 291 | 292 | if eval_split == "train": 293 | examples = processor.get_train_examples(args.data_dir) 294 | elif eval_split == "dev": 295 | examples = processor.get_dev_examples(args.data_dir) 296 | elif eval_split == "test": 297 | examples = processor.get_test_examples(args.data_dir) 298 | else: 299 | raise Exception("eval_split should be among train / dev / test") 300 | 301 | features = convert_examples_to_features(examples, processor.get_relation_labels(), args.max_seq_length, 302 | tokenizer, cls_token=tokenizer.cls_token, sep_token=tokenizer.sep_token) 303 | 304 | all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) 305 | all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) 306 | all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) 307 | all_start_indices = torch.tensor([f.start_indices for f in features], dtype=torch.long) 308 | all_end_indices = torch.tensor([f.end_indices for f in features], dtype=torch.long) 309 | all_relation_label = torch.tensor([f.relation_label for f in features], dtype=torch.long) 310 | 311 | dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_indices, all_end_indices, all_relation_label) 312 | return dataset, examples 313 | 314 | 315 | def main(): 316 | parser = argparse.ArgumentParser() 317 | 318 | ## Required parameters 319 | parser.add_argument("--data_dir", default=None, type=str, required=True, 320 | help="The input data dir. Should contain the .tsv files (or other data files) for the task.") 321 | parser.add_argument("--model_type", default=None, type=str, required=True, 322 | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) 323 | parser.add_argument("--model_name_or_path", default=None, type=str, required=True, 324 | help="Path to pre-trained model or shortcut name selected in the list: RoBERTaConfig") 325 | parser.add_argument("--task_name", default=None, type=str, required=True, 326 | help="The name of the task to train selected in the list: " + ", ".join(processors.keys())) 327 | parser.add_argument("--output_dir", default=None, type=str, required=True, 328 | help="The output directory where the model predictions and checkpoints will be written.") 329 | 330 | parser.add_argument("--data_cache_dir", default=None, type=str, 331 | help="Cache dir if it needs to be diff from data_dir") 332 | 333 | ## Other parameters 334 | parser.add_argument("--config_name", default="", type=str, 335 | help="Pretrained config name or path if not the same as model_name") 336 | parser.add_argument("--tokenizer_name", default="", type=str, 337 | help="Pretrained tokenizer name or path if not the same as model_name") 338 | parser.add_argument("--cache_dir", default="", type=str, 339 | help="Where do you want to store the pre-trained models downloaded from s3") 340 | parser.add_argument("--max_seq_length", default=300, type=int, 341 | help="The maximum total input sequence length after tokenization. Sequences longer " 342 | "than this will be truncated, sequences shorter will be padded.") 343 | parser.add_argument("--max_nodes", default=11, type=int, 344 | help="Maximum number of nodes") 345 | parser.add_argument("--do_train", action='store_true', 346 | help="Whether to run training.") 347 | parser.add_argument("--do_eval", action='store_true', 348 | help="Whether to run eval on the dev set.") 349 | parser.add_argument("--do_eval_edge", action='store_true', 350 | help="Whether to run eval for edges.") 351 | parser.add_argument("--do_prediction", action='store_true', 352 | help="Whether to run prediction on the test set. (Training will not be executed.)") 353 | parser.add_argument("--evaluate_during_training", action='store_true', 354 | help="Rul evaluation during training at each logging step.") 355 | parser.add_argument("--do_lower_case", action='store_true', 356 | help="Set this flag if you are using an uncased model.") 357 | parser.add_argument('--run_on_test', action='store_true') 358 | 359 | parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, 360 | help="Batch size per GPU/CPU for training.") 361 | parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, 362 | help="Batch size per GPU/CPU for evaluation.") 363 | parser.add_argument('--gradient_accumulation_steps', type=int, default=1, 364 | help="Number of updates steps to accumulate before performing a backward/update pass.") 365 | parser.add_argument("--learning_rate", default=1e-5, type=float, 366 | help="The initial learning rate for Adam.") 367 | parser.add_argument("--weight_decay", default=0.1, type=float, 368 | help="Weight deay if we apply some.") 369 | parser.add_argument("--adam_epsilon", default=1e-6, type=float, 370 | help="Epsilon for Adam optimizer.") 371 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 372 | help="Max gradient norm.") 373 | parser.add_argument("--num_train_epochs", default=3.0, type=float, 374 | help="Total number of training epochs to perform.") 375 | parser.add_argument("--max_steps", default=-1, type=int, 376 | help="If > 0: set total number of training steps to perform. Override num_train_epochs.") 377 | parser.add_argument("--warmup_steps", default=0, type=int, 378 | help="Linear warmup over warmup_steps.") 379 | parser.add_argument("--warmup_pct", default=None, type=float, 380 | help="Linear warmup over warmup_pct*total_steps.") 381 | 382 | parser.add_argument('--logging_steps', type=int, default=50, 383 | help="Log every X updates steps.") 384 | parser.add_argument('--save_steps', type=int, default=50, 385 | help="Save checkpoint every X updates steps.") 386 | parser.add_argument("--eval_all_checkpoints", action='store_true', 387 | help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") 388 | parser.add_argument("--no_cuda", action='store_true', 389 | help="Avoid using CUDA when available") 390 | parser.add_argument('--overwrite_output_dir', action='store_true', 391 | help="Overwrite the content of the output directory") 392 | parser.add_argument('--overwrite_cache', action='store_true', 393 | help="Overwrite the cached training and evaluation sets") 394 | parser.add_argument('--seed', type=int, default=42, 395 | help="random seed for initialization") 396 | 397 | parser.add_argument('--fp16', action='store_true', 398 | help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") 399 | parser.add_argument('--fp16_opt_level', type=str, default='O1', 400 | help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 401 | "See details at https://nvidia.github.io/apex/amp.html") 402 | parser.add_argument("--local_rank", type=int, default=-1, 403 | help="For distributed training: local_rank") 404 | parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") 405 | parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") 406 | args = parser.parse_args() 407 | 408 | if os.path.exists(args.output_dir) and os.listdir( 409 | args.output_dir) and args.do_train and not args.overwrite_output_dir: 410 | raise ValueError( 411 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format( 412 | args.output_dir)) 413 | 414 | # Setup distant debugging if needed 415 | if args.server_ip and args.server_port: 416 | # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script 417 | import ptvsd 418 | print("Waiting for debugger attach") 419 | ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) 420 | ptvsd.wait_for_attach() 421 | 422 | # Setup CUDA, GPU & distributed training 423 | if args.local_rank == -1 or args.no_cuda: 424 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 425 | args.n_gpu = torch.cuda.device_count() 426 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 427 | torch.cuda.set_device(args.local_rank) 428 | device = torch.device("cuda", args.local_rank) 429 | torch.distributed.init_process_group(backend='nccl') 430 | args.n_gpu = 1 431 | args.device = device 432 | 433 | # Setup logging 434 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 435 | datefmt='%m/%d/%Y %H:%M:%S', 436 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) 437 | logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", 438 | args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) 439 | 440 | # Set seed 441 | set_seed(args) 442 | 443 | # Prepare GLUE task 444 | args.task_name = args.task_name.lower() 445 | if args.task_name not in processors: 446 | raise ValueError("Task not found: %s" % (args.task_name)) 447 | processor = processors[args.task_name]() 448 | args.output_mode = output_modes[args.task_name] 449 | label_list = processor.get_relation_labels() 450 | num_labels_relation = len(label_list) 451 | 452 | # Load pretrained model and tokenizer 453 | if args.local_rank not in [-1, 0]: 454 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 455 | 456 | args.model_type = args.model_type.lower() 457 | config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] 458 | config = config_class.from_pretrained( 459 | args.config_name if args.config_name else args.model_name_or_path, 460 | num_labels=num_labels_relation, 461 | finetuning_task=args.task_name 462 | ) 463 | tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 464 | do_lower_case=args.do_lower_case) 465 | model = model_class.from_pretrained(args.model_name_or_path, from_tf=bool('.ckpt' in args.model_name_or_path), 466 | config=config) 467 | 468 | if args.local_rank == 0: 469 | torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab 470 | 471 | model.to(args.device) 472 | 473 | logger.info("Training/evaluation parameters %s", args) 474 | 475 | # Prediction (on test set) 476 | if args.do_prediction: 477 | results = {} 478 | logger.info("Prediction on the test set (note: Training will not be executed.) ") 479 | result = evaluate(args, model, tokenizer, processor, prefix="", eval_split="test") 480 | result = dict((k, v) for k, v in result.items()) 481 | results.update(result) 482 | logger.info("***** Experiment finished *****") 483 | return results 484 | 485 | # Training 486 | if args.do_train: 487 | train_dataset, _ = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False) 488 | global_step, tr_loss = train(args, train_dataset, model, tokenizer) 489 | logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) 490 | 491 | # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() 492 | if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): 493 | # Create output directory if needed 494 | if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: 495 | os.makedirs(args.output_dir) 496 | 497 | logger.info("Saving model checkpoint to %s", args.output_dir) 498 | # Save a trained model, configuration and tokenizer using `save_pretrained()`. 499 | # They can then be reloaded using `from_pretrained()` 500 | model_to_save = model.module if hasattr(model, 501 | 'module') else model # Take care of distributed/parallel training 502 | model_to_save.save_pretrained(args.output_dir) 503 | tokenizer.save_pretrained(args.output_dir) 504 | 505 | # Good practice: save your training arguments together with the trained model 506 | torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) 507 | 508 | # Load a trained model and vocabulary that you have fine-tuned 509 | model = model_class.from_pretrained(args.output_dir) 510 | tokenizer = tokenizer_class.from_pretrained(args.output_dir) 511 | model.to(args.device) 512 | 513 | # Evaluation 514 | results = {} 515 | checkpoints = [args.output_dir] 516 | if args.do_eval and args.local_rank in [-1, 0]: 517 | if args.eval_all_checkpoints: 518 | checkpoints = list( 519 | os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) 520 | logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging 521 | logger.info("Evaluate the following checkpoints: %s", checkpoints) 522 | for checkpoint in checkpoints: 523 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 524 | model = model_class.from_pretrained(checkpoint) 525 | model.to(args.device) 526 | result = evaluate(args, model, tokenizer, processor, prefix=global_step, eval_split="dev") 527 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 528 | results.update(result) 529 | 530 | # Run on test 531 | if args.run_on_test and args.local_rank in [-1, 0]: 532 | checkpoint = checkpoints[0] 533 | global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" 534 | model = model_class.from_pretrained(checkpoint) 535 | model.to(args.device) 536 | result = evaluate(args, model, tokenizer, processor, prefix=global_step, eval_split="test") 537 | result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) 538 | results.update(result) 539 | 540 | logger.info("***** Experiment finished *****") 541 | return results 542 | 543 | 544 | if __name__ == "__main__": 545 | main() 546 | -------------------------------------------------------------------------------- /structured_model/save_relation_embeddings.py: -------------------------------------------------------------------------------- 1 | from transformers import RobertaTokenizer, RobertaModel 2 | import torch 3 | 4 | if __name__ == '__main__': 5 | tokenizer = RobertaTokenizer.from_pretrained('roberta-large') 6 | model = RobertaModel.from_pretrained('roberta-large') 7 | relations = open("./data/relations.txt", "r").read().splitlines() 8 | relations.append("no relation") 9 | embeddings = None 10 | for relation in relations: 11 | inputs = tokenizer(relation, return_tensors="pt") 12 | embedding = model(**inputs)[1] 13 | if embeddings is None: 14 | embeddings = embedding 15 | else: 16 | embeddings = torch.cat((embeddings, embedding), dim=0) 17 | 18 | torch.save(embeddings, "./data/relations.pt") 19 | -------------------------------------------------------------------------------- /structured_model/utils_joint_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import csv 4 | import logging 5 | import os 6 | import sys 7 | from collections import OrderedDict 8 | from io import open 9 | 10 | import numpy as np 11 | from nltk import word_tokenize 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class ExplaGraphInputExample(object): 17 | def __init__(self, id, belief, argument, external, node_label_internal_belief, node_label_internal_argument, 18 | node_label_external, edge_label, stance_label): 19 | self.id = id 20 | self.belief = belief 21 | self.argument = argument 22 | self.external = external 23 | self.node_label_internal_belief = node_label_internal_belief 24 | self.node_label_internal_argument = node_label_internal_argument 25 | self.node_label_external = node_label_external 26 | self.edge_label = edge_label 27 | self.stance_label = stance_label 28 | 29 | 30 | class ExplaGraphFeatures(object): 31 | def __init__(self, id, input_ids, input_mask, segment_ids, node_start_index, node_end_index, node_label, 32 | edge_label, stance_label): 33 | self.id = id 34 | self.input_ids = input_ids 35 | self.input_mask = input_mask 36 | self.segment_ids = segment_ids 37 | self.node_start_index = node_start_index 38 | self.node_end_index = node_end_index 39 | self.node_label = node_label 40 | self.edge_label = edge_label 41 | self.stance_label = stance_label 42 | 43 | 44 | class ExplaGraphProcessor(object): 45 | def _read_tsv(self, input_file, quotechar=None): 46 | with open(input_file, "r", encoding="utf-8-sig") as f: 47 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 48 | lines = [] 49 | for line in reader: 50 | if sys.version_info[0] == 2: 51 | line = list(unicode(cell, 'utf-8') for cell in line) 52 | lines.append(line) 53 | return lines 54 | 55 | def get_train_examples(self, data_dir): 56 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "train.tsv"))) 57 | 58 | def get_dev_examples(self, data_dir, is_edge_pred=True): 59 | # If predicting nodes, then create labels using gold nodes, because don't care 60 | # But if predicting edges, create node labels using the predicting nodes 61 | if not is_edge_pred: 62 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "dev.tsv")), is_eval=True) 63 | else: 64 | return self._create_examples_with_predicted_nodes(self._read_tsv(os.path.join(data_dir, "dev.tsv")), 65 | open(os.path.join(data_dir, "internal_nodes_dev.txt"), 66 | "r").read().splitlines(), 67 | open(os.path.join(data_dir, "external_nodes_dev.txt"), 68 | "r").read().splitlines()) 69 | 70 | def get_test_examples(self, data_dir): 71 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv"))) 72 | 73 | def get_stance_labels(self): 74 | return ["support", "counter"] 75 | 76 | def get_node_labels(self): 77 | return ["B-N", "I-N", "O"] 78 | 79 | def get_edge_labels(self): 80 | return ["antonym of", "synonym of", "at location", "not at location", "capable of", "not capable of", "causes", 81 | "not causes", "created by", "not created by", "is a", "is not a", "desires", "not desires", 82 | "has subevent", "not has subevent", "part of", "not part of", "has context", "not has context", 83 | "has property", "not has property", "made of", "not made of", "receives action", "not receives action", 84 | "used for", "not used for", "no relation"] 85 | 86 | def _get_external_nodes_eval(self, belief, argument, external_nodes, internal_nodes_count): 87 | filtered_external_nodes = [] 88 | for external_node in list(set(external_nodes.split(", "))): 89 | # We'll consider a maximum of 11 nodes (9+2 shared between belief and argument) 90 | if internal_nodes_count + len(filtered_external_nodes) == 11: 91 | break 92 | if external_node in belief or external_node in argument: 93 | continue 94 | words = word_tokenize(external_node) 95 | if len(words) > 3: 96 | continue 97 | 98 | filtered_external_nodes.append(external_node) 99 | 100 | return filtered_external_nodes 101 | 102 | def _get_external_nodes_train(self, belief, argument, graph): 103 | external_nodes = [] 104 | for edge in graph[1:-1].split(")("): 105 | edge_parts = edge.split("; ") 106 | if edge_parts[0] not in belief and edge_parts[0] not in argument and edge_parts[0] not in external_nodes: 107 | external_nodes.append(edge_parts[0]) 108 | if edge_parts[2] not in belief and edge_parts[2] not in argument and edge_parts[2] not in external_nodes: 109 | external_nodes.append(edge_parts[2]) 110 | 111 | return external_nodes 112 | 113 | def _get_internal_nodes(self, belief, argument, graph): 114 | internal_nodes = {} 115 | for edge in graph[1:-1].split(")("): 116 | edge_parts = edge.split("; ") 117 | if edge_parts[0] in belief or edge_parts[0] in argument: 118 | length = len(edge_parts[0].split(" ")) 119 | if length not in internal_nodes: 120 | internal_nodes[length] = [edge_parts[0]] 121 | elif edge_parts[0] not in internal_nodes[length]: 122 | internal_nodes[length].append(edge_parts[0]) 123 | if edge_parts[2] in belief or edge_parts[2] in argument: 124 | length = len(edge_parts[2].split(" ")) 125 | if length not in internal_nodes: 126 | internal_nodes[length] = [edge_parts[2]] 127 | elif edge_parts[2] not in internal_nodes[length]: 128 | internal_nodes[length].append(edge_parts[2]) 129 | 130 | return internal_nodes 131 | 132 | def _get_edge_label(self, node_label_internal_belief, belief, node_label_internal_argument, argument, 133 | external_nodes, graph): 134 | 135 | edge_label_map = {label: i for i, label in enumerate(self.get_edge_labels())} 136 | 137 | gold_edges = {} 138 | for edge in graph[1:-1].split(")("): 139 | parts = edge.split("; ") 140 | gold_edges[parts[0], parts[2]] = parts[1] 141 | 142 | ordered_nodes = [] 143 | for i, (word, node_label) in enumerate(zip(belief, node_label_internal_belief)): 144 | if node_label == "B-N": 145 | node = word 146 | if i + 1 < len(belief) and node_label_internal_belief[i + 1] == "I-N": 147 | node += " " + belief[i + 1] 148 | if i + 2 < len(belief) and node_label_internal_belief[i + 2] == "I-N": 149 | node += " " + belief[i + 2] 150 | 151 | ordered_nodes.append(node) 152 | 153 | for i, (word, node_label) in enumerate(zip(argument, node_label_internal_argument)): 154 | if node_label == "B-N": 155 | node = word 156 | if i + 1 < len(argument) and node_label_internal_argument[i + 1] == "I-N": 157 | node += " " + argument[i + 1] 158 | if i + 2 < len(argument) and node_label_internal_argument[i + 2] == "I-N": 159 | node += " " + argument[i + 2] 160 | 161 | ordered_nodes.append(node) 162 | 163 | ordered_nodes.extend(external_nodes) 164 | 165 | edge_label = np.zeros((len(ordered_nodes), len(ordered_nodes)), dtype=int) 166 | 167 | for i in range(len(edge_label)): 168 | for j in range(len(edge_label)): 169 | if i == j: 170 | edge_label[i][j] = -100 171 | elif (ordered_nodes[i], ordered_nodes[j]) in gold_edges: 172 | edge_label[i][j] = edge_label_map[gold_edges[(ordered_nodes[i], ordered_nodes[j])]] 173 | else: 174 | edge_label[i][j] = edge_label_map["no relation"] 175 | 176 | return list(edge_label.flatten()) 177 | 178 | def _get_node_label_internal(self, internal_nodes, words): 179 | labels = ["O"] * len(words) 180 | 181 | for length in range(3, 0, -1): 182 | if length not in internal_nodes: 183 | continue 184 | nodes = internal_nodes[length] 185 | for node in nodes: 186 | node_words = node.split(" ") 187 | for (i, word) in enumerate(words): 188 | if length == 3 and i < len(words) - 2 and words[i] == node_words[0] and words[i + 1] == node_words[ 189 | 1] and words[i + 2] == node_words[2]: 190 | if labels[i] == "O" and labels[i + 1] == "O" and labels[i + 2] == "O": 191 | labels[i] = "B-N" 192 | labels[i + 1] = "I-N" 193 | labels[i + 2] = "I-N" 194 | if length == 2 and i < len(words) - 1 and words[i] == node_words[0] and words[i + 1] == node_words[ 195 | 1]: 196 | if labels[i] == "O" and labels[i + 1] == "O": 197 | labels[i] = "B-N" 198 | labels[i + 1] = "I-N" 199 | if length == 1 and words[i] == node_words[0]: 200 | if labels[i] == "O": 201 | labels[i] = "B-N" 202 | 203 | return labels 204 | 205 | def _get_node_label_external(self, external_nodes): 206 | labels = [] 207 | for external_node in external_nodes: 208 | length = len(word_tokenize(external_node)) 209 | labels.extend(["B-N"] + ["I-N"] * (length - 1)) 210 | 211 | return labels 212 | 213 | def _create_examples(self, records, is_eval=False): 214 | examples = [] 215 | 216 | max_edge_length = 0 217 | for (i, record) in enumerate(records): 218 | belief = record[0].lower() 219 | argument = record[1].lower() 220 | stance_label = record[2] 221 | graph = record[3].lower() 222 | 223 | belief_words = word_tokenize(belief) 224 | argument_words = word_tokenize(argument) 225 | 226 | internal_nodes = self._get_internal_nodes(belief, argument, graph) 227 | node_label_internal_belief = self._get_node_label_internal(internal_nodes, belief_words) 228 | node_label_internal_argument = self._get_node_label_internal(internal_nodes, argument_words) 229 | 230 | # If evaluating, external nodes are not required for tagging because they will come from generation model 231 | external_nodes = self._get_external_nodes_train(belief, argument, graph) if not is_eval else [] 232 | 233 | node_label_external = self._get_node_label_external(external_nodes) 234 | 235 | edge_label = self._get_edge_label(node_label_internal_belief, belief_words, node_label_internal_argument, 236 | argument_words, 237 | external_nodes, graph) 238 | 239 | max_edge_length = max(max_edge_length, len(edge_label)) 240 | 241 | external = [] 242 | for external_node in external_nodes: 243 | external.extend(word_tokenize(external_node)) 244 | 245 | examples.append( 246 | ExplaGraphInputExample(id=i, belief=belief_words, argument=argument_words, external=external, 247 | node_label_internal_belief=node_label_internal_belief, 248 | node_label_internal_argument=node_label_internal_argument, 249 | node_label_external=node_label_external, edge_label=edge_label, 250 | stance_label=stance_label)) 251 | 252 | return examples 253 | 254 | def _get_unique_node_count(self, belief, argument, node_label_internal_belief, node_label_internal_argument): 255 | nodes = [] 256 | for i, (word, node_label) in enumerate(zip(belief, node_label_internal_belief)): 257 | if node_label == "B-N": 258 | node = word 259 | if i + 1 < len(belief) and node_label_internal_belief[i + 1] == "I-N": 260 | node += " " + belief[i + 1] 261 | if i + 2 < len(belief) and node_label_internal_belief[i + 2] == "I-N": 262 | node += " " + belief[i + 2] 263 | 264 | nodes.append(node) 265 | 266 | for i, (word, node_label) in enumerate(zip(argument, node_label_internal_argument)): 267 | if node_label == "B-N": 268 | node = word 269 | if i + 1 < len(argument) and node_label_internal_argument[i + 1] == "I-N": 270 | node += " " + argument[i + 1] 271 | if i + 2 < len(argument) and node_label_internal_argument[i + 2] == "I-N": 272 | node += " " + argument[i + 2] 273 | 274 | if node not in nodes: 275 | nodes.append(node) 276 | 277 | return len(nodes) 278 | 279 | def _create_examples_with_predicted_nodes(self, records, internal_nodes, external_nodes): 280 | assert len(records) == len(external_nodes) 281 | examples = [] 282 | 283 | sample_breaks = [i for i, x in enumerate(internal_nodes) if x == ""] 284 | 285 | max_node_count = 0 286 | for (i, record) in enumerate(records): 287 | belief = record[0].lower() 288 | argument = record[1].lower() 289 | stance_label = record[2] 290 | 291 | belief_words = word_tokenize(belief) 292 | argument_words = word_tokenize(argument) 293 | 294 | start = 0 if i == 0 else sample_breaks[i - 1] + 1 295 | end = sample_breaks[i] 296 | belief_lines = internal_nodes[start:(start + len(belief_words))] 297 | argument_lines = internal_nodes[(start + len(belief_words)):end] 298 | 299 | node_label_internal_belief = [belief_line.split("\t")[1] for belief_line in belief_lines] 300 | node_label_internal_argument = [argument_line.split("\t")[1] for argument_line in argument_lines] 301 | node_count = self._get_unique_node_count(belief_words, argument_words, node_label_internal_belief, 302 | node_label_internal_argument) 303 | 304 | external = [] 305 | node_label_external = [] 306 | for external_node in list(OrderedDict.fromkeys(external_nodes[i].split(", "))): 307 | # Allowing a maximum of 9 unique nodes, as per the task 308 | if node_count >= 8: 309 | break 310 | if external_node in belief or external_node in argument: 311 | continue 312 | words = word_tokenize(external_node) 313 | if len(words) > 3: 314 | continue 315 | node_count += 1 316 | external.extend(words) 317 | node_label_external.extend(["B-N"] + ["I-N"] * (len(words) - 1)) 318 | 319 | max_node_count = max(max_node_count, node_count) 320 | 321 | edge_label = np.zeros((node_count, node_count), dtype=int) 322 | 323 | for a in range(len(edge_label)): 324 | for b in range(len(edge_label)): 325 | if a == b: 326 | edge_label[a][b] = -100 327 | else: 328 | edge_label[a][b] = 0 # Don't care, some placeholder value 329 | 330 | edge_label = list(edge_label.flatten()) 331 | 332 | examples.append( 333 | ExplaGraphInputExample(id=i, belief=belief_words, argument=argument_words, external=external, 334 | node_label_internal_belief=node_label_internal_belief, 335 | node_label_internal_argument=node_label_internal_argument, 336 | node_label_external=node_label_external, edge_label=edge_label, 337 | stance_label=stance_label)) 338 | 339 | print(max_node_count) 340 | return examples 341 | 342 | def get_word_start_indices(examples, tokenizer, cls_token, sep_token): 343 | all_word_start_indices = [] 344 | for (ex_index, example) in enumerate(examples): 345 | word_start_indices = [] 346 | print(ex_index) 347 | 348 | tokens = [cls_token] 349 | 350 | for word in example.belief: 351 | word_tokens = tokenizer.tokenize(word) 352 | if len(word_tokens) > 0: 353 | word_start_indices.append(len(tokens)) 354 | tokens.extend(word_tokens) 355 | 356 | tokens = tokens + [sep_token] + [sep_token] 357 | 358 | for word in example.argument: 359 | word_tokens = tokenizer.tokenize(word) 360 | if len(word_tokens) > 0: 361 | word_start_indices.append(len(tokens)) 362 | tokens.extend(word_tokens) 363 | 364 | all_word_start_indices.append(word_start_indices) 365 | 366 | return all_word_start_indices 367 | 368 | 369 | def convert_examples_to_features(examples, 370 | stance_label_list, 371 | node_label_list, 372 | max_seq_length, 373 | max_nodes, 374 | tokenizer, 375 | cls_token='[CLS]', 376 | sep_token='[SEP]'): 377 | # The encoding is based on RoBERTa (and hence segment ids don't matter) 378 | node_label_map = {label: i for i, label in enumerate(node_label_list)} 379 | stance_label_map = {label: i for i, label in enumerate(stance_label_list)} 380 | 381 | features = [] 382 | for (ex_index, example) in enumerate(examples): 383 | print(ex_index) 384 | if ex_index % 10000 == 0: 385 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 386 | 387 | tokens = [cls_token] 388 | node_label_ids = [-100] 389 | node_start_index, node_end_index = [], [] 390 | 391 | # Encode the belief 392 | for word, label in zip(example.belief, example.node_label_internal_belief): 393 | word_tokens = tokenizer.tokenize(word) 394 | if len(word_tokens) > 0: 395 | if label == "B-N": 396 | node_start_index.append(len(tokens)) 397 | tokens.extend(word_tokens) 398 | if label == "B-N": 399 | node_end_index.append(len(tokens) - 1) 400 | elif label == "I-N": 401 | node_end_index[len(node_end_index) - 1] = len(tokens) - 1 # Update the end index 402 | 403 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 404 | # node_label_ids.extend([node_label_map[label]] + [-100] * (len(word_tokens) - 1)) 405 | if label == "B-N": 406 | node_label_ids.extend( 407 | [node_label_map[label]] + [node_label_map["I-N"]] * (len(word_tokens) - 1)) 408 | else: 409 | node_label_ids.extend([node_label_map[label]] * len(word_tokens)) 410 | 411 | tokens = tokens + [sep_token] + [sep_token] 412 | node_label_ids = node_label_ids + [-100, -100] 413 | 414 | # Encode the argument 415 | for word, label in zip(example.argument, example.node_label_internal_argument): 416 | word_tokens = tokenizer.tokenize(word) 417 | if len(word_tokens) > 0: 418 | if label == "B-N": 419 | node_start_index.append(len(tokens)) 420 | tokens.extend(word_tokens) 421 | if label == "B-N": 422 | node_end_index.append(len(tokens) - 1) 423 | elif label == "I-N": 424 | node_end_index[len(node_end_index) - 1] = len(tokens) - 1 # Update the end index 425 | 426 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 427 | # node_label_ids.extend([node_label_map[label]] + [-100] * (len(word_tokens) - 1)) 428 | if label == "B-N": 429 | node_label_ids.extend([node_label_map[label]] + [node_label_map["I-N"]] * (len(word_tokens) - 1)) 430 | else: 431 | node_label_ids.extend([node_label_map[label]] * len(word_tokens)) 432 | 433 | tokens = tokens + [sep_token] + [sep_token] 434 | node_label_ids = node_label_ids + [-100, -100] 435 | 436 | # Encode the external concepts 437 | for word, label in zip(example.external, example.node_label_external): 438 | word_tokens = tokenizer.tokenize(word) 439 | if len(word_tokens) > 0: 440 | if label == "B-N": 441 | node_start_index.append(len(tokens)) 442 | tokens.extend(word_tokens) 443 | if label == "B-N": 444 | node_end_index.append(len(tokens) - 1) 445 | elif label == "I-N": 446 | node_end_index[len(node_end_index) - 1] = len(tokens) - 1 # Update the end index 447 | 448 | # Use the real label id for the first token of the word, and padding ids for the remaining tokens 449 | # node_label_ids.extend([node_label_map[label]] + [-100] * (len(word_tokens) - 1)) 450 | if label == "B-N": 451 | node_label_ids.extend( 452 | [node_label_map[label]] + [node_label_map["I-N"]] * (len(word_tokens) - 1)) 453 | else: 454 | node_label_ids.extend([node_label_map[label]] * len(word_tokens)) 455 | 456 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 457 | 458 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 459 | # tokens are attended to. 460 | input_mask = [1] * len(input_ids) 461 | 462 | # Zero-pad up to the sequence length. 463 | padding_length = max_seq_length - len(input_ids) 464 | 465 | input_ids = input_ids + ([0] * padding_length) 466 | input_mask = input_mask + ([0] * padding_length) 467 | node_label = node_label_ids + ([-100] * padding_length) 468 | segment_ids = [0] * len(input_ids) 469 | 470 | padding_length = max_seq_length - len(node_start_index) 471 | node_start_index = node_start_index + ([0] * padding_length) 472 | node_end_index = node_end_index + ([0] * padding_length) 473 | edge_label = example.edge_label + [-100] * (max_nodes * max_nodes - len(example.edge_label)) 474 | 475 | stance_label = stance_label_map[example.stance_label] 476 | 477 | assert len(input_ids) == max_seq_length 478 | assert len(input_mask) == max_seq_length 479 | assert len(segment_ids) == max_seq_length 480 | assert len(node_start_index) == max_seq_length 481 | assert len(node_end_index) == max_seq_length 482 | assert len(edge_label) == max_nodes * max_nodes 483 | 484 | if ex_index < 5: 485 | logger.info("*** Example ***") 486 | logger.info("id: %s" % (example.id)) 487 | logger.info("tokens: %s" % " ".join( 488 | [str(x) for x in tokens])) 489 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 490 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 491 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 492 | logger.info("node_start_index: %s" % " ".join([str(x) for x in node_start_index])) 493 | logger.info("node_end_index: %s" % " ".join([str(x) for x in node_end_index])) 494 | logger.info("node_label: %s" % " ".join([str(x) for x in node_label])) 495 | logger.info("edge_label: %s" % " ".join([str(x) for x in edge_label])) 496 | logger.info("label: %s (id = %d)" % (example.stance_label, stance_label)) 497 | 498 | features.append( 499 | ExplaGraphFeatures(id=id, 500 | input_ids=input_ids, 501 | input_mask=input_mask, 502 | segment_ids=segment_ids, 503 | node_start_index=node_start_index, 504 | node_end_index=node_end_index, 505 | node_label=node_label, 506 | edge_label=edge_label, 507 | stance_label=stance_label)) 508 | 509 | return features 510 | 511 | 512 | def simple_accuracy(preds, labels): 513 | return (preds == labels).mean() 514 | 515 | 516 | def compute_metrics(task_name, preds, labels): 517 | assert len(preds) == len(labels) 518 | if task_name == "eg": 519 | return {"acc": simple_accuracy(preds, labels)} 520 | else: 521 | raise KeyError(task_name) 522 | 523 | 524 | def write_node_predictions_to_file(writer, test_input_reader, preds_list): 525 | example_id = 0 526 | for line in test_input_reader: 527 | if line.startswith("-DOCSTART-") or line == "" or line == "\n": 528 | writer.write(line) 529 | if not preds_list[example_id]: 530 | example_id += 1 531 | elif preds_list[example_id]: 532 | output_line = line.split()[0] + " " + preds_list[example_id].pop(0) + "\n" 533 | writer.write(output_line) 534 | else: 535 | logger.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0]) 536 | 537 | 538 | processors = { 539 | "eg": ExplaGraphProcessor 540 | } 541 | 542 | output_modes = { 543 | "eg": "classification" 544 | } 545 | -------------------------------------------------------------------------------- /structured_model/utils_relation.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import csv 4 | import logging 5 | import os 6 | import sys 7 | from io import open 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class RelationInputExample(object): 13 | def __init__(self, id, concept1, concept2, relation): 14 | self.id = id 15 | self.concept1 = concept1 16 | self.concept2 = concept2 17 | self.relation = relation 18 | 19 | 20 | class RelationFeatures(object): 21 | def __init__(self, id, input_ids, input_mask, segment_ids, start_indices, end_indices, relation_label): 22 | self.id = id 23 | self.input_ids = input_ids 24 | self.input_mask = input_mask 25 | self.segment_ids = segment_ids 26 | self.start_indices = start_indices 27 | self.end_indices = end_indices 28 | self.relation_label = relation_label 29 | 30 | 31 | class RelationProcessor(object): 32 | def _read_tsv(self, input_file, quotechar=None): 33 | with open(input_file, "r", encoding="utf-8-sig") as f: 34 | reader = csv.reader(f, delimiter="\t", quotechar=quotechar) 35 | lines = [] 36 | for line in reader: 37 | if sys.version_info[0] == 2: 38 | line = list(unicode(cell, 'utf-8') for cell in line) 39 | lines.append(line) 40 | return lines 41 | 42 | def get_train_examples(self, data_dir): 43 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "conceptnet_train.tsv"))) 44 | 45 | def get_dev_examples(self, data_dir): 46 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "conceptnet_test.tsv"))) 47 | 48 | def get_test_examples(self, data_dir): 49 | return self._create_examples(self._read_tsv(os.path.join(data_dir, "test.tsv"))) 50 | 51 | def get_relation_labels(self): 52 | return ["antonym of", "synonym of", "at location", "not at location", "capable of", "not capable of", "causes", 53 | "not causes", "created by", "not created by", "is a", "is not a", "desires", "not desires", 54 | "has subevent", "not has subevent", "part of", "not part of", "has context", "not has context", 55 | "has property", "not has property", "made of", "not made of", "receives action", "not receives action", 56 | "used for", "not used for", "no relation"] 57 | 58 | def _create_examples(self, records): 59 | examples = [] 60 | for (i, record) in enumerate(records): 61 | concept1 = record[0] 62 | relation = record[1] 63 | concept2 = record[2] 64 | 65 | examples.append( 66 | RelationInputExample(id=i, concept1=concept1, concept2=concept2, relation=relation)) 67 | 68 | return examples 69 | 70 | 71 | def convert_examples_to_features(examples, 72 | relation_label_list, 73 | max_seq_length, 74 | tokenizer, 75 | cls_token='[CLS]', 76 | sep_token='[SEP]'): 77 | # The encoding is based on RoBERTa (and hence segment ids don't matter) 78 | relation_label_map = {label: i for i, label in enumerate(relation_label_list)} 79 | 80 | features = [] 81 | for (ex_index, example) in enumerate(examples): 82 | print(ex_index) 83 | if ex_index % 10000 == 0: 84 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 85 | 86 | concept1 = " ".join(example.concept1.split("_")) 87 | concept2 = " ".join(example.concept2.split("_")) 88 | 89 | if len(concept1) == 0 or len(concept2) == 0: 90 | continue 91 | 92 | start_indices, end_indices = [], [] 93 | tokens = [cls_token] 94 | start_indices.append(len(tokens)) 95 | 96 | tokens += tokenizer.tokenize(concept1) 97 | end_indices.append(len(tokens)-1) 98 | 99 | tokens += [sep_token, sep_token] 100 | 101 | start_indices.append(len(tokens)) 102 | tokens += tokenizer.tokenize(concept2) 103 | end_indices.append(len(tokens)-1) 104 | 105 | assert start_indices[0] <= end_indices[0] 106 | assert start_indices[1] <= end_indices[1] 107 | 108 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 109 | 110 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 111 | # tokens are attended to. 112 | input_mask = [1] * len(input_ids) 113 | 114 | if len(input_ids) >= max_seq_length: 115 | continue 116 | 117 | # Zero-pad up to the sequence length. 118 | padding_length = max_seq_length - len(input_ids) 119 | 120 | input_ids = input_ids + ([0] * padding_length) 121 | input_mask = input_mask + ([0] * padding_length) 122 | segment_ids = [0] * len(input_ids) 123 | 124 | relation_label = relation_label_map[example.relation] 125 | 126 | assert len(input_ids) == max_seq_length 127 | assert len(input_mask) == max_seq_length 128 | assert len(segment_ids) == max_seq_length 129 | 130 | if ex_index < 5: 131 | logger.info("*** Example ***") 132 | logger.info("id: %s" % (example.id)) 133 | logger.info("tokens: %s" % " ".join( 134 | [str(x) for x in tokens])) 135 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 136 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 137 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 138 | logger.info("start_indices: %s" % " ".join([str(x) for x in start_indices])) 139 | logger.info("end_indices: %s" % " ".join([str(x) for x in end_indices])) 140 | logger.info("relation label: %s (id = %d)" % (example.relation, relation_label)) 141 | 142 | features.append( 143 | RelationFeatures(id=id, 144 | input_ids=input_ids, 145 | input_mask=input_mask, 146 | segment_ids=segment_ids, 147 | start_indices=start_indices, 148 | end_indices=end_indices, 149 | relation_label=relation_label)) 150 | 151 | return features 152 | 153 | 154 | def simple_accuracy(preds, labels): 155 | return (preds == labels).mean() 156 | 157 | 158 | def compute_metrics(task_name, preds, labels): 159 | assert len(preds) == len(labels) 160 | if task_name == "relation": 161 | return {"acc": simple_accuracy(preds, labels)} 162 | else: 163 | raise KeyError(task_name) 164 | 165 | 166 | processors = { 167 | "relation": RelationProcessor 168 | } 169 | 170 | output_modes = { 171 | "relation": "classification" 172 | } 173 | -------------------------------------------------------------------------------- /tmp/README.md: -------------------------------------------------------------------------------- 1 | Folder to cache pre-trained language models. --------------------------------------------------------------------------------