├── .flake8 ├── .gitattributes ├── .gitignore ├── .mypy.ini ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── app ├── client │ ├── .dockerignore │ ├── .env │ ├── .gitignore │ ├── Dockerfile │ ├── README.md │ ├── package-lock.json │ ├── package.json │ ├── public │ │ ├── CNAME │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-512x512.png │ │ ├── apple-touch-icon.png │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ ├── favicon.ico │ │ ├── index.html │ │ ├── lab-logo.png │ │ ├── manifest.json │ │ └── robots.txt │ └── src │ │ ├── App.css │ │ ├── App.js │ │ ├── App.test.js │ │ ├── api.js │ │ ├── index.css │ │ ├── index.js │ │ ├── relations.js │ │ ├── reportWebVitals.js │ │ └── setupTests.js ├── docker-compose.yml ├── proxy-nginx.conf └── server │ ├── .dockerignore │ ├── Dockerfile │ ├── inference.py │ ├── requirements.txt │ └── server.py ├── docker ├── Dockerfile ├── build.sh ├── dev.Dockerfile ├── entrypoint.sh ├── push.sh └── train.py ├── docs ├── Makefile ├── _static │ └── css │ │ └── custom.css ├── api.rst ├── conf.py ├── getting_started.rst ├── index.rst ├── make.bat ├── requirements.txt └── userguide.rst ├── eacl2023 ├── kogito-poster-eacl2023.pdf └── kogito-presentation-eacl2023.pdf ├── examples ├── data │ ├── atomic2020 │ │ ├── LICENSE │ │ ├── README.md │ │ ├── sample_dev.tsv │ │ ├── sample_test.tsv │ │ └── sample_train.tsv │ ├── cc_news_samples.txt │ ├── dialog_samples.txt │ └── story_samples.txt ├── demo.ipynb ├── eacl.ipynb ├── evaluate_comet_bart.py ├── evaluate_comet_gpt2.py ├── results │ ├── cc_news │ │ ├── cc_news_samples_1.json │ │ ├── cc_news_samples_10.json │ │ ├── cc_news_samples_11.json │ │ ├── cc_news_samples_12.json │ │ ├── cc_news_samples_13.json │ │ ├── cc_news_samples_14.json │ │ ├── cc_news_samples_15.json │ │ ├── cc_news_samples_16.json │ │ ├── cc_news_samples_17.json │ │ ├── cc_news_samples_18.json │ │ ├── cc_news_samples_19.json │ │ ├── cc_news_samples_2.json │ │ ├── cc_news_samples_20.json │ │ ├── cc_news_samples_21.json │ │ ├── cc_news_samples_22.json │ │ ├── cc_news_samples_23.json │ │ ├── cc_news_samples_24.json │ │ ├── cc_news_samples_25.json │ │ ├── cc_news_samples_26.json │ │ ├── cc_news_samples_27.json │ │ ├── cc_news_samples_28.json │ │ ├── cc_news_samples_29.json │ │ ├── cc_news_samples_3.json │ │ ├── cc_news_samples_30.json │ │ ├── cc_news_samples_31.json │ │ ├── cc_news_samples_32.json │ │ ├── cc_news_samples_33.json │ │ ├── cc_news_samples_34.json │ │ ├── cc_news_samples_35.json │ │ ├── cc_news_samples_36.json │ │ ├── cc_news_samples_37.json │ │ ├── cc_news_samples_38.json │ │ ├── cc_news_samples_39.json │ │ ├── cc_news_samples_4.json │ │ ├── cc_news_samples_40.json │ │ ├── cc_news_samples_41.json │ │ ├── cc_news_samples_42.json │ │ ├── cc_news_samples_43.json │ │ ├── cc_news_samples_44.json │ │ ├── cc_news_samples_45.json │ │ ├── cc_news_samples_46.json │ │ ├── cc_news_samples_47.json │ │ ├── cc_news_samples_48.json │ │ ├── cc_news_samples_49.json │ │ ├── cc_news_samples_5.json │ │ ├── cc_news_samples_50.json │ │ ├── cc_news_samples_51.json │ │ ├── cc_news_samples_52.json │ │ ├── cc_news_samples_53.json │ │ ├── cc_news_samples_54.json │ │ ├── cc_news_samples_55.json │ │ ├── cc_news_samples_56.json │ │ ├── cc_news_samples_57.json │ │ ├── cc_news_samples_58.json │ │ ├── cc_news_samples_59.json │ │ ├── cc_news_samples_6.json │ │ ├── cc_news_samples_60.json │ │ ├── cc_news_samples_61.json │ │ ├── cc_news_samples_62.json │ │ ├── cc_news_samples_63.json │ │ ├── cc_news_samples_64.json │ │ ├── cc_news_samples_65.json │ │ ├── cc_news_samples_66.json │ │ ├── cc_news_samples_67.json │ │ ├── cc_news_samples_7.json │ │ ├── cc_news_samples_8.json │ │ └── cc_news_samples_9.json │ ├── cometbart_results_test_atomic2020.json │ ├── daily_dialog │ │ ├── dialog_samples_1.json │ │ ├── dialog_samples_10.json │ │ ├── dialog_samples_11.json │ │ ├── dialog_samples_12.json │ │ ├── dialog_samples_13.json │ │ ├── dialog_samples_14.json │ │ ├── dialog_samples_15.json │ │ ├── dialog_samples_16.json │ │ ├── dialog_samples_17.json │ │ ├── dialog_samples_18.json │ │ ├── dialog_samples_19.json │ │ ├── dialog_samples_2.json │ │ ├── dialog_samples_20.json │ │ ├── dialog_samples_21.json │ │ ├── dialog_samples_22.json │ │ ├── dialog_samples_23.json │ │ ├── dialog_samples_24.json │ │ ├── dialog_samples_25.json │ │ ├── dialog_samples_26.json │ │ ├── dialog_samples_27.json │ │ ├── dialog_samples_28.json │ │ ├── dialog_samples_29.json │ │ ├── dialog_samples_3.json │ │ ├── dialog_samples_30.json │ │ ├── dialog_samples_31.json │ │ ├── dialog_samples_32.json │ │ ├── dialog_samples_33.json │ │ ├── dialog_samples_34.json │ │ ├── dialog_samples_35.json │ │ ├── dialog_samples_36.json │ │ ├── dialog_samples_37.json │ │ ├── dialog_samples_38.json │ │ ├── dialog_samples_39.json │ │ ├── dialog_samples_4.json │ │ ├── dialog_samples_40.json │ │ ├── dialog_samples_41.json │ │ ├── dialog_samples_42.json │ │ ├── dialog_samples_43.json │ │ ├── dialog_samples_44.json │ │ ├── dialog_samples_45.json │ │ ├── dialog_samples_46.json │ │ ├── dialog_samples_47.json │ │ ├── dialog_samples_48.json │ │ ├── dialog_samples_49.json │ │ ├── dialog_samples_5.json │ │ ├── dialog_samples_50.json │ │ ├── dialog_samples_6.json │ │ ├── dialog_samples_7.json │ │ ├── dialog_samples_8.json │ │ └── dialog_samples_9.json │ ├── gpt3_prompt_2c976a77.txt │ ├── gpt3_prompt_56fcb840.txt │ ├── gpt3_prompt_99e17fae.txt │ ├── kgraph.json │ ├── kgraph_dry_run.json │ ├── kgraph_dry_run_2.json │ ├── kgraph_full.json │ ├── kgraph_gpt3.json │ ├── kgraph_gpt3_custom_relation.json │ ├── kgraph_gpt3_dry_run.json │ ├── kgraph_manual.json │ ├── kgraph_manual_heads.json │ ├── kgraph_modelbased_relations.json │ ├── kgraph_modelbased_relations_bert.json │ ├── kgraph_modelbased_relations_dbert.json │ ├── kgraph_modelbased_relations_swem.json │ ├── kgraph_no_head_extract.json │ ├── kgraph_no_match_subset.json │ ├── kgraph_rel_subset.json │ ├── kgraph_with_context.json │ ├── kgraph_without_context.json │ ├── roc_stories │ │ ├── story_samples_1.json │ │ ├── story_samples_10.json │ │ ├── story_samples_11.json │ │ ├── story_samples_12.json │ │ ├── story_samples_13.json │ │ ├── story_samples_14.json │ │ ├── story_samples_15.json │ │ ├── story_samples_16.json │ │ ├── story_samples_17.json │ │ ├── story_samples_18.json │ │ ├── story_samples_19.json │ │ ├── story_samples_2.json │ │ ├── story_samples_20.json │ │ ├── story_samples_21.json │ │ ├── story_samples_22.json │ │ ├── story_samples_23.json │ │ ├── story_samples_24.json │ │ ├── story_samples_25.json │ │ ├── story_samples_26.json │ │ ├── story_samples_27.json │ │ ├── story_samples_28.json │ │ ├── story_samples_29.json │ │ ├── story_samples_3.json │ │ ├── story_samples_30.json │ │ ├── story_samples_31.json │ │ ├── story_samples_32.json │ │ ├── story_samples_33.json │ │ ├── story_samples_34.json │ │ ├── story_samples_35.json │ │ ├── story_samples_36.json │ │ ├── story_samples_37.json │ │ ├── story_samples_38.json │ │ ├── story_samples_39.json │ │ ├── story_samples_4.json │ │ ├── story_samples_40.json │ │ ├── story_samples_41.json │ │ ├── story_samples_42.json │ │ ├── story_samples_43.json │ │ ├── story_samples_44.json │ │ ├── story_samples_45.json │ │ ├── story_samples_46.json │ │ ├── story_samples_47.json │ │ ├── story_samples_48.json │ │ ├── story_samples_49.json │ │ ├── story_samples_5.json │ │ ├── story_samples_50.json │ │ ├── story_samples_6.json │ │ ├── story_samples_7.json │ │ ├── story_samples_8.json │ │ └── story_samples_9.json │ ├── test_atomic2020_res_cometgpt2.json │ ├── test_atomic2020_res_cometgpt2_sample.json │ └── test_atomic2020_res_zeroshot_sample.jsonl ├── sample_graph.jsonl ├── sample_graph.tsv ├── sample_graph2.tsv ├── sample_graph3.jsonl ├── sample_linking_graph.csv ├── snapshots │ ├── custom-relation.png │ ├── kg-concepts.png │ ├── pipeline.png │ └── quickstart.png ├── test.ipynb ├── test.py ├── test_atomic2020.json ├── test_atomic2020_sample.json ├── test_comet_bart.py ├── test_comet_gpt2.py ├── test_gpt2.json ├── test_zeroshot.py ├── train_comet_bart.py └── train_comet_gpt2.py ├── experiments ├── fact_linking │ ├── data │ │ ├── roc_full_pipeline_result_samples.json │ │ ├── roc_full_pipeline_results.json │ │ ├── roc_nlu_test.csv │ │ ├── roc_results.json │ │ └── roc_samples.csv │ ├── evaluation.ipynb │ └── preprocess.ipynb └── relation_modeling │ ├── analysis.ipynb │ ├── base.ipynb │ ├── data │ └── atomic_split │ │ ├── n1 │ │ ├── test_n1.csv │ │ └── train_n1.csv │ │ ├── n3 │ │ ├── test_n3.csv │ │ └── train_n3.csv │ │ └── n5 │ │ ├── test_n5.csv │ │ └── train_n5.csv │ ├── heuristic.ipynb │ ├── ood.ipynb │ ├── prediction.ipynb │ ├── relation_modeling_utils.py │ ├── spacy │ ├── base_config.cfg │ ├── config.cfg │ └── relation_model_spacy.ipynb │ ├── swem │ ├── relation_modeling_swem.ipynb │ ├── relation_modeling_swem.py │ ├── relation_modeling_swem_finetune.ipynb │ ├── relation_modeling_swem_multi.ipynb │ ├── relation_modeling_swem_multi_nn.ipynb │ ├── relation_modeling_swem_multi_nn_no_pad.ipynb │ ├── relation_modeling_swem_multi_no_pad.ipynb │ ├── relation_modeling_swem_multi_no_pad_max.ipynb │ └── relation_modeling_swem_no_pad.ipynb │ └── transformer │ ├── relation_modeling_bert.ipynb │ ├── relation_modeling_bert.py │ ├── relation_modeling_distillbert.ipynb │ ├── relation_modeling_distillbert.py │ └── relation_modeling_transformer.ipynb ├── kogito ├── __init__.py ├── core │ ├── __init__.py │ ├── callbacks.py │ ├── dataset.py │ ├── head.py │ ├── knowledge.py │ ├── linker.py │ ├── model.py │ ├── processors │ │ ├── __init__.py │ │ ├── data │ │ │ └── vocab_glove_100d.npy │ │ ├── head.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ ├── bert.py │ │ │ ├── distilbert.py │ │ │ ├── swem.py │ │ │ └── utils.py │ │ └── relation.py │ ├── relation.py │ └── utils.py ├── evaluation │ ├── LICENSE │ ├── __init__.py │ ├── bert_score │ │ ├── __init__.py │ │ ├── bert_score.py │ │ ├── score.py │ │ └── utils.py │ ├── bleu │ │ ├── .gitignore │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── bleu.py │ │ └── bleu_scorer.py │ ├── cider │ │ ├── __init__.py │ │ ├── cider.py │ │ └── cider_scorer.py │ ├── eval.py │ ├── meteor │ │ ├── __init__.py │ │ └── meteor.py │ └── rouge │ │ ├── __init__.py │ │ └── rouge.py ├── inference.py ├── linkers │ └── deberta.py └── models │ ├── __init__.py │ ├── bart │ ├── __init__.py │ ├── comet.py │ ├── config.py │ └── utils.py │ ├── gpt2 │ ├── __init__.py │ ├── comet.py │ ├── utils.py │ └── zeroshot.py │ └── gpt3 │ ├── __init__.py │ └── zeroshot.py ├── poetry.lock └── pyproject.toml /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, W503 3 | max-line-length = 120 4 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | experiments/**/*.ipynb linguist-vendored -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | # MacOS 131 | .DS_Store 132 | 133 | # kogito 134 | experiments/relation_modeling/output 135 | experiments/relation_modeling/models 136 | experiments/relation_modeling/hmodels 137 | experiments/relation_modeling/data 138 | examples/models 139 | wandb 140 | .vector_cache/ 141 | lightning_logs/ 142 | kogito-relation-matcher/ 143 | *.bin 144 | __MACOSX -------------------------------------------------------------------------------- /.mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.7 3 | 4 | [mypy-setuptools] 5 | ignore_missing_imports = True 6 | 7 | [mypy-pytorch_lightning] 8 | ignore_missing_imports = True 9 | 10 | [mypy-pytorch_lightning.callbacks] 11 | ignore_missing_imports = True 12 | 13 | [mypy-pytorch_lightning.utilities] 14 | ignore_missing_imports = True 15 | 16 | [mypy-spacy] 17 | ignore_missing_imports = True 18 | 19 | [mypy-inflect] 20 | ignore_missing_imports = True 21 | 22 | [mypy-rouge_score] 23 | ignore_missing_imports = True 24 | 25 | [mypy-tqdm] 26 | ignore_missing_imports = True 27 | 28 | [mypy-pandas] 29 | ignore_missing_imports = True -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.8" 7 | 8 | python: 9 | install: 10 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # kogito 2 | A Python NLP Commonsense Knowledge Inference Toolkit 3 | 4 | System Description available here: https://arxiv.org/abs/2211.08451 5 | 6 | ## Installation 7 | 8 | ### Installation with pip 9 | **kogito** can be installed using pip. 10 | 11 | ```sh 12 | pip install kogito 13 | ``` 14 | 15 | It requires a minimum ``python`` version of ``3.8``. 16 | 17 | ## Setup 18 | 19 | ### Inference 20 | **kogito** uses [spacy](https://spacy.io) under the hood for various text processing purposes, so, a [spacy](https://spacy.io) language package has to be installed before running the inference module. 21 | 22 | ```sh 23 | python -m spacy download en_core_web_sm 24 | ``` 25 | By default, ``CommonsenseInference`` module uses ``en_core_web_sm`` to initialize ``spacy`` pipeline, but a different language pipeline can be specified as well. 26 | 27 | ### Evaluation 28 | If you also would like evaluate knowledge models using `METEOR` score, then you need to download the following ``nltk`` libraries: 29 | ```python 30 | import nltk 31 | 32 | nltk.download("punkt") 33 | nltk.download("wordnet") 34 | nltk.download("omw-1.4") 35 | ``` 36 | 37 | ## Quickstart 38 | **kogito** provides an easy interface to interact with knowledge inference or commonsense reasoning models such as [COMET](https://arxiv.org/abs/2010.05953) to generate inferences from a text input. 39 | Here is a sample usage of the library where you can initialize an inference module, a custom commonsense reasoning model, and generate a knowledge graph from text on the fly. 40 | 41 | ```python 42 | from kogito.models.bart.comet import COMETBART 43 | from kogito.inference import CommonsenseInference 44 | 45 | # Load pre-trained model from HuggingFace 46 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2") 47 | 48 | # Initialize inference module with a spacy language pipeline 49 | csi = CommonsenseInference(language="en_core_web_sm") 50 | 51 | # Run inference 52 | text = "PersonX becomes a great basketball player" 53 | kgraph = csi.infer(text, model) 54 | 55 | # Save output knowledge graph to JSON file 56 | kgraph.to_jsonl("kgraph.json") 57 | ``` 58 | 59 | Here is an excerpt from the result of the above code sample: 60 | 61 | ```json 62 | {"head": "PersonX becomes a great basketball player", "relation": "Causes", "tails": [" PersonX practices every day.", " PersonX plays basketball every day", " PersonX practices every day"]} 63 | {"head": "basketball", "relation": "ObjectUse", "tails": [" play with friends", " play basketball with", " play basketball"]} 64 | {"head": "player", "relation": "CapableOf", "tails": [" play game", " win game", " play football"]} 65 | {"head": "great basketball player", "relation": "HasProperty", "tails": [" good at basketball", " good at sports", " very good"]} 66 | {"head": "become player", "relation": "isAfter", "tails": [" play game", " become coach", " play with"]} 67 | ``` 68 | This is just one way to generate commonsense inferences and **kogito** offers much more. For complete documentation, check out the [kogito docs](https://kogito.readthedocs.io). 69 | 70 | ## Development 71 | 72 | ### Setup 73 | **kogito** uses [Poetry](https://python-poetry.org/) to manage its dependencies. 74 | 75 | Install poetry from the official repository first: 76 | ```sh 77 | curl -sSL https://install.python-poetry.org | python3 - 78 | ``` 79 | 80 | Then run the following command to install package dependencies: 81 | ```sh 82 | poetry install 83 | ``` 84 | 85 | ## Data 86 | If you need the ATOMIC2020 data to train your knowledge models, you can download it from AI2: 87 | 88 | For ATOMIC: 89 | ```sh 90 | wget https://storage.googleapis.com/ai2-mosaic/public/atomic/v1.0/atomic_data.tgz 91 | ``` 92 | 93 | For ATOMIC 2020: 94 | ```sh 95 | wget https://ai2-atomic.s3-us-west-2.amazonaws.com/data/atomic2020_data-feb2021.zip 96 | ``` 97 | 98 | ## Paper 99 | If you want to learn more about the library design, models and data used for this toolkit, check out our [paper](https://arxiv.org/abs/2211.08451). The paper can be cited as: 100 | 101 | ``` 102 | @article{Ismayilzada2022kogito, 103 | title={kogito: A Commonsense Knowledge Inference Toolkit}, 104 | author={Mete Ismayilzada and Antoine Bosselut}, 105 | journal={ArXiv}, 106 | volume={abs/2211.08451}, 107 | year={2022} 108 | } 109 | ``` 110 | 111 | If you work with knowledge models, consider citing the following papers: 112 | 113 | ``` 114 | @article{Hwang2020COMETATOMIC, 115 | author = {Jena D. Hwang and Chandra Bhagavatula and Ronan Le Bras and Jeff Da and Keisuke Sakaguchi and Antoine Bosselut and Yejin Choi}, 116 | booktitle = {Proceedings of the 35th AAAI Conference on Artificial Intelligence (AAAI)}, 117 | title = {COMET-ATOMIC 2020: On Symbolic and Neural Commonsense Knowledge Graphs}, 118 | year = {2021} 119 | } 120 | 121 | @inproceedings{Bosselut2019COMETCT, 122 | author = {Antoine Bosselut and Hannah Rashkin and Maarten Sap and Chaitanya Malaviya and Asli Çelikyilmaz and Yejin Choi}, 123 | booktitle = {Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (ACL)}, 124 | title = {COMET: Commonsense Transformers for Automatic Knowledge Graph Construction}, 125 | year = {2019} 126 | } 127 | ``` 128 | 129 | ## Acknowledgements 130 | Significant portion of the model training and evaluation code has been adapted from the original [codebase](https://github.com/allenai/comet-atomic-2020) for the paper [(Comet-) Atomic 2020: On Symbolic and Neural Commonsense Knowledge Graphs.](https://www.semanticscholar.org/paper/COMET-ATOMIC-2020%3A-On-Symbolic-and-Neural-Knowledge-Hwang-Bhagavatula/e39503e01ebb108c6773948a24ca798cd444eb62) 131 | -------------------------------------------------------------------------------- /app/client/.dockerignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | build -------------------------------------------------------------------------------- /app/client/.env: -------------------------------------------------------------------------------- 1 | REACT_APP_SERVER_URL=https://proxy.kogito.live -------------------------------------------------------------------------------- /app/client/.gitignore: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. 2 | 3 | # dependencies 4 | /node_modules 5 | /.pnp 6 | .pnp.js 7 | 8 | # testing 9 | /coverage 10 | 11 | # production 12 | /build 13 | 14 | # misc 15 | .DS_Store 16 | .env.local 17 | .env.development.local 18 | .env.test.local 19 | .env.production.local 20 | 21 | npm-debug.log* 22 | yarn-debug.log* 23 | yarn-error.log* 24 | -------------------------------------------------------------------------------- /app/client/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM node:18.5 2 | 3 | ENV HOME=/root 4 | ENV APP_DIR=${HOME}/kogito-client 5 | ENV REACT_APP_SERVER_URL=http://localhost:8080 6 | 7 | # Set default shell to /bin/bash 8 | SHELL ["/bin/bash", "-cu"] 9 | 10 | # Install dependencies 11 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends vim wget unzip 12 | 13 | RUN mkdir ${APP_DIR} 14 | WORKDIR ${APP_DIR} 15 | 16 | # Copy client files 17 | COPY . . 18 | 19 | # Setup app dependencies 20 | RUN npm install -g serve 21 | RUN npm install 22 | RUN npm run build 23 | 24 | EXPOSE 3000 25 | 26 | CMD ["serve", "-s", "build"] -------------------------------------------------------------------------------- /app/client/README.md: -------------------------------------------------------------------------------- 1 | # Getting Started with Create React App 2 | 3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app). 4 | 5 | ## Available Scripts 6 | 7 | In the project directory, you can run: 8 | 9 | ### `npm start` 10 | 11 | Runs the app in the development mode.\ 12 | Open [http://localhost:3000](http://localhost:3000) to view it in your browser. 13 | 14 | The page will reload when you make changes.\ 15 | You may also see any lint errors in the console. 16 | 17 | ### `npm test` 18 | 19 | Launches the test runner in the interactive watch mode.\ 20 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information. 21 | 22 | ### `npm run build` 23 | 24 | Builds the app for production to the `build` folder.\ 25 | It correctly bundles React in production mode and optimizes the build for the best performance. 26 | 27 | The build is minified and the filenames include the hashes.\ 28 | Your app is ready to be deployed! 29 | 30 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information. 31 | 32 | ### `npm run eject` 33 | 34 | **Note: this is a one-way operation. Once you `eject`, you can't go back!** 35 | 36 | If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project. 37 | 38 | Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own. 39 | 40 | You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it. 41 | 42 | ## Learn More 43 | 44 | You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started). 45 | 46 | To learn React, check out the [React documentation](https://reactjs.org/). 47 | 48 | ### Code Splitting 49 | 50 | This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting) 51 | 52 | ### Analyzing the Bundle Size 53 | 54 | This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size) 55 | 56 | ### Making a Progressive Web App 57 | 58 | This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app) 59 | 60 | ### Advanced Configuration 61 | 62 | This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration) 63 | 64 | ### Deployment 65 | 66 | This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment) 67 | 68 | ### `npm run build` fails to minify 69 | 70 | This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify) 71 | -------------------------------------------------------------------------------- /app/client/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "kogito", 3 | "version": "0.1.0", 4 | "homepage": "https://kogito.live", 5 | "private": true, 6 | "dependencies": { 7 | "@testing-library/jest-dom": "^5.16.4", 8 | "@testing-library/react": "^13.3.0", 9 | "@testing-library/user-event": "^13.5.0", 10 | "axios": "^0.27.2", 11 | "file-saver": "^2.0.5", 12 | "lodash": "^4.17.21", 13 | "react": "^18.2.0", 14 | "react-copy-to-clipboard": "^5.1.0", 15 | "react-dom": "^18.2.0", 16 | "react-scripts": "5.0.1", 17 | "semantic-ui-react": "^2.1.3", 18 | "semantic-ui-react-numberinput": "^1.5.1", 19 | "serve": "^14.0.1", 20 | "web-vitals": "^2.1.4" 21 | }, 22 | "scripts": { 23 | "start": "react-scripts start", 24 | "build": "react-scripts build", 25 | "test": "react-scripts test", 26 | "eject": "react-scripts eject", 27 | "predeploy": "npm run build", 28 | "deploy": "gh-pages -d build" 29 | }, 30 | "eslintConfig": { 31 | "extends": [ 32 | "react-app", 33 | "react-app/jest" 34 | ] 35 | }, 36 | "browserslist": { 37 | "production": [ 38 | ">0.2%", 39 | "not dead", 40 | "not op_mini all" 41 | ], 42 | "development": [ 43 | "last 1 chrome version", 44 | "last 1 firefox version", 45 | "last 1 safari version" 46 | ] 47 | }, 48 | "devDependencies": { 49 | "gh-pages": "^4.0.0" 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /app/client/public/CNAME: -------------------------------------------------------------------------------- 1 | kogito.live -------------------------------------------------------------------------------- /app/client/public/android-chrome-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/android-chrome-192x192.png -------------------------------------------------------------------------------- /app/client/public/android-chrome-512x512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/android-chrome-512x512.png -------------------------------------------------------------------------------- /app/client/public/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/apple-touch-icon.png -------------------------------------------------------------------------------- /app/client/public/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/favicon-16x16.png -------------------------------------------------------------------------------- /app/client/public/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/favicon-32x32.png -------------------------------------------------------------------------------- /app/client/public/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/favicon.ico -------------------------------------------------------------------------------- /app/client/public/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 23 | 24 | 33 | 34 | 35 | 36 | Kogito 37 | 38 | 39 | 40 |
41 | 51 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /app/client/public/lab-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/lab-logo.png -------------------------------------------------------------------------------- /app/client/public/manifest.json: -------------------------------------------------------------------------------- 1 | { 2 | "short_name": "Kogito", 3 | "name": "Kogito: Knowledge Inference Tool", 4 | "icons": [ 5 | { 6 | "src": "favicon.ico", 7 | "sizes": "64x64 32x32 24x24 16x16", 8 | "type": "image/x-icon" 9 | }, 10 | { 11 | "src": "android-chrome-192x192.png", 12 | "type": "image/png", 13 | "sizes": "192x192" 14 | }, 15 | { 16 | "src": "android-chrome-512x512.png", 17 | "type": "image/png", 18 | "sizes": "512x512" 19 | } 20 | ], 21 | "start_url": ".", 22 | "display": "standalone", 23 | "theme_color": "#000000", 24 | "background_color": "#ffffff" 25 | } 26 | -------------------------------------------------------------------------------- /app/client/public/robots.txt: -------------------------------------------------------------------------------- 1 | # https://www.robotstxt.org/robotstxt.html 2 | User-agent: * 3 | Disallow: 4 | -------------------------------------------------------------------------------- /app/client/src/App.css: -------------------------------------------------------------------------------- 1 | .logo { 2 | padding-top: 10px; 3 | font-size: 64px; 4 | font-family: 'Economica', sans-serif; 5 | margin-bottom: 10px; 6 | } 7 | 8 | .logo-k { 9 | background-color: #dce755; 10 | font-weight: bold; 11 | } 12 | 13 | .description { 14 | /* font-size: 20px; */ 15 | } 16 | 17 | .cntr-label { 18 | padding-bottom: 5px; 19 | } 20 | 21 | .cntr { 22 | margin-top: 24px; 23 | margin-bottom: 36px; 24 | } 25 | 26 | .cntr-head { 27 | margin-top: 10px; 28 | } 29 | 30 | .ui.button.kbtn { 31 | background: #dce755; 32 | } 33 | 34 | .ui.button.kbtn:hover { 35 | background: #dce755; 36 | } 37 | 38 | .home-results-json-segment { 39 | padding: 0!important; 40 | border: none!important; 41 | } -------------------------------------------------------------------------------- /app/client/src/App.test.js: -------------------------------------------------------------------------------- 1 | import { render, screen } from '@testing-library/react'; 2 | import App from './App'; 3 | 4 | test('renders learn react link', () => { 5 | render(); 6 | const linkElement = screen.getByText(/learn react/i); 7 | expect(linkElement).toBeInTheDocument(); 8 | }); 9 | -------------------------------------------------------------------------------- /app/client/src/api.js: -------------------------------------------------------------------------------- 1 | import axios from 'axios' 2 | 3 | const API_URI = process.env.REACT_APP_SERVER_URL || "http://localhost:8080" 4 | 5 | const api = { 6 | inference: { 7 | generate: (data) => axios.post(API_URI + '/inference', data) 8 | } 9 | } 10 | 11 | export default api -------------------------------------------------------------------------------- /app/client/src/index.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 3 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', 4 | sans-serif; 5 | -webkit-font-smoothing: antialiased; 6 | -moz-osx-font-smoothing: grayscale; 7 | } 8 | 9 | code { 10 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New', 11 | monospace; 12 | } 13 | -------------------------------------------------------------------------------- /app/client/src/index.js: -------------------------------------------------------------------------------- 1 | import React from 'react'; 2 | import ReactDOM from 'react-dom/client'; 3 | import './index.css'; 4 | import App from './App'; 5 | // import reportWebVitals from './reportWebVitals'; 6 | 7 | const root = ReactDOM.createRoot(document.getElementById('root')); 8 | root.render( 9 | // 10 | 11 | // 12 | ); 13 | 14 | // If you want to start measuring performance in your app, pass a function 15 | // to log results (for example: reportWebVitals(console.log)) 16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals 17 | // reportWebVitals(); 18 | -------------------------------------------------------------------------------- /app/client/src/relations.js: -------------------------------------------------------------------------------- 1 | const RELATIONS = [ 2 | 'AtLocation', 3 | 'CapableOf', 4 | 'Causes', 5 | 'CausesDesire', 6 | 'Desires', 7 | 'HasProperty', 8 | 'HasSubEvent', 9 | 'HinderedBy', 10 | 'MadeUpOf', 11 | 'NotDesires', 12 | 'ObjectUse', 13 | 'IsAfter', 14 | 'IsBefore', 15 | 'oEffect', 16 | 'oReact', 17 | 'oWant', 18 | 'xAttr', 19 | 'xEffect', 20 | 'xIntent', 21 | 'xNeed', 22 | 'xReact', 23 | 'xReason', 24 | 'xWant', 25 | 'CreatedBy', 26 | 'DefinedAs', 27 | 'DesireOf', 28 | 'HasA', 29 | 'HasFirstSubevent', 30 | 'HasLastSubevent', 31 | 'HasPainCharacter', 32 | 'HasPainIntensity', 33 | 'HasPrerequisite', 34 | 'InheritsFrom', 35 | 'InstanceOf', 36 | 'IsA', 37 | 'LocatedNear', 38 | 'LocationOfAction', 39 | 'MadeOf', 40 | 'NotHasA', 41 | 'NotHasProperty', 42 | 'NotIsA', 43 | 'NotMadeOf', 44 | 'MotivatedByGoal', 45 | 'NotCapableOf', 46 | 'PartOf', 47 | 'ReceivesAction', 48 | 'RelatedTo', 49 | 'SymbolOf', 50 | 'UsedFor', 51 | ] 52 | 53 | export default RELATIONS -------------------------------------------------------------------------------- /app/client/src/reportWebVitals.js: -------------------------------------------------------------------------------- 1 | const reportWebVitals = onPerfEntry => { 2 | if (onPerfEntry && onPerfEntry instanceof Function) { 3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => { 4 | getCLS(onPerfEntry); 5 | getFID(onPerfEntry); 6 | getFCP(onPerfEntry); 7 | getLCP(onPerfEntry); 8 | getTTFB(onPerfEntry); 9 | }); 10 | } 11 | }; 12 | 13 | export default reportWebVitals; 14 | -------------------------------------------------------------------------------- /app/client/src/setupTests.js: -------------------------------------------------------------------------------- 1 | // jest-dom adds custom jest matchers for asserting on DOM nodes. 2 | // allows you to do things like: 3 | // expect(element).toHaveTextContent(/react/i) 4 | // learn more: https://github.com/testing-library/jest-dom 5 | import '@testing-library/jest-dom'; 6 | -------------------------------------------------------------------------------- /app/docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '2' 2 | 3 | services: 4 | server: 5 | container_name: kogito-server 6 | image: kogito-server 7 | build: server 8 | ports: 9 | - 8080:8080 10 | client: 11 | container_name: kogito-client 12 | image: kogito-client 13 | build: client 14 | ports: 15 | - 3000:3000 16 | depends_on: 17 | - server -------------------------------------------------------------------------------- /app/proxy-nginx.conf: -------------------------------------------------------------------------------- 1 | server { 2 | listen [::]:443 ssl ipv6only=on; # managed by Certbot 3 | listen 443 ssl; # managed by Certbot 4 | ssl_certificate /etc/letsencrypt/live/proxy.kogito.live/fullchain.pem; # managed by Certbot 5 | ssl_certificate_key /etc/letsencrypt/live/proxy.kogito.live/privkey.pem; # managed by Certbot 6 | include /etc/letsencrypt/options-ssl-nginx.conf; # managed by Certbot 7 | ssl_dhparam /etc/letsencrypt/ssl-dhparams.pem; # managed by Certbot 8 | server_name proxy.kogito.live; 9 | proxy_read_timeout 300; 10 | proxy_connect_timeout 300; 11 | proxy_send_timeout 300; 12 | 13 | location / { 14 | proxy_pass http://76.50.42.128:49780; 15 | proxy_http_version 1.1; 16 | proxy_set_header Upgrade $http_upgrade; 17 | proxy_set_header Connection 'upgrade'; 18 | proxy_set_header Host $host; 19 | proxy_cache_bypass $http_upgrade; 20 | proxy_set_header X-Real-IP $remote_addr; 21 | proxy_set_header X-Forwarded-Proto https; 22 | proxy_set_header X-Forwarded-For $remote_addr; 23 | proxy_set_header X-Forwarded-Host $remote_addr; 24 | } 25 | } 26 | server { 27 | if ($host = proxy.kogito.live) { 28 | return 301 https://$host$request_uri; 29 | } # managed by Certbot 30 | 31 | 32 | listen 80 ; 33 | listen [::]:80 ; 34 | server_name proxy.kogito.live; 35 | return 404; # managed by Certbot 36 | 37 | 38 | } -------------------------------------------------------------------------------- /app/server/.dockerignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | lightning_logs -------------------------------------------------------------------------------- /app/server/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.8 2 | 3 | ENV HOME=/root 4 | ENV APP_DIR=${HOME}/kogito/app/server 5 | ENV FLASK_APP=server 6 | 7 | # Set default shell to /bin/bash 8 | SHELL ["/bin/bash", "-cu"] 9 | 10 | # Install dependencies 11 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends vim wget unzip 12 | 13 | WORKDIR ${HOME} 14 | ARG GITHUB_PERSONAL_TOKEN 15 | 16 | # Clone kogito 17 | RUN git clone https://${GITHUB_PERSONAL_TOKEN}@github.com/epfl-nlp/kogito.git 18 | 19 | WORKDIR ${APP_DIR} 20 | 21 | # Setup app dependencies 22 | RUN pip3 install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu114 23 | RUN python3 -c "import nltk;nltk.download('punkt');nltk.download('wordnet');nltk.download('omw-1.4')" 24 | RUN python3 -m spacy download en_core_web_sm 25 | 26 | EXPOSE 8080 27 | 28 | # gunicorn -w 1 -b 0.0.0.0:8080 --log-file server.log --capture-output --timeout 3600 --pid pid.log server:app --daemon 29 | ENTRYPOINT ["flask", "run", "-h", "0.0.0.0", "-p", "8080"] -------------------------------------------------------------------------------- /app/server/inference.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | 3 | from kogito.models.bart.comet import COMETBART 4 | 5 | # from kogito.models.gpt2.comet import COMETGPT2 6 | # from kogito.models.gpt2.zeroshot import GPT2Zeroshot 7 | from kogito.inference import CommonsenseInference 8 | from kogito.core.relation import KnowledgeRelation 9 | from kogito.core.processors.relation import ( 10 | # SWEMRelationMatcher, 11 | DistilBERTRelationMatcher, 12 | # BERTRelationMatcher, 13 | ) 14 | from kogito.linkers.deberta import DebertaLinker 15 | 16 | MODEL_MAP = { 17 | "comet-bart": COMETBART.from_pretrained("mismayil/comet-bart-ai2"), 18 | # "comet-gpt2": COMETGPT2.from_pretrained("mismayil/comet-gpt2-ai2"), 19 | # "gpt2": GPT2Zeroshot("gpt2-xl") 20 | } 21 | 22 | LINKER_MAP = {"deberta": DebertaLinker()} 23 | 24 | PROCESSOR_MAP = { 25 | # "swem_relation_matcher": SWEMRelationMatcher("swem_relation_matcher"), 26 | "distilbert_relation_matcher": DistilBERTRelationMatcher( 27 | "distilbert_relation_matcher" 28 | ), 29 | # "bert_relation_matcher": BERTRelationMatcher("bert_relation_matcher"), 30 | } 31 | 32 | nlp = spacy.load("en_core_web_sm") 33 | 34 | print("Ready for inference.") 35 | 36 | 37 | def infer(data): 38 | text = data.get("text") 39 | model = MODEL_MAP.get(data.get("model")) 40 | heads = data.get("heads") 41 | relations = data.get("relations") 42 | extract_heads = data.get("extractHeads", True) 43 | match_relations = data.get("matchRelations", True) 44 | dry_run = data.get("dryRun", False) 45 | head_procs = data.get("headProcs", []) 46 | rel_procs = data.get("relProcs", []) 47 | context = data.get("context", "").strip() 48 | threshold = float(data.get("threshold", 0.5)) 49 | 50 | csi = CommonsenseInference(language="en_core_web_sm") 51 | csi_head_procs = csi.processors["head"] 52 | csi_rel_procs = csi.processors["relation"] 53 | 54 | for proc in set(csi_head_procs).difference(set(head_procs)): 55 | csi.remove_processor(proc) 56 | 57 | for proc in set(csi_rel_procs).difference(set(rel_procs)): 58 | csi.remove_processor(proc) 59 | 60 | for proc in set(head_procs).difference(set(csi_head_procs)): 61 | csi.add_processor(PROCESSOR_MAP[proc]) 62 | 63 | for proc in set(rel_procs).difference(set(csi_rel_procs)): 64 | csi.add_processor(PROCESSOR_MAP[proc]) 65 | 66 | if relations: 67 | for i in range(len(relations)): 68 | relations[i] = KnowledgeRelation.from_text(relations[i]) 69 | 70 | linker = LINKER_MAP["deberta"] 71 | 72 | output_graph = csi.infer( 73 | text=text, 74 | model=model, 75 | heads=heads, 76 | relations=relations, 77 | extract_heads=extract_heads, 78 | match_relations=match_relations, 79 | dry_run=dry_run, 80 | context=context, 81 | threshold=threshold, 82 | linker=linker, 83 | ) 84 | 85 | result = {"text": [], "graph": []} 86 | 87 | if output_graph: 88 | result["graph"] = [kg.to_json() for kg in output_graph] 89 | 90 | if text: 91 | doc = nlp(text) 92 | result["text"] = [token.lemma_.lower() for token in doc] 93 | 94 | return result 95 | -------------------------------------------------------------------------------- /app/server/requirements.txt: -------------------------------------------------------------------------------- 1 | Flask 2 | kogito 3 | flask-cors 4 | gunicorn -------------------------------------------------------------------------------- /app/server/server.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | 4 | from flask import Flask, request, jsonify 5 | from flask_cors import CORS 6 | from inference import infer 7 | 8 | app = Flask(__name__) 9 | CORS(app) 10 | 11 | 12 | @app.route("/") 13 | def heartbeat(): 14 | return "Running" 15 | 16 | 17 | @app.route("/inference", methods=["POST"]) 18 | def inference(): 19 | try: 20 | return jsonify(infer(request.json)) 21 | except Exception as e: 22 | traceback.print_exc(e) 23 | return str(e), 500 24 | 25 | 26 | def main(): 27 | port = int(os.environ.get("PORT", 8080)) 28 | app.run(debug=os.environ.get("FLASK_DEBUG", False), host="0.0.0.0", port=port) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.5.1-cudnn8-runtime-ubuntu20.04@sha256:a6831f0d6328ea7301fa196ae2a376d2e67caae384af4ffd93fb196b527c0a0f 2 | 3 | ENV HOME=/root 4 | ENV CONDA_PREFIX=${HOME}/.conda 5 | ENV CONDA=${CONDA_PREFIX}/condabin/conda 6 | ENV KOGITO_DIR=${HOME}/kogito 7 | 8 | # Set default shell to /bin/bash 9 | SHELL ["/bin/bash", "-cu"] 10 | 11 | # Install dependencies 12 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends openssh-server vim wget unzip tmux git 13 | 14 | # Set up SSH server 15 | RUN mkdir /var/run/sshd 16 | RUN echo 'root:root' | chpasswd 17 | RUN sed -i 's/#*PermitRootLogin prohibit-password/PermitRootLogin yes/g' /etc/ssh/sshd_config 18 | # SSH login fix. Otherwise user is kicked off after login 19 | RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd 20 | ENV NOTVISIBLE="in users profile" 21 | RUN echo "export VISIBLE=now" >> /etc/profile 22 | EXPOSE 22 23 | 24 | WORKDIR ${HOME} 25 | 26 | # Cluster setup 27 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh -O anaconda.sh 28 | RUN bash anaconda.sh -b -p ${CONDA_PREFIX} 29 | RUN ${CONDA} config --set auto_activate_base false 30 | RUN ${CONDA} init bash 31 | RUN echo "export LANG=en_US.UTF-8" >> ~/.bashrc 32 | RUN ${CONDA} create --name kogito -y python=3.10 33 | RUN ${CONDA} install -n kogito ipykernel --force-reinstall 34 | 35 | WORKDIR ${KOGITO_DIR} 36 | 37 | ARG ENV_TAINT=0 38 | 39 | # Setup project dependencies 40 | RUN ${CONDA} run -n kogito pip install kogito --extra-index-url https://download.pytorch.org/whl/cu116 41 | RUN ${CONDA} run -n kogito python -c "import nltk;nltk.download('punkt');nltk.download('wordnet');nltk.download('omw-1.4')" 42 | RUN ${CONDA} run -n kogito python -m spacy download en_core_web_sm 43 | 44 | ARG VERSION_TAINT=0 45 | 46 | # Setup data 47 | COPY ./data . 48 | COPY ./train.py . 49 | COPY ./train.sh . 50 | 51 | CMD ["/usr/sbin/sshd", "-D"] 52 | # CMD ["./train.sh"] -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | export GITHUB_PERSONAL_TOKEN=${1} 4 | docker build -f Dockerfile -t ic-registry.epfl.ch/nlp/kogito --build-arg GITHUB_PERSONAL_TOKEN . -------------------------------------------------------------------------------- /docker/dev.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.5.1-cudnn8-runtime-ubuntu20.04@sha256:a6831f0d6328ea7301fa196ae2a376d2e67caae384af4ffd93fb196b527c0a0f 2 | 3 | ENV HOME=/root 4 | ENV CONDA_PREFIX=${HOME}/.conda 5 | ENV CONDA=${CONDA_PREFIX}/condabin/conda 6 | ENV KOGITO_DIR=${HOME}/kogito 7 | ENV POETRY=${HOME}/.local/bin/poetry 8 | 9 | # Set default shell to /bin/bash 10 | SHELL ["/bin/bash", "-cu"] 11 | 12 | # Install dependencies 13 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends openssh-server vim wget unzip tmux 14 | 15 | WORKDIR ${HOME} 16 | 17 | # Cluster setup 18 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh -O anaconda.sh 19 | RUN bash anaconda.sh -b -p ${CONDA_PREFIX} 20 | RUN ${CONDA} config --set auto_activate_base false 21 | RUN ${CONDA} init bash 22 | RUN git config --global user.name "Mete Ismayil" 23 | RUN git config --global user.email "mismayilza@gmail.com" 24 | RUN git config pull.rebase false 25 | RUN echo "export LANG=en_US.UTF-8" >> ~/.bashrc 26 | 27 | # Setup kogito env 28 | RUN ${CONDA} create --name kogito -y python=3.8 29 | RUN ${CONDA} run -n kogito curl -sSL https://install.python-poetry.org | python3 - 30 | RUN ${CONDA} install -n kogito -y pytorch cudatoolkit=11.5 -c pytorch 31 | 32 | ARG GITHUB_PERSONAL_TOKEN 33 | 34 | # Clone kogito 35 | RUN git clone https://${GITHUB_PERSONAL_TOKEN}@github.com/epfl-nlp/kogito.git 36 | 37 | WORKDIR ${KOGITO_DIR} 38 | 39 | # Setup kogito dependencies 40 | RUN ${CONDA} run -n kogito ${POETRY} install 41 | RUN ${CONDA} run -n kogito python -c "import nltk;nltk.download('punkt');nltk.download('wordnet');nltk.download('omw-1.4')" 42 | RUN ${CONDA} run -n kogito python -m spacy download en_core_web_sm 43 | 44 | # Install training data 45 | ENV KOGITO_DATA_DIR=${KOGITO_DIR}/data 46 | RUN mkdir ${KOGITO_DATA_DIR} 47 | RUN wget https://ai2-atomic.s3-us-west-2.amazonaws.com/data/atomic2020_data-feb2021.zip 48 | RUN unzip atomic2020_data-feb2021.zip -d ${KOGITO_DATA_DIR} 49 | 50 | # Set up SSH server 51 | RUN apt-get update && apt-get install -y openssh-server tmux vim 52 | RUN mkdir /var/run/sshd 53 | RUN echo 'root:root' | chpasswd 54 | RUN sed -i 's/#*PermitRootLogin prohibit-password/PermitRootLogin yes/g' /etc/ssh/sshd_config 55 | # SSH login fix. Otherwise user is kicked off after login 56 | RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd 57 | ENV NOTVISIBLE="in users profile" 58 | RUN echo "export VISIBLE=now" >> /etc/profile 59 | EXPOSE 22 60 | 61 | COPY ./train.py . 62 | COPY ./entrypoint.sh . 63 | 64 | ENTRYPOINT ["/usr/sbin/sshd", "-D"] 65 | # ENTRYPOINT ["./entrypoint.sh"] -------------------------------------------------------------------------------- /docker/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ${CONDA} run -n kogito python train.py -------------------------------------------------------------------------------- /docker/push.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | docker push ic-registry.epfl.ch/nlp/kogito -------------------------------------------------------------------------------- /docker/train.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.gpt2.comet import COMETGPT2 3 | import os 4 | 5 | if __name__ == "__main__": 6 | data_dir = os.environ.get("KOGITO_DATA_DIR") 7 | model = COMETGPT2("gpt2-xl") 8 | train_graph = KnowledgeGraph.from_csv( 9 | f"{data_dir}/atomic2020_data-feb2021/train.tsv", header=None, sep="\t" 10 | ) 11 | val_graph = KnowledgeGraph.from_csv( 12 | f"{data_dir}/atomic2020_data-feb2021/dev.tsv", header=None, sep="\t" 13 | ) 14 | model.train( 15 | train_graph=train_graph, 16 | val_graph=val_graph, 17 | batch_size=16, 18 | output_dir="/scratch/mete/models/comet-gpt2", 19 | log_wandb=True, 20 | lr=5e-5, 21 | epochs=1, 22 | ) 23 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | @import url('https://fonts.googleapis.com/css2?family=Inconsolata:wght@200;300;400;500;600;700;800;900&display=swap'); 2 | 3 | .highlight-python .kn { 4 | font-weight: 900; 5 | } 6 | 7 | .highlight-python .nn { 8 | font-weight: 500; 9 | } 10 | 11 | .highlight-python .nc { 12 | font-weight: 500; 13 | } 14 | 15 | #logo-container h1::first-letter { 16 | background-color: #dce755; 17 | font-weight: 400; 18 | } -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | API Reference 3 | ============= 4 | 5 | Inference 6 | ========= 7 | 8 | .. automodule:: kogito.inference 9 | :members: 10 | :special-members: __init__ 11 | 12 | Head 13 | ==== 14 | 15 | .. automodule:: kogito.core.head 16 | :members: 17 | :special-members: __init__ 18 | 19 | Relation 20 | ======== 21 | 22 | .. automodule:: kogito.core.relation 23 | :members: 24 | :special-members: __init__ 25 | 26 | Knowledge 27 | ========= 28 | 29 | .. automodule:: kogito.core.knowledge 30 | :members: 31 | :special-members: __init__ 32 | 33 | Models 34 | ====== 35 | 36 | .. automodule:: kogito.core.model 37 | :members: 38 | :special-members: __init__ 39 | 40 | .. automodule:: kogito.models.bart.comet 41 | :members: 42 | :special-members: __init__ 43 | 44 | .. automodule:: kogito.models.gpt2.comet 45 | :members: 46 | :special-members: __init__ 47 | 48 | .. automodule:: kogito.models.gpt2.zeroshot 49 | :members: 50 | :special-members: __init__ 51 | 52 | .. automodule:: kogito.models.gpt3.zeroshot 53 | :members: 54 | :special-members: __init__ 55 | 56 | Processors 57 | ========== 58 | 59 | .. automodule:: kogito.core.processors.head 60 | :members: 61 | :special-members: __init__ 62 | 63 | .. automodule:: kogito.core.processors.relation 64 | :members: 65 | :special-members: __init__ 66 | 67 | Linkers 68 | ======= 69 | 70 | .. automodule:: kogito.core.linker 71 | :members: 72 | :special-members: __init__ 73 | 74 | .. automodule:: kogito.linkers.deberta 75 | :members: 76 | :special-members: __init__ -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("..")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "kogito" 22 | copyright = "2022, Mete Ismayil" 23 | author = "Mete Ismayil" 24 | 25 | 26 | # -- General configuration --------------------------------------------------- 27 | 28 | # Add any Sphinx extension module names here, as strings. They can be 29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 30 | # ones. 31 | extensions = [ 32 | "insegel", 33 | "sphinx.ext.autodoc", 34 | "sphinx.ext.coverage", 35 | "sphinx.ext.napoleon", 36 | ] 37 | 38 | # Add any paths that contain templates here, relative to this directory. 39 | templates_path = ["_templates"] 40 | 41 | # List of patterns, relative to source directory, that match files and 42 | # directories to ignore when looking for source files. 43 | # This pattern also affects html_static_path and html_extra_path. 44 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 45 | 46 | 47 | # -- Options for HTML output ------------------------------------------------- 48 | 49 | # The theme to use for HTML and HTML Help pages. See the documentation for 50 | # a list of builtin themes. 51 | # 52 | html_theme = "insegel" 53 | 54 | # Add any paths that contain custom static files (such as style sheets) here, 55 | # relative to this directory. They are copied after the builtin static files, 56 | # so a file named "default.css" will overwrite the builtin "default.css". 57 | html_static_path = ["_static"] 58 | 59 | html_css_files = [ 60 | "css/custom.css", 61 | ] 62 | 63 | autodoc_member_order = "bysource" 64 | -------------------------------------------------------------------------------- /docs/getting_started.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Getting Started 3 | =============== 4 | 5 | Installation 6 | ============ 7 | 8 | Installation with pip 9 | ********************* 10 | **kogito** can be installed using pip. 11 | 12 | .. code-block:: shell 13 | 14 | pip install kogito 15 | 16 | It requires a minimum ``python`` version of ``3.8``. 17 | 18 | Setup 19 | ===== 20 | 21 | Inference 22 | ********* 23 | **kogito** uses `spacy `__ under the hood for various text processing purposes, so, a `spacy `__ language package has to be installed before running the inference module. 24 | 25 | .. code-block:: shell 26 | 27 | python -m spacy download en_core_web_sm 28 | 29 | By default, ``CommonsenseInference`` module uses ``en_core_web_sm`` to initialize `spacy `__ pipeline, but a different language pipeline can be specified as well. 30 | 31 | Evaluation 32 | ********** 33 | If you also would like evaluate knowledge models using `METEOR `_ score, then you need to download the following `nltk `_ libraries: 34 | 35 | .. code-block:: python 36 | 37 | import nltk 38 | 39 | nltk.download("punkt") 40 | nltk.download("wordnet") 41 | nltk.download("omw-1.4") 42 | 43 | 44 | Quickstart 45 | =========== 46 | **kogito** provides an easy interface to interact with knowledge inference or commonsense reasoning models such as `COMET `__ to generate inferences from a text input. 47 | Here is a sample usage of the library where you can initialize an inference module, a custom commonsense reasoning model, and generate a knowledge graph from text on the fly. 48 | 49 | .. code-block:: python 50 | 51 | from kogito.models.bart.comet import COMETBART 52 | from kogito.inference import CommonsenseInference 53 | 54 | # Load pre-trained model from HuggingFace 55 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2") 56 | 57 | # Initialize inference module with a spacy language pipeline 58 | csi = CommonsenseInference(language="en_core_web_sm") 59 | 60 | # Run inference 61 | text = "PersonX becomes a great basketball player" 62 | kgraph = csi.infer(text, model) 63 | 64 | # Save output knowledge graph to JSON file 65 | kgraph.to_jsonl("kgraph.json") 66 | 67 | 68 | Here is an excerpt from the result of the above code sample: 69 | 70 | .. code-block:: json 71 | 72 | {"head": "PersonX becomes a great basketball player", "relation": "Causes", "tails": [" PersonX practices every day.", " PersonX plays basketball every day", " PersonX practices every day"]} 73 | {"head": "basketball", "relation": "ObjectUse", "tails": [" play with friends", " play basketball with", " play basketball"]} 74 | {"head": "player", "relation": "CapableOf", "tails": [" play game", " win game", " play football"]} 75 | {"head": "great basketball player", "relation": "HasProperty", "tails": [" good at basketball", " good at sports", " very good"]} 76 | {"head": "become player", "relation": "isAfter", "tails": [" play game", " become coach", " play with"]} 77 | 78 | This is just one way to generate commonsense inferences and **kogito** offers much more. For information on more use-cases and a complete API reference, you can check out the `User Guide `_ and `API Reference `_ pages. -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. kogito documentation master file, created by 2 | sphinx-quickstart on Wed Apr 13 11:36:35 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | A Python Knowledge Inference Tool 7 | ====================================== 8 | 9 | .. toctree:: 10 | :maxdepth: 2 11 | 12 | getting_started 13 | userguide 14 | api 15 | 16 | .. Indices and tables 17 | .. ================== 18 | 19 | .. * :ref:`genindex` 20 | .. * :ref:`modindex` 21 | .. * :ref:`search` 22 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | insegel 2 | kogito -------------------------------------------------------------------------------- /eacl2023/kogito-poster-eacl2023.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/eacl2023/kogito-poster-eacl2023.pdf -------------------------------------------------------------------------------- /eacl2023/kogito-presentation-eacl2023.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/eacl2023/kogito-presentation-eacl2023.pdf -------------------------------------------------------------------------------- /examples/data/atomic2020/README.md: -------------------------------------------------------------------------------- 1 | # ATOMIC 2020 Knowledge Graph 2 | 3 | All data `*.tsv` files are formatted as follows: 4 | - each line represents a distinct commonsense tuple 5 | - column 1: head node/concept 6 | - column 2: edge relation (e.g., xWant,xAttr,AtLocation) 7 | - column 3: tail node/concept 8 | 9 | `train.tsv`, `dev.tsv`, `test.tsv` correspond to train/dev/test splits. 10 | 11 | 12 | ## Paper 13 | Please cite the following work when using this data: 14 | 15 | > Jena D. Hwang, Chandra Bhagavatula, Ronan Le Bras, Jeff Da, Keisuke Sakaguchi, Antoine Bosselut, Yejin Choi (2021). 16 | > (Comet-) Atomic 2020: On Symbolic and Neural Commonsense Knowledge Graphs. 17 | > AAAI 2021 -------------------------------------------------------------------------------- /examples/data/atomic2020/sample_dev.tsv: -------------------------------------------------------------------------------- 1 | PersonX 'd better go oEffect none 2 | PersonX 'd better go oEffect none 3 | PersonX 'd better go oReact none 4 | PersonX 'd better go oReact none 5 | PersonX 'd better go oWant none 6 | PersonX 'd better go oWant none 7 | PersonX 'd better go oWant none 8 | PersonX 'd better go xAttr avoidant 9 | PersonX 'd better go xAttr weak 10 | PersonX 'd better go xAttr hurried 11 | PersonX 'd better go xAttr late 12 | PersonX 'd better go xAttr Tardy 13 | PersonX 'd better go xAttr busy 14 | PersonX 'd better go xEffect She ran to the bathroom 15 | PersonX 'd better go xEffect She finally made it 16 | PersonX 'd better go xEffect leaves 17 | PersonX 'd better go xEffect runs away 18 | PersonX 'd better go xIntent to go somewhere else more important. 19 | PersonX 'd better go xIntent none 20 | PersonX 'd better go xNeed none 21 | PersonX 'd better go xNeed none 22 | PersonX 'd better go xNeed none 23 | PersonX 'd better go xReact the person feels happy since he arrived at his destination. 24 | PersonX 'd better go xReact rushed, in a hurry 25 | PersonX 'd better go xWant to escape from him 26 | PersonX 'd better go xWant to resign his job 27 | PersonX 'd better go xWant to leave on time 28 | PersonX 'd better go xWant to arrive home 29 | PersonX 'd better go xWant to relax and unwind 30 | PersonX 'd better go xWant to walk away 31 | PersonX 'd better go xWant not speak to anyone 32 | PersonX accepts PersonX's diploma oEffect none -------------------------------------------------------------------------------- /examples/data/atomic2020/sample_test.tsv: -------------------------------------------------------------------------------- 1 | PersonX abuses PersonX's power oEffect are told what to do 2 | PersonX abuses PersonX's power oEffect given unfair consequences or punishment 3 | PersonX abuses PersonX's power oEffect reach out for help 4 | PersonX abuses PersonX's power oEffect none 5 | PersonX abuses PersonX's power oReact humiliated 6 | PersonX abuses PersonX's power oReact sad 7 | PersonX abuses PersonX's power oReact angry 8 | PersonX abuses PersonX's power oReact cheated 9 | PersonX abuses PersonX's power oWant report PersonX to HR 10 | PersonX abuses PersonX's power oWant get PersonX fired 11 | PersonX abuses PersonX's power oWant to elect a new person 12 | PersonX abuses PersonX's power oWant to depose PersonX 13 | PersonX abuses PersonX's power oWant to impeach PrrsonX 14 | PersonX abuses PersonX's power oWant to punish person X 15 | PersonX abuses PersonX's power oWant to get rid of person X's atrocities 16 | PersonX abuses PersonX's power oWant to elect another leader 17 | PersonX abuses PersonX's power oWant to banish PersonX 18 | PersonX abuses PersonX's power oWant to remove PersonX 19 | PersonX abuses PersonX's power oWant to stop PersonX 20 | PersonX abuses PersonX's power oWant to depose PersonX 21 | PersonX abuses PersonX's power oWant to hurt PersonX 22 | PersonX abuses PersonX's power xAttr out of line 23 | PersonX abuses PersonX's power xAttr irresponsible 24 | PersonX abuses PersonX's power xAttr mean 25 | PersonX abuses PersonX's power xAttr confident 26 | PersonX abuses PersonX's power xAttr abusive 27 | PersonX abuses PersonX's power xAttr unreliable 28 | PersonX abuses PersonX's power xAttr abusive 29 | PersonX abuses PersonX's power xAttr careless 30 | PersonX abuses PersonX's power xEffect becomes authoratarian 31 | PersonX abuses PersonX's power xEffect is ostracized 32 | PersonX abuses PersonX's power xEffect is relieved of position -------------------------------------------------------------------------------- /examples/data/atomic2020/sample_train.tsv: -------------------------------------------------------------------------------- 1 | PersonX abandons ___ altogether oEffect none 2 | PersonX abandons ___ altogether oEffect none 3 | PersonX abandons ___ altogether oReact dejected 4 | PersonX abandons ___ altogether oWant none 5 | PersonX abandons ___ altogether oWant none 6 | PersonX abandons ___ altogether oWant to find a new job for him 7 | PersonX abandons ___ altogether oWant to support him 8 | PersonX abandons ___ altogether xAttr impatient 9 | PersonX abandons ___ altogether xAttr decisive 10 | PersonX abandons ___ altogether xAttr undependable 11 | PersonX abandons ___ altogether xAttr fickle 12 | PersonX abandons ___ altogether xAttr destructed 13 | PersonX abandons ___ altogether xAttr sad 14 | PersonX abandons ___ altogether xEffect gets a reputation as a quitter 15 | PersonX abandons ___ altogether xEffect hangs head in shame 16 | PersonX abandons ___ altogether xEffect Begins the process of change 17 | PersonX abandons ___ altogether xEffect Turns over a new leaf 18 | PersonX abandons ___ altogether xIntent put a stop 19 | PersonX abandons ___ altogether xNeed Plows the field. 20 | PersonX abandons ___ altogether xNeed Gets exhausted from it. 21 | PersonX abandons ___ altogether xNeed none 22 | PersonX abandons ___ altogether xNeed to give a resignation letter 23 | PersonX abandons ___ altogether xNeed to get permission from his parents 24 | PersonX abandons ___ altogether xReact authoritative 25 | PersonX abandons ___ altogether xWant Sell his land. 26 | PersonX abandons ___ altogether xWant Was just city. 27 | PersonX abandons ___ altogether xWant to start something new 28 | PersonX abandons ___ altogether xWant to start fresh 29 | PersonX abandons ___ altogether xWant to find a new job 30 | PersonX abandons ___ altogether xWant to search for a new job 31 | PersonX abandons the ___ altogether oEffect none 32 | PersonX abandons the ___ altogether oEffect none 33 | PersonX abandons the ___ altogether oEffect none 34 | PersonX abandons the ___ altogether oReact defeat 35 | PersonX abandons the ___ altogether oWant none 36 | PersonX abandons the ___ altogether oWant to do something else as well 37 | PersonX abandons the ___ altogether oWant they find something better 38 | PersonX abandons the ___ altogether oWant none 39 | PersonX abandons the ___ altogether xAttr flaky 40 | PersonX abandons the ___ altogether xAttr irresponsible 41 | PersonX abandons the ___ altogether xAttr desperate 42 | PersonX abandons the ___ altogether xAttr convinced 43 | PersonX abandons the ___ altogether xAttr decisive 44 | PersonX abandons the ___ altogether xAttr frustrated 45 | PersonX abandons the ___ altogether xEffect eats all the cakes 46 | PersonX abandons the ___ altogether xEffect abandons his diets too 47 | PersonX abandons the ___ altogether xEffect repercussions for leaving all responsibilities 48 | PersonX abandons the ___ altogether xEffect they go home 49 | PersonX abandons the ___ altogether xEffect they try to form a different plan 50 | PersonX abandons the ___ altogether xEffect they search for a different alternative 51 | PersonX abandons the ___ altogether xIntent to appear not interested 52 | PersonX abandons the ___ altogether xNeed none 53 | PersonX abandons the ___ altogether xNeed to get frustrated 54 | PersonX abandons the ___ altogether xNeed to determine it's not worth it 55 | PersonX abandons the ___ altogether xNeed none 56 | PersonX abandons the ___ altogether xReact pressurized 57 | PersonX abandons the ___ altogether xWant to go out 58 | PersonX abandons the ___ altogether xWant to find other place 59 | PersonX abandons the ___ altogether xWant find something else to do 60 | PersonX abandons the ___ altogether xWant to do the project the best he can 61 | PersonX abandons the ___ altogether xWant sigh in relief 62 | PersonX abandons the ___ altogether xWant find another project 63 | PersonX abolishes ___ altogether oEffect none 64 | PersonX abolishes ___ altogether oEffect none -------------------------------------------------------------------------------- /examples/evaluate_comet_bart.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.bart.comet import COMETBART 3 | 4 | input_graph = KnowledgeGraph.from_jsonl("test_atomic2020_sample.json") 5 | 6 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2") 7 | scores = model.evaluate( 8 | input_graph, batch_size=256, num_return_sequences=1 9 | ) 10 | print(scores) 11 | -------------------------------------------------------------------------------- /examples/evaluate_comet_gpt2.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.gpt2.comet import COMETGPT2 3 | 4 | input_graph = KnowledgeGraph.from_jsonl("test_atomic2020.json") 5 | 6 | model = COMETGPT2.from_pretrained("mismayil/comet-bart-ai2") 7 | scores = model.evaluate(input_graph) 8 | 9 | print(scores) 10 | -------------------------------------------------------------------------------- /examples/results/cc_news/cc_news_samples_13.json: -------------------------------------------------------------------------------- 1 | {"head": "Washington", "relation": "ObjectUse", "tails": [" Washington, D.C.", " Washington, D.C", " Washington,D.C."]} 2 | {"head": "Washington", "relation": "AtLocation", "tails": [" city", " town", " cities"]} 3 | {"head": "Washington", "relation": "CapableOf", "tails": [" go to war", " go on vacation", " go to the beach"]} 4 | {"head": "Washington", "relation": "HasProperty", "tails": [" Washington, D.C.", " Washington, D.C", " Washington, D.C.,"]} 5 | {"head": "Washington", "relation": "MadeUpOf", "tails": [" Washington, D.C.", " Washington, D.C", " Washington, D.C.,"]} 6 | {"head": "Washington", "relation": "NotDesires", "tails": [" Washington, D.C.", " Washington, D.C", " Washington D.C."]} 7 | {"head": "Washington", "relation": "Desires", "tails": [" Washington, D.C.", " Washington, D.C", " Washington, D.C.,"]} 8 | {"head": "From Everett, Washington", "relation": "xAttr", "tails": [" adventurous", " sailing", " sailing ship"]} 9 | {"head": "From Everett, Washington", "relation": "xWant", "tails": [" to go to the beach", " to see the ocean", " to go home"]} 10 | {"head": "From Everett, Washington", "relation": "xNeed", "tails": [" to drive to the city", " to drive to the location", " to drive to the town"]} 11 | {"head": "From Everett, Washington", "relation": "xEffect", "tails": [" get to their destination", " get to the beach", " get to their destination."]} 12 | {"head": "From Everett, Washington", "relation": "HinderedBy", "tails": [" from there", " from a different city", " from a different state"]} 13 | {"head": "From Everett, Washington", "relation": "oWant", "tails": [" to go to the beach", " to see the sights", " to go to the movies"]} 14 | {"head": "From Everett, Washington", "relation": "xReact", "tails": [" happy", " happy to be there", " happy to be there."]} 15 | {"head": "From Everett, Washington", "relation": "oEffect", "tails": [" none", " travel", " travel to"]} 16 | {"head": "From Everett, Washington", "relation": "xIntent", "tails": [" to see the ocean", " to go to the beach", " to visit family"]} 17 | {"head": "From Everett, Washington", "relation": "oReact", "tails": [" none", " like they have a place to live", " like they have a place to go"]} 18 | {"head": "From Everett, Washington", "relation": "isBefore", "tails": [" from there to there", " from there", " from there to their home"]} 19 | {"head": "From Everett, Washington", "relation": "isAfter", "tails": [" from Seattle to Everett", " from there to there", " from there"]} 20 | {"head": "From Everett, Washington", "relation": "HasSubEvent", "tails": [" get to their destination", " get to the beach", " get off plane"]} 21 | {"head": "From Everett, Washington", "relation": "Causes", "tails": [" from there", " from a different state", " from a different place"]} 22 | {"head": "From Everett, Washington", "relation": "xReason", "tails": [" from there to there", " from there", " from the ocean"]} 23 | {"head": "Everett", "relation": "ObjectUse", "tails": [" get a new job", " have a good time", " get a job"]} 24 | {"head": "Everett", "relation": "AtLocation", "tails": [" motel room", " school", " town"]} 25 | {"head": "Everett", "relation": "CapableOf", "tails": [" go to sleep", " go to bed", " go to the store"]} 26 | {"head": "Everett", "relation": "HasProperty", "tails": [" none", " one of the two brothers", " one of the three brothers"]} 27 | {"head": "Everett", "relation": "MadeUpOf", "tails": [" ", " ", " "]} 28 | {"head": "Everett", "relation": "NotDesires", "tails": [" get in car accident", " none", " have to go to work"]} 29 | {"head": "Everett", "relation": "Desires", "tails": [" none", " get a new job", " get a job"]} -------------------------------------------------------------------------------- /examples/results/cc_news/cc_news_samples_41.json: -------------------------------------------------------------------------------- 1 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xAttr", "tails": [" responsible", " energy-hungry", " energy-generating"]} 2 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xWant", "tails": [" to make money", " to sell the company", " to sell the business"]} 3 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xNeed", "tails": [" to invest in the company", " to invest in a company", " to invest in the business"]} 4 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xEffect", "tails": [" Gcl New Energy has a lot of debt.", " Gcl New Energy has a lot of money.", " Gcl New Energy has no money."]} 5 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "HinderedBy", "tails": [" Gcl New Energy is not a company.", " Gcl New Energy is not profitable.", " Gcl New Energy has no money."]} 6 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "oWant", "tails": [" none", " to make money", " to sell the company"]} 7 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xReact", "tails": [" good about themselves", " happy", " satisfied"]} 8 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "oEffect", "tails": [" none", " Gcl New Energy has a lot of debt.", " Gcl New Energy has no money."]} 9 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xIntent", "tails": [" to make money", " to make money.", " to make a profit"]} 10 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "oReact", "tails": [" none", " like they have something to do", " like they have a good company"]} 11 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "isBefore", "tails": [" Gcl New Energy", " Gcl New Energy Holdings", " Gcl New Energy is a company"]} 12 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "isAfter", "tails": [" Gcl New Energy", " Gcl New Energy Holdings", " Gcl New Energy Holdings Ltd"]} 13 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "HasSubEvent", "tails": [" Gcl New Energy gets a loan from the bank.", " Gcl New Energy gets a loan from a bank.", " Gcl New Energy gets a loan from the bank"]} 14 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "Causes", "tails": [" Gcl New Energy has no money.", " Gcl New Energy Holdings is not a company.", " Gcl New Energy Holdings is not a real company."]} 15 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xReason", "tails": [" Gcl New Energy is a private company.", " Gcl New Energy", " Gcl New Energy Holdings"]} 16 | {"head": "Gcl New Energy Holdings Ltd", "relation": "ObjectUse", "tails": [" to make money", " power plants", " to make a profit"]} 17 | {"head": "Gcl New Energy Holdings Ltd", "relation": "AtLocation", "tails": [" energy company", " electricity", " company"]} 18 | {"head": "Gcl New Energy Holdings Ltd", "relation": "CapableOf", "tails": [" get a loan", " power plants", " buy electricity"]} 19 | {"head": "Gcl New Energy Holdings Ltd", "relation": "HasProperty", "tails": [" Gcl New Energy Holdings", " Gcl New Energy", " Gcl New Energy is a company."]} 20 | {"head": "Gcl New Energy Holdings Ltd", "relation": "MadeUpOf", "tails": [" power plant", " electricity", " energy"]} 21 | {"head": "Gcl New Energy Holdings Ltd", "relation": "NotDesires", "tails": [" Gcl New Energy is a public company.", " Gcl New Energy is a private company.", " Gcl New Energy"]} 22 | {"head": "Gcl New Energy Holdings Ltd", "relation": "Desires", "tails": [" to make money", " none", " to make a profit"]} -------------------------------------------------------------------------------- /examples/results/cc_news/cc_news_samples_65.json: -------------------------------------------------------------------------------- 1 | {"head": "LONDON LEAGUE", "relation": "xAttr", "tails": [" competitive", " athletic", " sporty"]} 2 | {"head": "LONDON LEAGUE", "relation": "xWant", "tails": [" to win the game", " to win the championship", " to play a game"]} 3 | {"head": "LONDON LEAGUE", "relation": "xNeed", "tails": [" find a team to play in", " find a place to play", " find a team to play for"]} 4 | {"head": "LONDON LEAGUE", "relation": "xEffect", "tails": [" wins the league", " wins the game", " wins the championship"]} 5 | {"head": "LONDON LEAGUE", "relation": "HinderedBy", "tails": [" LONDON LEAGUE", " LONDON LEAGUE ", " LONDON LEAGUE "]} 6 | {"head": "LONDON LEAGUE", "relation": "oWant", "tails": [" to win the game", " to play in the league", " to play a game"]} 7 | {"head": "LONDON LEAGUE", "relation": "xReact", "tails": [" happy", " proud", " good"]} 8 | {"head": "LONDON LEAGUE", "relation": "oEffect", "tails": [" none", " play in", " play football"]} 9 | {"head": "LONDON LEAGUE", "relation": "xIntent", "tails": [" to be competitive", " to play football", " to play in"]} 10 | {"head": "LONDON LEAGUE", "relation": "oReact", "tails": [" none", " happy", " happy."]} 11 | {"head": "LONDON LEAGUE", "relation": "isBefore", "tails": [" LONDON LEAGUE", " LONDON FOOTBALL LEAGUE", " LONDON LONDON LEAGUE"]} 12 | {"head": "LONDON LEAGUE", "relation": "isAfter", "tails": [" LONDON LEAGUE", " LONDON FOOTBALL LEAGUE", " LONDON LONDON LEAGUE"]} 13 | {"head": "LONDON LEAGUE", "relation": "HasSubEvent", "tails": [" win the game", " play in", " win the match"]} 14 | {"head": "LONDON LEAGUE", "relation": "Causes", "tails": [" LONDON LEAGUE", " LONDON LEAGUE.", " LONDON LEAGUE "]} 15 | {"head": "LONDON LEAGUE", "relation": "xReason", "tails": [" LONDON LEAGUE", " LONDON LONDON LEAGUE", " LONDON LEAGUE."]} -------------------------------------------------------------------------------- /examples/results/cc_news/cc_news_samples_67.json: -------------------------------------------------------------------------------- 1 | {"head": "to find", "relation": "HinderedBy", "tails": [" to find ", " to find", " to find to find "]} 2 | {"head": "to find", "relation": "isBefore", "tails": [" to find", " to find something", " to find "]} 3 | {"head": "to find", "relation": "isAfter", "tails": [" to find something", " to find something to eat", " to find"]} 4 | {"head": "to find", "relation": "HasSubEvent", "tails": [" to find", " to find ", " to look for it"]} 5 | {"head": "to find", "relation": "Causes", "tails": [" to find", " to find ", " to find something"]} 6 | {"head": "to find", "relation": "xReason", "tails": [" to find", " to find something", " to find "]} 7 | {"head": "to decompose", "relation": "HinderedBy", "tails": [" to decompose", " to decompose ", " to decompose to decompose "]} 8 | {"head": "to decompose", "relation": "isBefore", "tails": [" to decompose", " to decompose in the ground", " to decompose in a lab"]} 9 | {"head": "to decompose", "relation": "isAfter", "tails": [" to decompose", " to decompose ", " decompose"]} 10 | {"head": "to decompose", "relation": "HasSubEvent", "tails": [" to decompose", " to decompose ", " to decompose in the ground"]} 11 | {"head": "to decompose", "relation": "Causes", "tails": [" to decompose", " to decompose ", " decomposing"]} 12 | {"head": "to decompose", "relation": "xReason", "tails": [" to decompose", " to decompose ", " to decompose in a lab"]} 13 | {"head": "body", "relation": "ObjectUse", "tails": [" have sex with", " body parts", " have sex"]} 14 | {"head": "body", "relation": "AtLocation", "tails": [" body", " medicine chest", " body parts"]} 15 | {"head": "body", "relation": "CapableOf", "tails": [" lie down on bed", " lie on bed", " lie down in bed"]} 16 | {"head": "body", "relation": "HasProperty", "tails": [" found in body", " body", " found in human body"]} 17 | {"head": "body", "relation": "MadeUpOf", "tails": [" body", " body part", " body parts"]} 18 | {"head": "body", "relation": "NotDesires", "tails": [" dead", " death", " body parts"]} 19 | {"head": "body", "relation": "Desires", "tails": [" body", " body parts", " feel good"]} 20 | {"head": "badly decomposed body", "relation": "ObjectUse", "tails": [" hide in the woods", " kill someone", " hide in a closet"]} 21 | {"head": "badly decomposed body", "relation": "AtLocation", "tails": [" dead body", " body bag", " corpse"]} 22 | {"head": "badly decomposed body", "relation": "CapableOf", "tails": [" dead body", " kill yourself", " bury in grave"]} 23 | {"head": "badly decomposed body", "relation": "HasProperty", "tails": [" embalmed", " dead body", " dead"]} 24 | {"head": "badly decomposed body", "relation": "MadeUpOf", "tails": [" embalmed", " dead body", " embalmed body"]} 25 | {"head": "badly decomposed body", "relation": "NotDesires", "tails": [" embalmed", " dead body", " bad smell"]} 26 | {"head": "badly decomposed body", "relation": "Desires", "tails": [" dead body", " death", " bad smell"]} 27 | {"head": "The badly decomposed body was found near...", "relation": "xAttr", "tails": [" dead", " morbid", " sadistic"]} 28 | {"head": "The badly decomposed body was found near...", "relation": "xWant", "tails": [" to bury the body", " to bury the body.", " to bury the dead body"]} 29 | {"head": "The badly decomposed body was found near...", "relation": "xNeed", "tails": [" find a body", " find a dead body", " none"]} 30 | {"head": "The badly decomposed body was found near...", "relation": "xEffect", "tails": [" is buried in a grave", " is buried", " the body is found"]} 31 | {"head": "The badly decomposed body was found near...", "relation": "HinderedBy", "tails": [" The body was found in the woods.", " The body was found in the middle of the road.", " The body was found in the middle of nowhere."]} 32 | {"head": "The badly decomposed body was found near...", "relation": "oWant", "tails": [" to bury the body", " none", " to bury the body."]} 33 | {"head": "The badly decomposed body was found near...", "relation": "xReact", "tails": [" sad", " dead", " scared"]} 34 | {"head": "The badly decomposed body was found near...", "relation": "oEffect", "tails": [" none", " the police are called to the scene", " the police find the body"]} 35 | {"head": "The badly decomposed body was found near...", "relation": "xIntent", "tails": [" none", " the body to be found", " the body to be buried"]} 36 | {"head": "The badly decomposed body was found near...", "relation": "oReact", "tails": [" none", " sad", " dead"]} 37 | {"head": "The badly decomposed body was found near...", "relation": "isBefore", "tails": [" the body to be buried", " the body to be found", " the body is buried"]} 38 | {"head": "The badly decomposed body was found near...", "relation": "isAfter", "tails": [" the body to be buried", " the body to be found", " the body was found"]} 39 | {"head": "The badly decomposed body was found near...", "relation": "HasSubEvent", "tails": [" the body to be found", " the body to be buried", " find a body"]} 40 | {"head": "The badly decomposed body was found near...", "relation": "Causes", "tails": [" the body to be found", " the body to be buried", " the body was found"]} 41 | {"head": "The badly decomposed body was found near...", "relation": "xReason", "tails": [" the body to be buried", " the body to be found", " the body was found"]} -------------------------------------------------------------------------------- /examples/results/cc_news/cc_news_samples_7.json: -------------------------------------------------------------------------------- 1 | {"head": "Mylan NV:", "relation": "xAttr", "tails": [" hardworking", " hard-working", " hard working"]} 2 | {"head": "Mylan NV:", "relation": "xWant", "tails": [" to make money", " to sell the medicine", " to sell the product"]} 3 | {"head": "Mylan NV:", "relation": "xNeed", "tails": [" to get a prescription", " to buy the drug", " to have a prescription"]} 4 | {"head": "Mylan NV:", "relation": "xEffect", "tails": [" gets a prescription", " get a prescription", " get a prescription filled"]} 5 | {"head": "Mylan NV:", "relation": "HinderedBy", "tails": [" The pharmacy is out of stock.", " The drug is too expensive.", " The pharmacy is closed."]} 6 | {"head": "Mylan NV:", "relation": "oWant", "tails": [" none", " to get rid of it", " to get rid of the drug"]} 7 | {"head": "Mylan NV:", "relation": "xReact", "tails": [" happy", " satisfied", " good"]} 8 | {"head": "Mylan NV:", "relation": "oEffect", "tails": [" none", " they get a prescription for it", " they get a prescription for the medicine"]} 9 | {"head": "Mylan NV:", "relation": "xIntent", "tails": [" none", " to get rid of a disease", " to get rid of the pain"]} 10 | {"head": "Mylan NV:", "relation": "oReact", "tails": [" none", " happy", " happy."]} 11 | {"head": "Mylan NV:", "relation": "isBefore", "tails": [" drugstore", " drug company", " drug"]} 12 | {"head": "Mylan NV:", "relation": "isAfter", "tails": [" drugstore", " medicine", " drug store"]} 13 | {"head": "Mylan NV:", "relation": "HasSubEvent", "tails": [" get a prescription", " get a prescription filled", " get a prescription for it"]} 14 | {"head": "Mylan NV:", "relation": "Causes", "tails": [" the medicine is not working", " the drug is not working", " the medicine to be safe"]} 15 | {"head": "Mylan NV:", "relation": "xReason", "tails": [" have to pay for it", " have to pay for the medication", " have to pay for the medicine"]} 16 | {"head": "Mylan NV", "relation": "ObjectUse", "tails": [" make a pill", " make pills", " sell to customers"]} 17 | {"head": "Mylan NV", "relation": "AtLocation", "tails": [" medicine chest", " drugstore", " drug store"]} 18 | {"head": "Mylan NV", "relation": "CapableOf", "tails": [" sell to customers", " sell drugs", " sell to others"]} 19 | {"head": "Mylan NV", "relation": "HasProperty", "tails": [" sold in the United States", " sold in the US", " good for heart health"]} 20 | {"head": "Mylan NV", "relation": "MadeUpOf", "tails": [" drug", " medicine", " medicine chest"]} 21 | {"head": "Mylan NV", "relation": "NotDesires", "tails": [" have to pay for it", " have to pay for drugs", " have to pay for the product"]} 22 | {"head": "Mylan NV", "relation": "Desires", "tails": [" to be a drug company", " to be able to take medicine", " to be able to sell drugs"]} -------------------------------------------------------------------------------- /examples/results/gpt3_prompt_2c976a77.txt: -------------------------------------------------------------------------------- 1 | What needs to be true for this event to take place? 2 | 3 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule 4 | 5 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y 6 | 7 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field. 8 | 9 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY 10 | 11 | PersonX always drank. Before, PersonX needs to to get a drink 12 | 13 | PersonX about to get married. Before, PersonX needs to meet someone 14 | 15 | PersonX accidentally fell. Before, PersonX needs to to walk too fast 16 | 17 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast 18 | 19 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date 20 | 21 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal 22 | 23 | PersonX aces PersonX's exam. Before, PersonX needs to 24 | 25 | What needs to be true for this event to take place? 26 | 27 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule 28 | 29 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y 30 | 31 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field. 32 | 33 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY 34 | 35 | PersonX always drank. Before, PersonX needs to to get a drink 36 | 37 | PersonX about to get married. Before, PersonX needs to meet someone 38 | 39 | PersonX accidentally fell. Before, PersonX needs to to walk too fast 40 | 41 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast 42 | 43 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date 44 | 45 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal 46 | 47 | PersonX accuses PersonY of cheating. Before, PersonX needs to -------------------------------------------------------------------------------- /examples/results/gpt3_prompt_56fcb840.txt: -------------------------------------------------------------------------------- 1 | What needs to be true for this event to take place? 2 | 3 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY 4 | 5 | PersonX about to get married. Before, PersonX needs to meet someone 6 | 7 | PersonX always drank. Before, PersonX needs to to get a drink 8 | 9 | PersonX accidentally fell. Before, PersonX needs to to walk too fast 10 | 11 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date 12 | 13 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule 14 | 15 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y 16 | 17 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field. 18 | 19 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast 20 | 21 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal 22 | 23 | PersonX sees PersonY's point. Before, PersonX needs to 24 | 25 | What needs to be true for this event to take place? 26 | 27 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY 28 | 29 | PersonX about to get married. Before, PersonX needs to meet someone 30 | 31 | PersonX always drank. Before, PersonX needs to to get a drink 32 | 33 | PersonX accidentally fell. Before, PersonX needs to to walk too fast 34 | 35 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date 36 | 37 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule 38 | 39 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y 40 | 41 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field. 42 | 43 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast 44 | 45 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal 46 | 47 | PersonX makes a huge mistake. Before, PersonX needs to -------------------------------------------------------------------------------- /examples/results/gpt3_prompt_99e17fae.txt: -------------------------------------------------------------------------------- 1 | How does this situation affect each character's wishes? 2 | 3 | Situation 1: PersonX calls PersonY 4 | Wishes: As a result, PersonX wishes to have a long chat 5 | 6 | Situation 2: PersonX bleeds a lot 7 | Wishes: As a result, PersonX wishes to see a doctor 8 | 9 | Situation 3: PersonX works as a cashier 10 | Wishes: As a result, PersonX wishes to be a store manager 11 | 12 | Situation 4: PersonX stays up all night studying 13 | Wishes: As a result, PersonX wishes to sleep all day 14 | 15 | Situation 5: PersonX makes his own costume 16 | Wishes: As a result, PersonX wishes to go to a costume party 17 | 18 | Situation 6: PersonX tells PersonY a secret 19 | Wishes: As a result, PersonX wishes to get PersonY's advice 20 | 21 | Situation 7: PersonX gets dirty 22 | Wishes: As a result, PersonX wishes to clean up 23 | 24 | Situation 8: PersonX ends a friendship 25 | Wishes: As a result, PersonX wishes to meet new people 26 | 27 | Situation 9: PersonX gets PersonY's autograph 28 | Wishes: As a result, PersonX wishes to have a relationship with PersonY 29 | 30 | Situation 10: PersonX mows the lawn 31 | Wishes: As a result, PersonX wishes to get a new lawnmower 32 | 33 | Situation 11: PersonX sees PersonY's point 34 | Wishes: As a result, PersonX wishes 35 | 36 | How does this situation affect each character's wishes? 37 | 38 | Situation 1: PersonX calls PersonY 39 | Wishes: As a result, PersonX wishes to have a long chat 40 | 41 | Situation 2: PersonX bleeds a lot 42 | Wishes: As a result, PersonX wishes to see a doctor 43 | 44 | Situation 3: PersonX works as a cashier 45 | Wishes: As a result, PersonX wishes to be a store manager 46 | 47 | Situation 4: PersonX stays up all night studying 48 | Wishes: As a result, PersonX wishes to sleep all day 49 | 50 | Situation 5: PersonX makes his own costume 51 | Wishes: As a result, PersonX wishes to go to a costume party 52 | 53 | Situation 6: PersonX tells PersonY a secret 54 | Wishes: As a result, PersonX wishes to get PersonY's advice 55 | 56 | Situation 7: PersonX gets dirty 57 | Wishes: As a result, PersonX wishes to clean up 58 | 59 | Situation 8: PersonX ends a friendship 60 | Wishes: As a result, PersonX wishes to meet new people 61 | 62 | Situation 9: PersonX gets PersonY's autograph 63 | Wishes: As a result, PersonX wishes to have a relationship with PersonY 64 | 65 | Situation 10: PersonX mows the lawn 66 | Wishes: As a result, PersonX wishes to get a new lawnmower 67 | 68 | Situation 11: PersonX makes a huge mistake 69 | Wishes: As a result, PersonX wishes -------------------------------------------------------------------------------- /examples/results/kgraph_dry_run.json: -------------------------------------------------------------------------------- 1 | {"head": "Gabby always brought cookies to school.", "relation": "xWant", "tails": []} 2 | {"head": "Gabby always brought cookies to school.", "relation": "HinderedBy", "tails": []} 3 | {"head": "school", "relation": "HasProperty", "tails": []} 4 | {"head": "cookies", "relation": "CapableOf", "tails": []} 5 | {"head": "bring cookies", "relation": "isAfter", "tails": []} 6 | {"head": "Gabby", "relation": "xNeed", "tails": []} 7 | {"head": "Gabby", "relation": "AtLocation", "tails": []} 8 | {"head": "school", "relation": "Causes", "tails": []} 9 | {"head": "Gabby", "relation": "MadeUpOf", "tails": []} 10 | {"head": "to bring", "relation": "Causes", "tails": []} 11 | {"head": "Gabby always brought cookies to school.", "relation": "xAttr", "tails": []} 12 | {"head": "to bring", "relation": "isFilledBy", "tails": []} 13 | {"head": "school", "relation": "ObjectUse", "tails": []} 14 | {"head": "cookies", "relation": "MadeUpOf", "tails": []} 15 | {"head": "to bring", "relation": "xReason", "tails": []} 16 | {"head": "cookies", "relation": "AtLocation", "tails": []} 17 | {"head": "cookies", "relation": "xNeed", "tails": []} 18 | {"head": "school", "relation": "NotDesires", "tails": []} 19 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []} 20 | {"head": "Gabby always brought cookies to school.", "relation": "isFilledBy", "tails": []} 21 | {"head": "school", "relation": "CapableOf", "tails": []} 22 | {"head": "to bring", "relation": "isBefore", "tails": []} 23 | {"head": "Gabby always brought cookies to school.", "relation": "xReact", "tails": []} 24 | {"head": "Gabby always brought cookies to school.", "relation": "xReason", "tails": []} 25 | {"head": "bring cookies", "relation": "HinderedBy", "tails": []} 26 | {"head": "to bring", "relation": "HasSubEvent", "tails": []} 27 | {"head": "Gabby always brought cookies to school.", "relation": "isBefore", "tails": []} 28 | {"head": "school", "relation": "AtLocation", "tails": []} 29 | {"head": "school", "relation": "xNeed", "tails": []} 30 | {"head": "Gabby always brought cookies to school.", "relation": "oReact", "tails": []} 31 | {"head": "to bring", "relation": "xNeed", "tails": []} 32 | {"head": "school", "relation": "MadeUpOf", "tails": []} 33 | {"head": "to bring", "relation": "isAfter", "tails": []} 34 | {"head": "Gabby always brought cookies to school.", "relation": "HasSubEvent", "tails": []} 35 | {"head": "Gabby", "relation": "Desires", "tails": []} 36 | {"head": "Gabby always brought cookies to school.", "relation": "xIntent", "tails": []} 37 | {"head": "Gabby always brought cookies to school.", "relation": "xNeed", "tails": []} 38 | {"head": "Gabby", "relation": "HasProperty", "tails": []} 39 | {"head": "Gabby always brought cookies to school.", "relation": "xEffect", "tails": []} 40 | {"head": "bring cookies", "relation": "Causes", "tails": []} 41 | {"head": "bring cookies", "relation": "isFilledBy", "tails": []} 42 | {"head": "Gabby always brought cookies to school.", "relation": "isAfter", "tails": []} 43 | {"head": "bring cookies", "relation": "xReason", "tails": []} 44 | {"head": "Gabby", "relation": "Causes", "tails": []} 45 | {"head": "cookies", "relation": "Desires", "tails": []} 46 | {"head": "bring cookies", "relation": "isBefore", "tails": []} 47 | {"head": "cookies", "relation": "HasProperty", "tails": []} 48 | {"head": "Gabby", "relation": "ObjectUse", "tails": []} 49 | {"head": "cookies", "relation": "Causes", "tails": []} 50 | {"head": "Gabby", "relation": "NotDesires", "tails": []} 51 | {"head": "Gabby always brought cookies to school.", "relation": "oEffect", "tails": []} 52 | {"head": "bring cookies", "relation": "xNeed", "tails": []} 53 | {"head": "cookies", "relation": "ObjectUse", "tails": []} 54 | {"head": "to bring", "relation": "HinderedBy", "tails": []} 55 | {"head": "bring cookies", "relation": "HasSubEvent", "tails": []} 56 | {"head": "Gabby", "relation": "CapableOf", "tails": []} 57 | {"head": "cookies", "relation": "NotDesires", "tails": []} 58 | {"head": "school", "relation": "Desires", "tails": []} 59 | {"head": "Gabby always brought cookies to school.", "relation": "oWant", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_dry_run_2.json: -------------------------------------------------------------------------------- 1 | {"head": "he didnt listen to me", "relation": "xReason", "tails": []} 2 | {"head": "he didnt listen to me", "relation": "HinderedBy", "tails": []} 3 | {"head": "to want", "relation": "xNeed", "tails": []} 4 | {"head": "he didnt listen to me", "relation": "xNeed", "tails": []} 5 | {"head": "to listen", "relation": "HinderedBy", "tails": []} 6 | {"head": "I wanted to feed him.", "relation": "oEffect", "tails": []} 7 | {"head": "I wanted to feed him.", "relation": "xIntent", "tails": []} 8 | {"head": "to listen", "relation": "xReason", "tails": []} 9 | {"head": "I wanted to feed him.", "relation": "isAfter", "tails": []} 10 | {"head": "I", "relation": "HasProperty", "tails": []} 11 | {"head": "to listen", "relation": "xNeed", "tails": []} 12 | {"head": "I wanted to feed him.", "relation": "oReact", "tails": []} 13 | {"head": "he didnt listen to me", "relation": "oEffect", "tails": []} 14 | {"head": "I", "relation": "MadeUpOf", "tails": []} 15 | {"head": "I", "relation": "ObjectUse", "tails": []} 16 | {"head": "I", "relation": "NotDesires", "tails": []} 17 | {"head": "to want", "relation": "isAfter", "tails": []} 18 | {"head": "he didnt listen to me", "relation": "xIntent", "tails": []} 19 | {"head": "I wanted to feed him.", "relation": "HasSubEvent", "tails": []} 20 | {"head": "I wanted to feed him.", "relation": "xEffect", "tails": []} 21 | {"head": "I wanted to feed him.", "relation": "isFilledBy", "tails": []} 22 | {"head": "he didnt listen to me", "relation": "isAfter", "tails": []} 23 | {"head": "he didnt listen to me", "relation": "oReact", "tails": []} 24 | {"head": "I", "relation": "CapableOf", "tails": []} 25 | {"head": "I", "relation": "Desires", "tails": []} 26 | {"head": "to listen", "relation": "isAfter", "tails": []} 27 | {"head": "to want", "relation": "HasSubEvent", "tails": []} 28 | {"head": "he didnt listen to me", "relation": "xEffect", "tails": []} 29 | {"head": "to want", "relation": "isFilledBy", "tails": []} 30 | {"head": "I wanted to feed him.", "relation": "isBefore", "tails": []} 31 | {"head": "he didnt listen to me", "relation": "HasSubEvent", "tails": []} 32 | {"head": "he didnt listen to me", "relation": "isFilledBy", "tails": []} 33 | {"head": "I", "relation": "AtLocation", "tails": []} 34 | {"head": "to want", "relation": "isBefore", "tails": []} 35 | {"head": "I wanted to feed him.", "relation": "xReact", "tails": []} 36 | {"head": "to listen", "relation": "HasSubEvent", "tails": []} 37 | {"head": "to listen", "relation": "isFilledBy", "tails": []} 38 | {"head": "I", "relation": "Causes", "tails": []} 39 | {"head": "he didnt listen to me", "relation": "isBefore", "tails": []} 40 | {"head": "I wanted to feed him.", "relation": "oWant", "tails": []} 41 | {"head": "I wanted to feed him.", "relation": "Causes", "tails": []} 42 | {"head": "to listen", "relation": "isBefore", "tails": []} 43 | {"head": "I wanted to feed him.", "relation": "xAttr", "tails": []} 44 | {"head": "I wanted to feed him.", "relation": "xWant", "tails": []} 45 | {"head": "he didnt listen to me", "relation": "xReact", "tails": []} 46 | {"head": "to want", "relation": "Causes", "tails": []} 47 | {"head": "he didnt listen to me", "relation": "oWant", "tails": []} 48 | {"head": "he didnt listen to me", "relation": "Causes", "tails": []} 49 | {"head": "I wanted to feed him.", "relation": "xReason", "tails": []} 50 | {"head": "I wanted to feed him.", "relation": "HinderedBy", "tails": []} 51 | {"head": "I", "relation": "xNeed", "tails": []} 52 | {"head": "he didnt listen to me", "relation": "xAttr", "tails": []} 53 | {"head": "to listen", "relation": "Causes", "tails": []} 54 | {"head": "he didnt listen to me", "relation": "xWant", "tails": []} 55 | {"head": "to want", "relation": "xReason", "tails": []} 56 | {"head": "to want", "relation": "HinderedBy", "tails": []} 57 | {"head": "I wanted to feed him.", "relation": "xNeed", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_gpt3.json: -------------------------------------------------------------------------------- 1 | {"head": "PersonX aces PersonX's exam", "relation": "xNeed", "tails": [" study"]} 2 | {"head": "PersonX accuses PersonY of cheating", "relation": "xNeed", "tails": [" to find out for sure"]} -------------------------------------------------------------------------------- /examples/results/kgraph_gpt3_custom_relation.json: -------------------------------------------------------------------------------- 1 | {"head": "PersonX sees PersonY's point", "relation": "xNeed", "tails": [" to listen to PersonY"]} 2 | {"head": "PersonX sees PersonY's point", "relation": "xWishes", "tails": [" to apologize to PersonY"]} 3 | {"head": "PersonX makes a huge mistake", "relation": "xNeed", "tails": [""]} 4 | {"head": "PersonX makes a huge mistake", "relation": "xWishes", "tails": [" to fix the mistake"]} -------------------------------------------------------------------------------- /examples/results/kgraph_gpt3_dry_run.json: -------------------------------------------------------------------------------- 1 | {"head": "PersonX accuses PersonY of cheating", "relation": "xNeed", "tails": ["\n\n1. Check if there is any rule in the game that PersonY"]} 2 | {"head": "PersonX aces PersonX's exam", "relation": "xNeed", "tails": ["\n\nstudy."]} -------------------------------------------------------------------------------- /examples/results/kgraph_manual.json: -------------------------------------------------------------------------------- 1 | {"head": "Gabby always brought cookies to school.", "relation": "Desires", "tails": []} 2 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_manual_heads.json: -------------------------------------------------------------------------------- 1 | {"head": "post office", "relation": "xReason", "tails": []} 2 | {"head": "to get out of the room", "relation": "xWant", "tails": []} 3 | {"head": "post office", "relation": "isBefore", "tails": []} 4 | {"head": "to get out of the room", "relation": "HinderedBy", "tails": []} 5 | {"head": "to get out of the room", "relation": "oWant", "tails": []} 6 | {"head": "to get out of the room", "relation": "xAttr", "tails": []} 7 | {"head": "post office", "relation": "oReact", "tails": []} 8 | {"head": "post office", "relation": "xIntent", "tails": []} 9 | {"head": "post office", "relation": "HasSubEvent", "tails": []} 10 | {"head": "post office", "relation": "xNeed", "tails": []} 11 | {"head": "to get out of the room", "relation": "Causes", "tails": []} 12 | {"head": "post office", "relation": "xEffect", "tails": []} 13 | {"head": "to get out of the room", "relation": "isFilledBy", "tails": []} 14 | {"head": "to get out of the room", "relation": "xReact", "tails": []} 15 | {"head": "post office", "relation": "isAfter", "tails": []} 16 | {"head": "to get out of the room", "relation": "xReason", "tails": []} 17 | {"head": "to get out of the room", "relation": "isBefore", "tails": []} 18 | {"head": "post office", "relation": "Causes", "tails": []} 19 | {"head": "to get out of the room", "relation": "oReact", "tails": []} 20 | {"head": "post office", "relation": "oEffect", "tails": []} 21 | {"head": "post office", "relation": "oWant", "tails": []} 22 | {"head": "to get out of the room", "relation": "HasSubEvent", "tails": []} 23 | {"head": "to get out of the room", "relation": "xIntent", "tails": []} 24 | {"head": "to get out of the room", "relation": "xNeed", "tails": []} 25 | {"head": "to get out of the room", "relation": "xEffect", "tails": []} 26 | {"head": "post office", "relation": "xWant", "tails": []} 27 | {"head": "post office", "relation": "HinderedBy", "tails": []} 28 | {"head": "to get out of the room", "relation": "isAfter", "tails": []} 29 | {"head": "post office", "relation": "xAttr", "tails": []} 30 | {"head": "to get out of the room", "relation": "oEffect", "tails": []} 31 | {"head": "post office", "relation": "isFilledBy", "tails": []} 32 | {"head": "post office", "relation": "xReact", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_modelbased_relations.json: -------------------------------------------------------------------------------- 1 | {"head": "banana", "relation": "ObjectUse", "tails": []} 2 | {"head": "banana", "relation": "CapableOf", "tails": []} 3 | {"head": "banana", "relation": "MadeUpOf", "tails": []} 4 | {"head": "banana", "relation": "HasProperty", "tails": []} 5 | {"head": "banana", "relation": "Desires", "tails": []} 6 | {"head": "banana", "relation": "NotDesires", "tails": []} 7 | {"head": "banana", "relation": "AtLocation", "tails": []} 8 | {"head": "love another", "relation": "Causes", "tails": []} 9 | {"head": "love another", "relation": "HinderedBy", "tails": []} 10 | {"head": "love another", "relation": "xReason", "tails": []} 11 | {"head": "love another", "relation": "isAfter", "tails": []} 12 | {"head": "love another", "relation": "isBefore", "tails": []} 13 | {"head": "love another", "relation": "HasSubEvent", "tails": []} 14 | {"head": "love another", "relation": "isFilledBy", "tails": []} 15 | {"head": "love another", "relation": "xIntent", "tails": []} 16 | {"head": "love another", "relation": "xReact", "tails": []} 17 | {"head": "love another", "relation": "oReact", "tails": []} 18 | {"head": "love another", "relation": "xAttr", "tails": []} 19 | {"head": "love another", "relation": "xEffect", "tails": []} 20 | {"head": "love another", "relation": "xNeed", "tails": []} 21 | {"head": "love another", "relation": "xWant", "tails": []} 22 | {"head": "love another", "relation": "oEffect", "tails": []} 23 | {"head": "love another", "relation": "oWant", "tails": []} 24 | {"head": "Student gets a card", "relation": "Causes", "tails": []} 25 | {"head": "Student gets a card", "relation": "HinderedBy", "tails": []} 26 | {"head": "Student gets a card", "relation": "xReason", "tails": []} 27 | {"head": "Student gets a card", "relation": "isAfter", "tails": []} 28 | {"head": "Student gets a card", "relation": "isBefore", "tails": []} 29 | {"head": "Student gets a card", "relation": "HasSubEvent", "tails": []} 30 | {"head": "Student gets a card", "relation": "isFilledBy", "tails": []} 31 | {"head": "Student gets a card", "relation": "xIntent", "tails": []} 32 | {"head": "Student gets a card", "relation": "xReact", "tails": []} 33 | {"head": "Student gets a card", "relation": "oReact", "tails": []} 34 | {"head": "Student gets a card", "relation": "xAttr", "tails": []} 35 | {"head": "Student gets a card", "relation": "xEffect", "tails": []} 36 | {"head": "Student gets a card", "relation": "xNeed", "tails": []} 37 | {"head": "Student gets a card", "relation": "xWant", "tails": []} 38 | {"head": "Student gets a card", "relation": "oEffect", "tails": []} 39 | {"head": "Student gets a card", "relation": "oWant", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_modelbased_relations_bert.json: -------------------------------------------------------------------------------- 1 | {"head": "Student gets a card", "relation": "Causes", "tails": []} 2 | {"head": "banana", "relation": "CapableOf", "tails": []} 3 | {"head": "Student gets a card", "relation": "isFilledBy", "tails": []} 4 | {"head": "love another", "relation": "isAfter", "tails": []} 5 | {"head": "Student gets a card", "relation": "xReact", "tails": []} 6 | {"head": "Student gets a card", "relation": "xReason", "tails": []} 7 | {"head": "Student gets a card", "relation": "oEffect", "tails": []} 8 | {"head": "Student gets a card", "relation": "isBefore", "tails": []} 9 | {"head": "love another", "relation": "Causes", "tails": []} 10 | {"head": "love another", "relation": "isFilledBy", "tails": []} 11 | {"head": "Student gets a card", "relation": "oWant", "tails": []} 12 | {"head": "love another", "relation": "HasSubEvent", "tails": []} 13 | {"head": "banana", "relation": "AtLocation", "tails": []} 14 | {"head": "banana", "relation": "Desires", "tails": []} 15 | {"head": "love another", "relation": "xReason", "tails": []} 16 | {"head": "Student gets a card", "relation": "oReact", "tails": []} 17 | {"head": "banana", "relation": "MadeUpOf", "tails": []} 18 | {"head": "love another", "relation": "isBefore", "tails": []} 19 | {"head": "Student gets a card", "relation": "xWant", "tails": []} 20 | {"head": "banana", "relation": "HasProperty", "tails": []} 21 | {"head": "Student gets a card", "relation": "HinderedBy", "tails": []} 22 | {"head": "Student gets a card", "relation": "HasSubEvent", "tails": []} 23 | {"head": "Student gets a card", "relation": "xIntent", "tails": []} 24 | {"head": "Student gets a card", "relation": "xAttr", "tails": []} 25 | {"head": "Student gets a card", "relation": "xNeed", "tails": []} 26 | {"head": "banana", "relation": "ObjectUse", "tails": []} 27 | {"head": "Student gets a card", "relation": "xEffect", "tails": []} 28 | {"head": "love another", "relation": "HinderedBy", "tails": []} 29 | {"head": "banana", "relation": "NotDesires", "tails": []} 30 | {"head": "Student gets a card", "relation": "isAfter", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_modelbased_relations_dbert.json: -------------------------------------------------------------------------------- 1 | {"head": "love another", "relation": "xEffect", "tails": []} 2 | {"head": "banana", "relation": "CapableOf", "tails": []} 3 | {"head": "Student gets a card", "relation": "ObjectUse", "tails": []} 4 | {"head": "love another", "relation": "isAfter", "tails": []} 5 | {"head": "love another", "relation": "Causes", "tails": []} 6 | {"head": "love another", "relation": "isFilledBy", "tails": []} 7 | {"head": "Student gets a card", "relation": "NotDesires", "tails": []} 8 | {"head": "Student gets a card", "relation": "HasProperty", "tails": []} 9 | {"head": "banana", "relation": "AtLocation", "tails": []} 10 | {"head": "love another", "relation": "xReact", "tails": []} 11 | {"head": "Student gets a card", "relation": "CapableOf", "tails": []} 12 | {"head": "banana", "relation": "Desires", "tails": []} 13 | {"head": "love another", "relation": "oEffect", "tails": []} 14 | {"head": "banana", "relation": "MadeUpOf", "tails": []} 15 | {"head": "love another", "relation": "xReason", "tails": []} 16 | {"head": "love another", "relation": "isBefore", "tails": []} 17 | {"head": "love another", "relation": "oWant", "tails": []} 18 | {"head": "banana", "relation": "HasProperty", "tails": []} 19 | {"head": "love another", "relation": "oReact", "tails": []} 20 | {"head": "banana", "relation": "ObjectUse", "tails": []} 21 | {"head": "Student gets a card", "relation": "AtLocation", "tails": []} 22 | {"head": "Student gets a card", "relation": "Desires", "tails": []} 23 | {"head": "love another", "relation": "xWant", "tails": []} 24 | {"head": "love another", "relation": "HinderedBy", "tails": []} 25 | {"head": "Student gets a card", "relation": "MadeUpOf", "tails": []} 26 | {"head": "banana", "relation": "NotDesires", "tails": []} 27 | {"head": "love another", "relation": "HasSubEvent", "tails": []} 28 | {"head": "love another", "relation": "xAttr", "tails": []} 29 | {"head": "love another", "relation": "xNeed", "tails": []} 30 | {"head": "love another", "relation": "xIntent", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_modelbased_relations_swem.json: -------------------------------------------------------------------------------- 1 | {"head": "love another", "relation": "xEffect", "tails": []} 2 | {"head": "Student gets a card", "relation": "Causes", "tails": []} 3 | {"head": "banana", "relation": "CapableOf", "tails": []} 4 | {"head": "Student gets a card", "relation": "isFilledBy", "tails": []} 5 | {"head": "love another", "relation": "isAfter", "tails": []} 6 | {"head": "Student gets a card", "relation": "xReact", "tails": []} 7 | {"head": "Student gets a card", "relation": "xReason", "tails": []} 8 | {"head": "Student gets a card", "relation": "isBefore", "tails": []} 9 | {"head": "banana", "relation": "AtLocation", "tails": []} 10 | {"head": "love another", "relation": "oEffect", "tails": []} 11 | {"head": "banana", "relation": "MadeUpOf", "tails": []} 12 | {"head": "Student gets a card", "relation": "oReact", "tails": []} 13 | {"head": "love another", "relation": "oWant", "tails": []} 14 | {"head": "Student gets a card", "relation": "HasSubEvent", "tails": []} 15 | {"head": "Student gets a card", "relation": "xIntent", "tails": []} 16 | {"head": "Student gets a card", "relation": "xNeed", "tails": []} 17 | {"head": "love another", "relation": "xWant", "tails": []} 18 | {"head": "Student gets a card", "relation": "xEffect", "tails": []} 19 | {"head": "love another", "relation": "HinderedBy", "tails": []} 20 | {"head": "Student gets a card", "relation": "isAfter", "tails": []} 21 | {"head": "love another", "relation": "xAttr", "tails": []} 22 | {"head": "Student gets a card", "relation": "oEffect", "tails": []} 23 | {"head": "love another", "relation": "Causes", "tails": []} 24 | {"head": "love another", "relation": "isFilledBy", "tails": []} 25 | {"head": "Student gets a card", "relation": "oWant", "tails": []} 26 | {"head": "love another", "relation": "xReact", "tails": []} 27 | {"head": "banana", "relation": "Desires", "tails": []} 28 | {"head": "love another", "relation": "xReason", "tails": []} 29 | {"head": "love another", "relation": "isBefore", "tails": []} 30 | {"head": "Student gets a card", "relation": "xWant", "tails": []} 31 | {"head": "banana", "relation": "HasProperty", "tails": []} 32 | {"head": "Student gets a card", "relation": "HinderedBy", "tails": []} 33 | {"head": "Student gets a card", "relation": "xAttr", "tails": []} 34 | {"head": "love another", "relation": "oReact", "tails": []} 35 | {"head": "banana", "relation": "ObjectUse", "tails": []} 36 | {"head": "banana", "relation": "NotDesires", "tails": []} 37 | {"head": "love another", "relation": "HasSubEvent", "tails": []} 38 | {"head": "love another", "relation": "xNeed", "tails": []} 39 | {"head": "love another", "relation": "xIntent", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_no_head_extract.json: -------------------------------------------------------------------------------- 1 | {"head": "Gabby always brought cookies to school.", "relation": "xWant", "tails": []} 2 | {"head": "Gabby always brought cookies to school.", "relation": "HinderedBy", "tails": []} 3 | {"head": "Gabby always brought cookies to school.", "relation": "HasSubEvent", "tails": []} 4 | {"head": "Gabby always brought cookies to school.", "relation": "xIntent", "tails": []} 5 | {"head": "Gabby always brought cookies to school.", "relation": "xNeed", "tails": []} 6 | {"head": "Gabby always brought cookies to school.", "relation": "xEffect", "tails": []} 7 | {"head": "Gabby always brought cookies to school.", "relation": "xAttr", "tails": []} 8 | {"head": "Gabby always brought cookies to school.", "relation": "isAfter", "tails": []} 9 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []} 10 | {"head": "Gabby always brought cookies to school.", "relation": "isFilledBy", "tails": []} 11 | {"head": "Gabby always brought cookies to school.", "relation": "xReact", "tails": []} 12 | {"head": "Gabby always brought cookies to school.", "relation": "xReason", "tails": []} 13 | {"head": "Gabby always brought cookies to school.", "relation": "oEffect", "tails": []} 14 | {"head": "Gabby always brought cookies to school.", "relation": "isBefore", "tails": []} 15 | {"head": "Gabby always brought cookies to school.", "relation": "oWant", "tails": []} 16 | {"head": "Gabby always brought cookies to school.", "relation": "oReact", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_no_match_subset.json: -------------------------------------------------------------------------------- 1 | {"head": "to bring", "relation": "Desires", "tails": []} 2 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []} 3 | {"head": "cookies", "relation": "Causes", "tails": []} 4 | {"head": "Gabby", "relation": "Causes", "tails": []} 5 | {"head": "school", "relation": "Causes", "tails": []} 6 | {"head": "Gabby", "relation": "Desires", "tails": []} 7 | {"head": "bring cookies", "relation": "Causes", "tails": []} 8 | {"head": "to bring", "relation": "Causes", "tails": []} 9 | {"head": "Gabby always brought cookies to school.", "relation": "Desires", "tails": []} 10 | {"head": "school", "relation": "Desires", "tails": []} 11 | {"head": "cookies", "relation": "Desires", "tails": []} 12 | {"head": "bring cookies", "relation": "Desires", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_rel_subset.json: -------------------------------------------------------------------------------- 1 | {"head": "Gabby", "relation": "xNeed", "tails": []} 2 | {"head": "school", "relation": "Causes", "tails": []} 3 | {"head": "Gabby always brought cookies to school.", "relation": "xNeed", "tails": []} 4 | {"head": "bring cookies", "relation": "Causes", "tails": []} 5 | {"head": "to bring", "relation": "Causes", "tails": []} 6 | {"head": "school", "relation": "ObjectUse", "tails": []} 7 | {"head": "Gabby", "relation": "Causes", "tails": []} 8 | {"head": "cookies", "relation": "xNeed", "tails": []} 9 | {"head": "Gabby", "relation": "ObjectUse", "tails": []} 10 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []} 11 | {"head": "cookies", "relation": "Causes", "tails": []} 12 | {"head": "cookies", "relation": "ObjectUse", "tails": []} 13 | {"head": "school", "relation": "xNeed", "tails": []} 14 | {"head": "bring cookies", "relation": "xNeed", "tails": []} 15 | {"head": "to bring", "relation": "xNeed", "tails": []} -------------------------------------------------------------------------------- /examples/results/kgraph_with_context.json: -------------------------------------------------------------------------------- 1 | {"head": "PersonX wraps gifts", "relation": "oReact", "tails": [" happy", " grateful"]} 2 | {"head": "gifts", "relation": "AtLocation", "tails": [" store", " gift box"]} 3 | {"head": "gifts", "relation": "MadeUpOf", "tails": [" give as a gift", " gift", " give as gift"]} 4 | {"head": "PersonX wraps gifts", "relation": "HasSubEvent", "tails": [" PersonX receives a gift", " PersonX receives a gift."]} 5 | {"head": "PersonX wraps gifts", "relation": "xIntent", "tails": [" to be nice", " to give a gift", " to be generous"]} 6 | {"head": "PersonX wraps gifts", "relation": "xNeed", "tails": [" to buy wrapping paper", " to buy the wrapping paper", " to buy wrapping paper."]} 7 | {"head": "PersonX wraps gifts", "relation": "xEffect", "tails": [" PersonX receives a gift in return.", " PersonX receives a gift.", " PersonX receives a gift in return"]} 8 | {"head": "PersonX wraps gifts", "relation": "isAfter", "tails": [" PersonX buys a gift", " PersonX buys a gift for PersonY"]} 9 | {"head": "wraps gifts", "relation": "Desires", "tails": [" wrap presents", " wrap gift", " give to person"]} 10 | {"head": "wraps gifts", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped paper"]} 11 | {"head": "wraps", "relation": "Desires", "tails": [" wrap gift", " wrap presents", " wrap presents in"]} 12 | {"head": "PersonX wraps gifts", "relation": "oEffect", "tails": [" PersonY opens the gift.", " PersonY opens the gift"]} 13 | {"head": "wraps gifts", "relation": "ObjectUse", "tails": [" wrap presents", " wrap presents in", " give to person Y"]} 14 | {"head": "wraps", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped in paper"]} 15 | {"head": "PersonX wraps gifts", "relation": "oWant", "tails": [" to thank PersonX", " to open the gift"]} 16 | {"head": "wraps", "relation": "ObjectUse", "tails": [" wrap the gift in", " wrap a gift in", " wrap a present in"]} 17 | {"head": "gifts", "relation": "Desires", "tails": [" give as a gift", " give to someone", " give to person"]} 18 | {"head": "PersonX wraps gifts", "relation": "xWant", "tails": [" to give the gifts to the person", " to give the gift to the person", " to give the gifts to their friends"]} 19 | {"head": "wraps gifts", "relation": "CapableOf", "tails": [" wrap gift", " wrap presents", " wrap"]} 20 | {"head": "PersonX wraps gifts", "relation": "HinderedBy", "tails": [" PersonX doesn't have any wrapping paper.", " PersonX has no wrapping paper.", " PersonX doesn't know how to wrap."]} 21 | {"head": "gifts", "relation": "HasProperty", "tails": [" given as gifts", " gifts", " give to person"]} 22 | {"head": "PersonX wraps gifts", "relation": "xAttr", "tails": [" thoughtful", " generous", " giving"]} 23 | {"head": "wraps", "relation": "CapableOf", "tails": [" wrap gift", " wrap gift in", " wrap presents in"]} 24 | {"head": "gifts", "relation": "ObjectUse", "tails": [" give as a gift", " give as a birthday gift", " give to someone"]} 25 | {"head": "wraps gifts", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap paper", " wrap"]} 26 | {"head": "wraps gifts", "relation": "AtLocation", "tails": [" box"]} 27 | {"head": "PersonX wraps gifts", "relation": "Causes", "tails": [" PersonX wants to be nice.", " PersonX wants to be generous.", " PersonX wants to be nice"]} 28 | {"head": "PersonX wraps gifts", "relation": "isFilledBy", "tails": [" gift bag", " presents", " gifts"]} 29 | {"head": "PersonX wraps gifts", "relation": "xReact", "tails": [" happy"]} 30 | {"head": "wraps", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap"]} 31 | {"head": "gifts", "relation": "CapableOf", "tails": [" give to person", " give to people", " give to persony"]} 32 | {"head": "PersonX wraps gifts", "relation": "xReason", "tails": [" PersonX buys a gift", " PersonX buys a gift for PersonY"]} 33 | {"head": "PersonX wraps gifts", "relation": "isBefore", "tails": [" PersonX puts the gifts under the tree", " PersonX puts the gifts in the gift bag", " PersonX gives the gifts to the kids"]} -------------------------------------------------------------------------------- /examples/results/kgraph_without_context.json: -------------------------------------------------------------------------------- 1 | {"head": "PersonX wraps gifts", "relation": "oReact", "tails": [" none", " happy", " grateful"]} 2 | {"head": "gifts", "relation": "AtLocation", "tails": [" gift shop", " store", " gift box"]} 3 | {"head": "gifts", "relation": "MadeUpOf", "tails": [" give as a gift", " gift", " give as gift"]} 4 | {"head": "PersonX wraps gifts", "relation": "HasSubEvent", "tails": [" PersonX opens the gift", " PersonX receives a gift", " PersonX receives a gift."]} 5 | {"head": "PersonX wraps gifts", "relation": "xIntent", "tails": [" to be nice", " to give a gift", " to be generous"]} 6 | {"head": "PersonX wraps gifts", "relation": "xNeed", "tails": [" to buy wrapping paper", " to buy the wrapping paper", " to buy wrapping paper."]} 7 | {"head": "PersonX wraps gifts", "relation": "xEffect", "tails": [" PersonX receives a gift in return.", " PersonX receives a gift.", " PersonX receives a gift in return"]} 8 | {"head": "PersonX wraps gifts", "relation": "isAfter", "tails": [" PersonX goes to the store", " PersonX buys a gift", " PersonX buys a gift for PersonY"]} 9 | {"head": "wraps gifts", "relation": "Desires", "tails": [" wrap presents", " wrap gift", " give to person"]} 10 | {"head": "wraps gifts", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped paper"]} 11 | {"head": "wraps", "relation": "Desires", "tails": [" wrap gift", " wrap presents", " wrap presents in"]} 12 | {"head": "PersonX wraps gifts", "relation": "oEffect", "tails": [" none", " PersonY opens the gift.", " PersonY opens the gift"]} 13 | {"head": "wraps gifts", "relation": "ObjectUse", "tails": [" wrap presents", " wrap presents in", " give to person Y"]} 14 | {"head": "wraps", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped in paper"]} 15 | {"head": "PersonX wraps gifts", "relation": "oWant", "tails": [" none", " to thank PersonX", " to open the gift"]} 16 | {"head": "wraps gifts", "relation": "NotDesires", "tails": [" wrapping paper", " wrap presents", " wrap gift"]} 17 | {"head": "wraps", "relation": "ObjectUse", "tails": [" wrap the gift in", " wrap a gift in", " wrap a present in"]} 18 | {"head": "gifts", "relation": "Desires", "tails": [" give as a gift", " give to someone", " give to person"]} 19 | {"head": "PersonX wraps gifts", "relation": "xWant", "tails": [" to give the gifts to the person", " to give the gift to the person", " to give the gifts to their friends"]} 20 | {"head": "wraps", "relation": "NotDesires", "tails": [" wrap around body", " wrap up a gift", " wrap up a present"]} 21 | {"head": "wraps gifts", "relation": "CapableOf", "tails": [" wrap gift", " wrap presents", " wrap"]} 22 | {"head": "PersonX wraps gifts", "relation": "HinderedBy", "tails": [" PersonX doesn't have any wrapping paper.", " PersonX has no wrapping paper.", " PersonX doesn't know how to wrap."]} 23 | {"head": "gifts", "relation": "HasProperty", "tails": [" given as gifts", " gifts", " give to person"]} 24 | {"head": "PersonX wraps gifts", "relation": "xAttr", "tails": [" thoughtful", " generous", " giving"]} 25 | {"head": "wraps", "relation": "CapableOf", "tails": [" wrap gift", " wrap gift in", " wrap presents in"]} 26 | {"head": "gifts", "relation": "ObjectUse", "tails": [" give as a gift", " give as a birthday gift", " give to someone"]} 27 | {"head": "wraps gifts", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap paper", " wrap"]} 28 | {"head": "wraps gifts", "relation": "AtLocation", "tails": [" gift shop", " mail box", " box"]} 29 | {"head": "gifts", "relation": "NotDesires", "tails": [" give as a gift", " give to person", " give to someone"]} 30 | {"head": "PersonX wraps gifts", "relation": "Causes", "tails": [" PersonX wants to be nice.", " PersonX wants to be generous.", " PersonX wants to be nice"]} 31 | {"head": "PersonX wraps gifts", "relation": "isFilledBy", "tails": [" gift bag", " presents", " gifts"]} 32 | {"head": "wraps", "relation": "AtLocation", "tails": [" store", " box", " mail box"]} 33 | {"head": "PersonX wraps gifts", "relation": "xReact", "tails": [" happy", " good about themselves", " good about themselves."]} 34 | {"head": "wraps", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap", " wrap cloth"]} 35 | {"head": "gifts", "relation": "CapableOf", "tails": [" give to person", " give to people", " give to persony"]} 36 | {"head": "PersonX wraps gifts", "relation": "xReason", "tails": [" PersonX buys a lot of wrapping paper", " PersonX buys a gift", " PersonX buys a gift for PersonY"]} 37 | {"head": "PersonX wraps gifts", "relation": "isBefore", "tails": [" PersonX puts the gifts under the tree", " PersonX puts the gifts in the gift bag", " PersonX gives the gifts to the kids"]} -------------------------------------------------------------------------------- /examples/results/test_atomic2020_res_cometgpt2_sample.json: -------------------------------------------------------------------------------- 1 | {"head": "PersonX takes things for granted", "relation": "xNeed", "tails": ["to have a lot of work to do"]} 2 | {"head": "PersonX calls PersonY ambulance", "relation": "xNeed", "tails": ["to pick up the phone"]} 3 | {"head": "PersonX pleases ___ to make", "relation": "xWant", "tails": ["to be successful"]} 4 | {"head": "PersonX shoves PersonY back", "relation": "xEffect", "tails": ["none"]} 5 | {"head": "PersonX dates for years", "relation": "xWant", "tails": ["to get married"]} 6 | {"head": "PersonX covers every aspect", "relation": "isAfter", "tails": ["PersonX is a lawyer"]} 7 | {"head": "PersonX wants to go", "relation": "isAfter", "tails": ["PersonX is in the hospital"]} 8 | {"head": "PersonX hits by lightning", "relation": "xEffect", "tails": ["is in pain"]} 9 | {"head": "PersonX finally meet PersonY", "relation": "xNeed", "tails": ["to go to the meeting"]} 10 | {"head": "chain", "relation": "ObjectUse", "tails": ["tie the dog up"]} -------------------------------------------------------------------------------- /examples/results/test_atomic2020_res_zeroshot_sample.jsonl: -------------------------------------------------------------------------------- 1 | {"head": "PersonX takes things for granted", "relation": "xNeed", "tails": ["be able to do something that is not possible with other people. Now, PersonX needs to be able to do something that is not possible with other people", "be able to do something that is not possible with other people. Now, PersonX needs to be able to do something that is not possible with other people", "be able to do something that is not possible with other people. Now, PersonX needs to be able to do something that is not possible with other people"]} 2 | {"head": "PersonX calls PersonY ambulance", "relation": "xNeed", "tails": ["get a call from PersonY.", "get a call from PersonY.", "get a call from PersonY."]} 3 | {"head": "PersonX pleases ___ to make", "relation": "xWant", "tails": ["____ the person who is not a member of the group.", "____ the person who is not a member of the group.", "____ the person who is not a member of the group."]} 4 | {"head": "PersonX shoves PersonY back", "relation": "xEffect", "tails": ["increased by 1.5x.", "increased by 1.5x.", "increased by 1.5x."]} 5 | {"head": "PersonX dates for years", "relation": "xWant", "tails": ["have a relationship with you.", "have a relationship with you.", "have a relationship with you."]} 6 | {"head": "PersonX covers every aspect", "relation": "isAfter", "tails": ["I've covered the basics of how to use the app, how to use the app's built-in camera, how to use the app's built-", "I've covered the basics of how to use the app, how to use the app's built-in camera, how to use the app's built-", "I've covered the basics of how to use the app, how to use the app's built-in camera, how to use the app's built-"]} 7 | {"head": "PersonX wants to go", "relation": "isAfter", "tails": ["he's been working on his own project.", "he's been working on his own project.", "he's been working on his own project."]} 8 | {"head": "PersonX hits by lightning", "relation": "xEffect", "tails": ["increased by 1% per level.", "increased by 1% per level.", "increased by 1% per level."]} 9 | {"head": "PersonX finally meet PersonY", "relation": "xNeed", "tails": ["get a job and PersonY needs to get a job.", "get a job and PersonY needs to get a job.", "get a job and PersonY needs to get a job."]} 10 | {"head": "chain", "relation": "ObjectUse", "tails": ["the following:", "the following:", "the same purpose."]} -------------------------------------------------------------------------------- /examples/sample_graph.jsonl: -------------------------------------------------------------------------------- 1 | {"source": "PersonX buys lunch", "rel": "xNeed", "tails": ["bring a wallet"]} 2 | {"source": "Throwing a party", "rel": "Causes", "tails": ["have fun"]} -------------------------------------------------------------------------------- /examples/sample_graph.tsv: -------------------------------------------------------------------------------- 1 | PersonX always drank xNeed to get a drink 2 | PersonX abandons ___ altogether xNeed Plows the field. 3 | PersonX about to get married xNeed meet someone 4 | PersonX accepts PersonY appointment xNeed to clear spot in schedule 5 | PersonX accepts PersonY thanks xNeed to says thanks to Y 6 | PersonX accepts PersonY's proposal xNeed to date 7 | PersonX accidentally bumped xNeed run or drive fast 8 | PersonX accidentally fell xNeed to walk too fast 9 | PersonX accompanies PersonY far xNeed to walk with PersonY 10 | PersonX accomplishes PersonX's goal xNeed to make a goal -------------------------------------------------------------------------------- /examples/sample_graph2.tsv: -------------------------------------------------------------------------------- 1 | PersonX always drank xNeed to get a drink 2 | PersonX abandons ___ altogether xNeed Plows the field. 3 | PersonX about to get married xNeed meet someone 4 | PersonX accepts PersonY appointment xNeed to clear spot in schedule 5 | PersonX accepts PersonY thanks xNeed to says thanks to Y 6 | PersonX accepts PersonY's proposal xNeed to date 7 | PersonX accidentally bumped xNeed run or drive fast 8 | PersonX accidentally fell xNeed to walk too fast 9 | PersonX accompanies PersonY far xNeed to walk with PersonY 10 | PersonX accomplishes PersonX's goal xNeed to make a goal 11 | PersonX is at a party xWishes to drink beer and dance 12 | PersonX bleeds a lot xWishes to see a doctor 13 | PersonX works as a cashier xWishes to be a store manager 14 | PersonX gets dirty xWishes to clean up 15 | PersonX stays up all night studying xWishes to sleep all day 16 | PersonX gets PersonY's autograph xWishes to have a relationship with PersonY 17 | PersonX ends a friendship xWishes to meet new people 18 | PersonX makes his own costume xWishes to go to a costume party 19 | PersonX calls PersonY xWishes to have a long chat 20 | PersonX tells PersonY a secret xWishes to get PersonY's advice 21 | PersonX mows the lawn xWishes to get a new lawnmower 22 | -------------------------------------------------------------------------------- /examples/sample_graph3.jsonl: -------------------------------------------------------------------------------- 1 | {"head": "PersonX accepts PersonY appointment", "relation": "xNeed", "tails": ["to clear spot in schedule"]} 2 | {"head": "PersonX accepts PersonY thanks", "relation": "xNeed", "tails": ["to says thanks to Y"]} 3 | {"head": "PersonX abandons ___ altogether", "relation": "xNeed", "tails": ["Plows the field."]} 4 | {"head": "PersonX accompanies PersonY far", "relation": "xNeed", "tails": ["to walk with PersonY"]} 5 | {"head": "PersonX always drank", "relation": "xNeed", "tails": ["to get a drink"]} 6 | {"head": "PersonX about to get married", "relation": "xNeed", "tails": ["meet someone"]} 7 | {"head": "PersonX accidentally fell", "relation": "xNeed", "tails": ["to walk too fast"]} 8 | {"head": "PersonX accidentally bumped", "relation": "xNeed", "tails": ["run or drive fast"]} 9 | {"head": "PersonX accepts PersonY's proposal", "relation": "xNeed", "tails": ["to date"]} 10 | {"head": "PersonX accomplishes PersonX's goal", "relation": "xNeed", "tails": ["to make a goal"]} -------------------------------------------------------------------------------- /examples/sample_linking_graph.csv: -------------------------------------------------------------------------------- 1 | PersonX drives ___ to work | xWant | his friend to get to work on time 2 | PersonX drives ___ fast | xWant | to get out of the car 3 | drive | HasSubEvent | get into car -------------------------------------------------------------------------------- /examples/snapshots/custom-relation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/custom-relation.png -------------------------------------------------------------------------------- /examples/snapshots/kg-concepts.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/kg-concepts.png -------------------------------------------------------------------------------- /examples/snapshots/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/pipeline.png -------------------------------------------------------------------------------- /examples/snapshots/quickstart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/quickstart.png -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import pathlib 4 | 5 | from kogito.inference import CommonsenseInference 6 | from kogito.models.bart.comet import COMETBART 7 | 8 | 9 | def main(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--datapath", type=str) 12 | parser.add_argument("--output-dir", type=str) 13 | 14 | args = parser.parse_args() 15 | 16 | with open(args.datapath) as f: 17 | data = f.readlines() 18 | 19 | csi = CommonsenseInference() 20 | model = COMETBART.from_pretrained() 21 | 22 | for i, line in tqdm( 23 | enumerate(data), total=len(data), desc="Generating", position=0, leave=True 24 | ): 25 | kgraph = csi.infer(line, model) 26 | kgraph.to_jsonl( 27 | f"{args.output_dir}/{pathlib.Path(args.datapath).stem}_{i+1}.json" 28 | ) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /examples/test_atomic2020_sample.json: -------------------------------------------------------------------------------- 1 | {"relation": "xNeed", "head": "PersonX takes things for granted", "tails": ["to have been lazy at work", "to have wasted resources", "none", "to have used their money on unnecessary things"]} 2 | {"relation": "xNeed", "head": "PersonX calls PersonY ambulance", "tails": ["get the phone", "recognize the person needs an ambulance", "to see Person Y is hurt.", "to get a phone.", "none"]} 3 | {"relation": "xWant", "head": "PersonX pleases ___ to make", "tails": ["to say thanks for making it", "to eat it", "to produce something", "to be in their good graces"]} 4 | {"relation": "xEffect", "head": "PersonX shoves PersonY back", "tails": ["is physical", "none", "is violent"]} 5 | {"relation": "xWant", "head": "PersonX dates for years", "tails": ["to keep in contact for 2nd dates", "to propose", "to go to the movies", "to go to dinner", "to get married"]} 6 | {"relation": "isAfter", "head": "PersonX covers every aspect", "tails": ["PersonX researches an article about current events"]} 7 | {"relation": "isAfter", "head": "PersonX wants to go", "tails": ["PersonX sees a flyer for a festival"]} 8 | {"relation": "xEffect", "head": "PersonX hits by lightning", "tails": ["Has heart attack", "goes to the hospital", "Has hair burned", "gets hurt"]} 9 | {"relation": "xNeed", "head": "PersonX finally meet PersonY", "tails": ["ask friend to be introduced to PersonY", "Get dressed", "Find person", "to search for person Y", "travels to meet PersonY", "to be away from person Y for sometime"]} 10 | {"relation": "ObjectUse", "head": "chain", "tails": ["hold the swing", "put around it", "lock up a vicious pet.", "make spooky noises in a haunted house", "attach the mooring", "lock the bike", "lock up the bike while they look around", "hold something on the wall", "to chain bicycle to the bicycle rack", "take with them", "decorate the house", "hold down a garage door without locks.", "take off", "keep pick-pockets from stealing their wallet", "pull the car", "hold their spectacles on their neck", "hold the key", "catch pants on", "connect the truck and the car", "lock it up", "hook it up too", "tug on a truck"]} -------------------------------------------------------------------------------- /examples/test_comet_bart.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.bart.comet import COMETBART 3 | 4 | input_graph = KnowledgeGraph.from_jsonl("test_atomic2020_sample.json") 5 | 6 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2") 7 | output_graph = model.generate( 8 | input_graph, batch_size=256, num_return_sequences=1 9 | ) 10 | output_graph.to_jsonl("results.json") 11 | -------------------------------------------------------------------------------- /examples/test_comet_gpt2.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.gpt2.comet import COMETGPT2 3 | 4 | model = COMETGPT2.from_pretrained("mismayil/comet-gpt2-ai2") 5 | input_graph = KnowledgeGraph.from_jsonl("./test_atomic2020_sample.json") 6 | output_graph = model.generate(input_graph) 7 | output_graph.to_jsonl("results/test_atomic2020_res_cometgpt2_sample.json") 8 | -------------------------------------------------------------------------------- /examples/test_gpt2.json: -------------------------------------------------------------------------------- 1 | {"head": "The researcher tweets about a new library", "relation": "xAttr", "tails": ["a fan of the library. PersonY is a critic. PersonZ is a neutral observer.", "a fan of the library. PersonY is a critic. PersonZ is a neutral observer.", "a fan of the library. PersonY is a critic. PersonZ is a neutral observer."]} 2 | {"head": "The researcher tweets about a new library", "relation": "xWant", "tails": ["find out more about the library. PersonX will then tweet about the library. PersonX will then tweet about the library again. PersonX will then tweet", "find out more about the library. PersonX will then tweet about the library. PersonX will then tweet about the library again. PersonX will then tweet", "find out more about the library. PersonX will then tweet about the library. PersonX will then tweet about the library again. PersonX will then tweet"]} 3 | {"head": "The researcher tweets about a new library", "relation": "xNeed", "tails": ["find the library and then find the person who owns it. Now, PersonX can find the library and then find the person who owns it. This is", "find the library and then find the person who owns it. Now, PersonX can find the library and then find the person who owns it. This is", "find the library and then find the person who owns it. Now, PersonX can find the library and then find the person who owns it. This is"]} 4 | {"head": "The researcher tweets about a new library", "relation": "xEffect", "tails": ["that they will be more likely to visit the library.", "that they will be more likely to visit the library.", "that they will be more likely to visit the library."]} 5 | {"head": "The researcher tweets about a new library", "relation": "HinderedBy", "tails": ["the library was not open to the public.", "the library was not open to the public.", "the library was not open to the public."]} 6 | {"head": "The researcher tweets about a new library", "relation": "oWant", "tails": ["use it. The researcher tweets about a new library. After, others will want to use it. The researcher tweets about a new library. After, others", "use it. The researcher tweets about a new library. After, others will want to use it. The researcher tweets about a new library. After, others", "use it. The researcher tweets about a new library. After, others will want to use it. The researcher tweets about a new library. After, others"]} 7 | {"head": "The researcher tweets about a new library", "relation": "xReact", "tails": ["a library for people who are interested in the humanities. PersonY will be a library for people who are interested in the sciences.", "a library for people who are interested in the humanities. PersonY will be a library for people who are interested in the sciences.", "a library for people who are interested in the humanities. PersonY will be a library for people who are interested in the sciences."]} 8 | {"head": "The researcher tweets about a new library", "relation": "oEffect", "tails": ["positive or negative depending on the social network.", "positive or negative depending on the social network.", "positive or negative depending on the social network."]} 9 | {"head": "The researcher tweets about a new library", "relation": "xIntent", "tails": ["the library. PersonY did this to the library. PersonZ did this to the library.", "the library. PersonY did this to the library. PersonZ did this to the library.", "the library. PersonY did this to the library. PersonZ did this to the library."]} 10 | {"head": "The researcher tweets about a new library", "relation": "oReact", "tails": ["that they have to follow suit.", "that they have to follow suit.", "that they have to follow suit."]} 11 | {"head": "The researcher tweets about a new library", "relation": "isBefore", "tails": ["the library's Twitter account tweets about the researcher.", "the library's Twitter account tweets about the researcher.", "the library's Twitter account tweets about the researcher."]} 12 | {"head": "The researcher tweets about a new library", "relation": "isAfter", "tails": ["he tweets about a new library. Before that, he tweets about a new library. Before that, he tweets about a new library. Before that, he", "he tweets about a new library. Before that, he tweets about a new library. Before that, he tweets about a new library. Before that, he", "he tweets about a new library. Before that, he tweets about a new library. Before that, he tweets about a new library. Before that, he"]} 13 | {"head": "The researcher tweets about a new library", "relation": "HasSubEvent", "tails": ["be able to see the library's name, location, and the number of books in the library.", "be able to see the library's name, location, and the number of books in the library.", "be able to see the library's name, location, and the number of books in the library."]} 14 | {"head": "The researcher tweets about a new library", "relation": "Causes", "tails": ["a flurry of activity on Twitter. \u00a0The library is then mentioned in the news, and the library's website is mentioned in the news.", "a flurry of activity on Twitter. \u00a0The library is then mentioned in the news, and the library's website is mentioned in the news.", "a flurry of activity on Twitter. \u00a0The library is then mentioned in the news, and the library's website is mentioned in the news."]} 15 | {"head": "The researcher tweets about a new library", "relation": "xReason", "tails": ["he was bored. PersonY did this because he was bored. PersonZ did this because he was bored.", "he was bored. PersonY did this because he was bored. PersonZ did this because he was bored.", "he was bored. PersonY did this because he was bored. PersonZ did this because he was bored."]} -------------------------------------------------------------------------------- /examples/test_zeroshot.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.gpt2.zeroshot import GPT2Zeroshot 3 | 4 | input_graph = KnowledgeGraph.from_jsonl("./test_atomic2020_sample.json") 5 | 6 | model = GPT2Zeroshot() 7 | output_graph = model.generate(input_graph) 8 | output_graph.to_jsonl("results/test_atomic2020_res_zeroshot_sample.jsonl") 9 | -------------------------------------------------------------------------------- /examples/train_comet_bart.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.bart.comet import COMETBART, COMETBARTConfig 3 | 4 | 5 | config = COMETBARTConfig( 6 | output_dir="models/comet-bart", 7 | task="summarization", 8 | n_val=100, 9 | num_workers=1, 10 | learning_rate=1e-5, 11 | gpus=0, 12 | sortish_sampler=True, 13 | atomic=True, 14 | train_batch_size=32, 15 | eval_batch_size=32, 16 | max_epochs=1, 17 | pretrained_model="facebook/bart-large", 18 | ) 19 | model = COMETBART(config) 20 | train_graph = KnowledgeGraph.from_csv( 21 | "data/atomic2020_data-feb2021/train.tsv", header=None, sep="\t" 22 | ) 23 | val_graph = KnowledgeGraph.from_csv( 24 | "data/atomic2020_data-feb2021/dev.tsv", header=None, sep="\t" 25 | ) 26 | test_graph = KnowledgeGraph.from_csv( 27 | "data/atomic2020_data-feb2021/test.tsv", header=None, sep="\t" 28 | ) 29 | model.train( 30 | train_graph=train_graph, 31 | val_graph=val_graph, 32 | test_graph=test_graph, 33 | logger_name="wandb", 34 | ) 35 | -------------------------------------------------------------------------------- /examples/train_comet_gpt2.py: -------------------------------------------------------------------------------- 1 | from kogito.core.knowledge import KnowledgeGraph 2 | from kogito.models.gpt2.comet import COMETGPT2 3 | 4 | model = COMETGPT2("gpt2-xl") 5 | train_graph = KnowledgeGraph.from_csv( 6 | "data/atomic2020/sample_train.tsv", header=None, sep="\t" 7 | ) 8 | val_graph = KnowledgeGraph.from_csv( 9 | "data/atomic2020/sample_dev.tsv", header=None, sep="\t" 10 | ) 11 | model.train( 12 | train_graph=train_graph, 13 | val_graph=val_graph, 14 | batch_size=32, 15 | output_dir="models/comet-gpt2", 16 | epochs=1, 17 | lr=5e-5, 18 | ) 19 | -------------------------------------------------------------------------------- /experiments/relation_modeling/spacy/base_config.cfg: -------------------------------------------------------------------------------- 1 | # This is an auto-generated partial config. To use it with 'spacy train' 2 | # you can run spacy init fill-config to auto-fill all default settings: 3 | # python -m spacy init fill-config ./base_config.cfg ./config.cfg 4 | [paths] 5 | train = "data/train.spacy" 6 | dev = "data/dev.spacy" 7 | vectors = "en_core_web_lg" 8 | [system] 9 | gpu_allocator = null 10 | 11 | [nlp] 12 | lang = "en" 13 | pipeline = ["tok2vec","textcat_multilabel"] 14 | batch_size = 1000 15 | 16 | [components] 17 | 18 | [components.tok2vec] 19 | factory = "tok2vec" 20 | 21 | [components.tok2vec.model] 22 | @architectures = "spacy.Tok2Vec.v2" 23 | 24 | [components.tok2vec.model.embed] 25 | @architectures = "spacy.MultiHashEmbed.v2" 26 | width = ${components.tok2vec.model.encode.width} 27 | attrs = ["ORTH", "SHAPE"] 28 | rows = [5000, 2500] 29 | include_static_vectors = true 30 | 31 | [components.tok2vec.model.encode] 32 | @architectures = "spacy.MaxoutWindowEncoder.v2" 33 | width = 256 34 | depth = 8 35 | window_size = 1 36 | maxout_pieces = 3 37 | 38 | [components.textcat_multilabel] 39 | factory = "textcat_multilabel" 40 | 41 | [components.textcat_multilabel.model] 42 | @architectures = "spacy.TextCatEnsemble.v2" 43 | nO = null 44 | 45 | [components.textcat_multilabel.model.tok2vec] 46 | @architectures = "spacy.Tok2VecListener.v1" 47 | width = ${components.tok2vec.model.encode.width} 48 | 49 | [components.textcat_multilabel.model.linear_model] 50 | @architectures = "spacy.TextCatBOW.v2" 51 | exclusive_classes = false 52 | ngram_size = 1 53 | no_output_layer = false 54 | 55 | [corpora] 56 | 57 | [corpora.train] 58 | @readers = "spacy.Corpus.v1" 59 | path = ${paths.train} 60 | max_length = 0 61 | 62 | [corpora.dev] 63 | @readers = "spacy.Corpus.v1" 64 | path = ${paths.dev} 65 | max_length = 0 66 | 67 | [training] 68 | dev_corpus = "corpora.dev" 69 | train_corpus = "corpora.train" 70 | 71 | [training.optimizer] 72 | @optimizers = "Adam.v1" 73 | 74 | [training.batcher] 75 | @batchers = "spacy.batch_by_words.v1" 76 | discard_oversize = false 77 | tolerance = 0.2 78 | 79 | [training.batcher.size] 80 | @schedules = "compounding.v1" 81 | start = 100 82 | stop = 1000 83 | compound = 1.001 84 | 85 | [initialize] 86 | vectors = ${paths.vectors} -------------------------------------------------------------------------------- /experiments/relation_modeling/spacy/config.cfg: -------------------------------------------------------------------------------- 1 | [paths] 2 | train = "data/train.spacy" 3 | dev = "data/dev.spacy" 4 | vectors = "en_core_web_lg" 5 | init_tok2vec = null 6 | 7 | [system] 8 | gpu_allocator = null 9 | seed = 0 10 | 11 | [nlp] 12 | lang = "en" 13 | pipeline = ["tok2vec","textcat_multilabel"] 14 | batch_size = 1000 15 | disabled = [] 16 | before_creation = null 17 | after_creation = null 18 | after_pipeline_creation = null 19 | tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"} 20 | 21 | [components] 22 | 23 | [components.textcat_multilabel] 24 | factory = "textcat_multilabel" 25 | scorer = {"@scorers":"spacy.textcat_multilabel_scorer.v1"} 26 | threshold = 0.5 27 | 28 | [components.textcat_multilabel.model] 29 | @architectures = "spacy.TextCatEnsemble.v2" 30 | nO = null 31 | 32 | [components.textcat_multilabel.model.linear_model] 33 | @architectures = "spacy.TextCatBOW.v2" 34 | exclusive_classes = false 35 | ngram_size = 1 36 | no_output_layer = false 37 | nO = null 38 | 39 | [components.textcat_multilabel.model.tok2vec] 40 | @architectures = "spacy.Tok2VecListener.v1" 41 | width = ${components.tok2vec.model.encode.width} 42 | upstream = "*" 43 | 44 | [components.tok2vec] 45 | factory = "tok2vec" 46 | 47 | [components.tok2vec.model] 48 | @architectures = "spacy.Tok2Vec.v2" 49 | 50 | [components.tok2vec.model.embed] 51 | @architectures = "spacy.MultiHashEmbed.v2" 52 | width = ${components.tok2vec.model.encode.width} 53 | attrs = ["ORTH","SHAPE"] 54 | rows = [5000,2500] 55 | include_static_vectors = true 56 | 57 | [components.tok2vec.model.encode] 58 | @architectures = "spacy.MaxoutWindowEncoder.v2" 59 | width = 256 60 | depth = 8 61 | window_size = 1 62 | maxout_pieces = 3 63 | 64 | [corpora] 65 | 66 | [corpora.dev] 67 | @readers = "spacy.Corpus.v1" 68 | path = ${paths.dev} 69 | max_length = 0 70 | gold_preproc = false 71 | limit = 0 72 | augmenter = null 73 | 74 | [corpora.train] 75 | @readers = "spacy.Corpus.v1" 76 | path = ${paths.train} 77 | max_length = 0 78 | gold_preproc = false 79 | limit = 0 80 | augmenter = null 81 | 82 | [training] 83 | dev_corpus = "corpora.dev" 84 | train_corpus = "corpora.train" 85 | seed = ${system.seed} 86 | gpu_allocator = ${system.gpu_allocator} 87 | dropout = 0.1 88 | accumulate_gradient = 1 89 | patience = 1600 90 | max_epochs = 0 91 | max_steps = 20000 92 | eval_frequency = 200 93 | frozen_components = [] 94 | annotating_components = [] 95 | before_to_disk = null 96 | 97 | [training.batcher] 98 | @batchers = "spacy.batch_by_words.v1" 99 | discard_oversize = false 100 | tolerance = 0.2 101 | get_length = null 102 | 103 | [training.batcher.size] 104 | @schedules = "compounding.v1" 105 | start = 100 106 | stop = 1000 107 | compound = 1.001 108 | t = 0.0 109 | 110 | [training.logger] 111 | @loggers = "spacy.ConsoleLogger.v1" 112 | progress_bar = false 113 | 114 | [training.optimizer] 115 | @optimizers = "Adam.v1" 116 | beta1 = 0.9 117 | beta2 = 0.999 118 | L2_is_weight_decay = true 119 | L2 = 0.01 120 | grad_clip = 1.0 121 | use_averages = false 122 | eps = 0.00000001 123 | learn_rate = 0.001 124 | 125 | [training.score_weights] 126 | cats_score = 1.0 127 | cats_score_desc = null 128 | cats_micro_p = null 129 | cats_micro_r = null 130 | cats_micro_f = null 131 | cats_macro_p = null 132 | cats_macro_r = null 133 | cats_macro_f = null 134 | cats_macro_auc = null 135 | cats_f_per_type = null 136 | cats_macro_auc_per_type = null 137 | 138 | [pretraining] 139 | 140 | [initialize] 141 | vectors = ${paths.vectors} 142 | init_tok2vec = ${paths.init_tok2vec} 143 | vocab_data = null 144 | lookups = null 145 | before_init = null 146 | after_init = null 147 | 148 | [initialize.components] 149 | 150 | [initialize.tokenizer] -------------------------------------------------------------------------------- /experiments/relation_modeling/spacy/relation_model_spacy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from relation_modeling_utils import load_data\n", 10 | "\n", 11 | "train_data = load_data(\"data/atomic2020_data-feb2021/train.tsv\")\n", 12 | "dev_data = load_data(\"data/atomic2020_data-feb2021/dev.tsv\")" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "(53891, 4823)" 24 | ] 25 | }, 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "output_type": "execute_result" 29 | } 30 | ], 31 | "source": [ 32 | "len(train_data), len(dev_data)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "import spacy\n", 42 | "from tqdm import tqdm\n", 43 | "nlp = spacy.load(\"en_core_web_lg\")\n", 44 | "\n", 45 | "label_map = ['physical', 'event', 'social']\n", 46 | "\n", 47 | "def make_docs(data):\n", 48 | " \"\"\"\n", 49 | " this will take a list of texts and labels \n", 50 | " and transform them in spacy documents\n", 51 | " \n", 52 | " data: list(tuple(text, label))\n", 53 | " \n", 54 | " returns: List(spacy.Doc.doc)\n", 55 | " \"\"\"\n", 56 | " \n", 57 | " docs = []\n", 58 | " \n", 59 | " for doc, label in tqdm(nlp.pipe(data, as_tuples=True), total=len(data)):\n", 60 | " \n", 61 | " for label_txt in label_map:\n", 62 | " doc.cats[label_txt] = 0\n", 63 | "\n", 64 | " doc.cats[label_map[label]] = 1\n", 65 | " \n", 66 | " # put them into a nice list\n", 67 | " docs.append(doc)\n", 68 | " \n", 69 | " return docs" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": 10, 75 | "metadata": {}, 76 | "outputs": [ 77 | { 78 | "name": "stderr", 79 | "output_type": "stream", 80 | "text": [ 81 | "100%|██████████| 53891/53891 [00:34<00:00, 1570.54it/s]\n", 82 | "100%|██████████| 4823/4823 [00:03<00:00, 1498.61it/s]\n" 83 | ] 84 | } 85 | ], 86 | "source": [ 87 | "from spacy.tokens import DocBin\n", 88 | "\n", 89 | "train_docs = make_docs(list(train_data.itertuples(index=False, name=None)))\n", 90 | "dev_docs = make_docs(list(dev_data.itertuples(index=False, name=None)))\n", 91 | "\n", 92 | "# then we save it in a binary file to disc\n", 93 | "train_doc_bin = DocBin(docs=train_docs)\n", 94 | "train_doc_bin.to_disk(\"./data/train.spacy\")\n", 95 | "\n", 96 | "dev_doc_bin = DocBin(docs=dev_docs)\n", 97 | "dev_doc_bin.to_disk(\"./data/dev.spacy\")" 98 | ] 99 | } 100 | ], 101 | "metadata": { 102 | "interpreter": { 103 | "hash": "7c3b128559c7e8fd624042ca8b6c93b33cd59aca7b58d05c9d4cd21ec1a84d35" 104 | }, 105 | "kernelspec": { 106 | "display_name": "Python 3.8.12 ('kogito')", 107 | "language": "python", 108 | "name": "python3" 109 | }, 110 | "language_info": { 111 | "codemirror_mode": { 112 | "name": "ipython", 113 | "version": 3 114 | }, 115 | "file_extension": ".py", 116 | "mimetype": "text/x-python", 117 | "name": "python", 118 | "nbconvert_exporter": "python", 119 | "pygments_lexer": "ipython3", 120 | "version": "3.8.12" 121 | }, 122 | "orig_nbformat": 4 123 | }, 124 | "nbformat": 4, 125 | "nbformat_minor": 2 126 | } 127 | -------------------------------------------------------------------------------- /experiments/relation_modeling/swem/relation_modeling_swem.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | import pytorch_lightning as pl 4 | from pytorch_lightning.loggers import WandbLogger 5 | import wandb 6 | 7 | from kogito.core.processors.models.swem import ( 8 | SWEMConfig, 9 | SWEMClassifier, 10 | SWEMHeadDataset, 11 | ) 12 | from relation_modeling_utils import load_fdata, load_data, get_timestamp 13 | 14 | DATASET_TYPE = "n1" 15 | NUM_EPOCHS = 20 16 | LR_RATE = 1e-4 17 | FREEZE_EMB = False 18 | BATCH_SIZE = 128 19 | POOLING = "avg" 20 | 21 | VOCAB = np.load("data/vocab_glove_100d.npy", allow_pickle=True).item() 22 | 23 | if __name__ == "__main__": 24 | train_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/train_{DATASET_TYPE}.csv") 25 | val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True) 26 | test_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/test_{DATASET_TYPE}.csv") 27 | train_data = SWEMHeadDataset(train_df, vocab=VOCAB) 28 | val_data = SWEMHeadDataset(val_df, vocab=VOCAB) 29 | test_data = SWEMHeadDataset(test_df, vocab=VOCAB) 30 | 31 | train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) 32 | val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE) 33 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True) 34 | 35 | timestamp = get_timestamp() 36 | 37 | emb_txt = "frozen" if FREEZE_EMB else "finetune" 38 | 39 | wandb_logger = WandbLogger( 40 | project="kogito-relation-matcher", name=f"swem_{emb_txt}_{DATASET_TYPE}" 41 | ) 42 | wandb_logger.experiment.config["epochs"] = NUM_EPOCHS 43 | wandb_logger.experiment.config["batch_size"] = BATCH_SIZE 44 | config = SWEMConfig(pooling=POOLING, freeze_emb=FREEZE_EMB, learning_rate=LR_RATE) 45 | model = SWEMClassifier(config) 46 | trainer = pl.Trainer( 47 | default_root_dir="models/swem", 48 | max_epochs=NUM_EPOCHS, 49 | logger=wandb_logger, 50 | accelerator="gpu", 51 | devices=[0], 52 | ) 53 | trainer.fit( 54 | model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader 55 | ) 56 | trainer.test(model, dataloaders=test_dataloader) 57 | trainer.save_checkpoint( 58 | f"models/swem/swem_{emb_txt}_{DATASET_TYPE}_{timestamp}.ckpt", weights_only=True 59 | ) 60 | model.save_pretrained("hmodels/swem") 61 | wandb.finish() 62 | -------------------------------------------------------------------------------- /experiments/relation_modeling/transformer/relation_modeling_bert.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import pytorch_lightning as pl 3 | from relation_modeling_utils import load_fdata, get_timestamp, load_data 4 | from pytorch_lightning.loggers import WandbLogger 5 | import wandb 6 | 7 | from kogito.core.processors.models.bert import ( 8 | BERTConfig, 9 | BERTHeadDataset, 10 | BERTClassifier, 11 | ) 12 | 13 | MODEL_TYPE = "uncased" 14 | NUM_EPOCHS = 3 15 | BATCH_SIZE = 2 16 | FREEZE_EMB = True 17 | DATASET_TYPE = "n1" 18 | LR_RATE = 1e-4 19 | 20 | if __name__ == "__main__": 21 | train_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/train_{DATASET_TYPE}.csv") 22 | val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True) 23 | test_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/test_{DATASET_TYPE}.csv") 24 | train_data = BERTHeadDataset(train_df, tokenizer_type=MODEL_TYPE) 25 | val_data = BERTHeadDataset(val_df, tokenizer_type=MODEL_TYPE) 26 | test_data = BERTHeadDataset(test_df, tokenizer_type=MODEL_TYPE) 27 | 28 | train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) 29 | val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE) 30 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True) 31 | 32 | timestamp = get_timestamp() 33 | emb_txt = "frozen" if FREEZE_EMB else "finetune" 34 | 35 | wandb_logger = WandbLogger( 36 | project="kogito-relation-matcher", 37 | name=f"bert_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}", 38 | ) 39 | wandb_logger.experiment.config["epochs"] = NUM_EPOCHS 40 | wandb_logger.experiment.config["batch_size"] = BATCH_SIZE 41 | config = BERTConfig( 42 | learning_rate=LR_RATE, model_case=MODEL_TYPE, freeze_emb=FREEZE_EMB 43 | ) 44 | model = BERTClassifier(config) 45 | trainer = pl.Trainer( 46 | default_root_dir="models/bert", 47 | max_epochs=NUM_EPOCHS, 48 | logger=wandb_logger, 49 | accelerator="gpu", 50 | devices=[0], 51 | ) 52 | trainer.fit( 53 | model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader 54 | ) 55 | trainer.test(model, dataloaders=test_dataloader) 56 | trainer.save_checkpoint( 57 | f"models/bert/bert_model_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}_{timestamp}.ckpt", 58 | weights_only=True, 59 | ) 60 | model.save_pretrained("hmodels/bert") 61 | wandb.finish() 62 | -------------------------------------------------------------------------------- /experiments/relation_modeling/transformer/relation_modeling_distillbert.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import pytorch_lightning as pl 3 | 4 | from relation_modeling_utils import get_timestamp, load_fdata, load_data 5 | from pytorch_lightning.loggers import WandbLogger 6 | import wandb 7 | 8 | from kogito.core.processors.models.distilbert import ( 9 | DistilBERTConfig, 10 | DistilBERTHeadDataset, 11 | DistilBERTClassifier, 12 | ) 13 | 14 | MODEL_TYPE = "uncased" 15 | NUM_EPOCHS = 3 16 | BATCH_SIZE = 64 17 | DATASET_TYPE = "n1" 18 | FREEZE_EMB = False 19 | LR_RATE = 1e-4 20 | 21 | 22 | if __name__ == "__main__": 23 | train_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/train_{DATASET_TYPE}.csv") 24 | val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True) 25 | test_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/test_{DATASET_TYPE}.csv") 26 | train_data = DistilBERTHeadDataset(train_df, tokenizer_type=MODEL_TYPE) 27 | val_data = DistilBERTHeadDataset(val_df, tokenizer_type=MODEL_TYPE) 28 | test_data = DistilBERTHeadDataset(test_df, tokenizer_type=MODEL_TYPE) 29 | 30 | train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) 31 | val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE) 32 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True) 33 | 34 | emb_txt = "frozen" if FREEZE_EMB else "finetune" 35 | 36 | timestamp = get_timestamp() 37 | wandb_logger = WandbLogger( 38 | project="kogito-relation-matcher", 39 | name=f"distilbert_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}", 40 | ) 41 | wandb_logger.experiment.config["epochs"] = NUM_EPOCHS 42 | wandb_logger.experiment.config["batch_size"] = BATCH_SIZE 43 | config = DistilBERTConfig( 44 | learning_rate=LR_RATE, model_case=MODEL_TYPE, freeze_emb=FREEZE_EMB 45 | ) 46 | model = DistilBERTClassifier(config) 47 | trainer = pl.Trainer( 48 | default_root_dir="models/distilbert", 49 | max_epochs=NUM_EPOCHS, 50 | logger=wandb_logger, 51 | accelerator="gpu", 52 | devices=[1], 53 | ) 54 | trainer.fit( 55 | model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader 56 | ) 57 | trainer.test(model, dataloaders=test_dataloader) 58 | trainer.save_checkpoint( 59 | f"models/distilbert/distilbert_model_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}_{timestamp}.ckpt", 60 | weights_only=True, 61 | ) 62 | model.save_pretrained("hmodels/distilbert") 63 | wandb.finish() 64 | -------------------------------------------------------------------------------- /kogito/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/__init__.py -------------------------------------------------------------------------------- /kogito/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/__init__.py -------------------------------------------------------------------------------- /kogito/core/callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytorch_lightning as pl 7 | import torch 8 | from pytorch_lightning.callbacks import ModelCheckpoint 9 | from pytorch_lightning.utilities import rank_zero_only 10 | from pytorch_lightning.utilities import rank_zero_info 11 | 12 | 13 | def count_trainable_parameters(model): 14 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 15 | params = sum([np.prod(p.size()) for p in model_parameters]) 16 | return params 17 | 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Seq2SeqLoggingCallback(pl.Callback): 23 | @rank_zero_only 24 | def _write_logs( 25 | self, 26 | trainer: pl.Trainer, 27 | pl_module: pl.LightningModule, 28 | type_path: str, 29 | save_generations=True, 30 | ) -> None: 31 | logger.info( 32 | f"***** {type_path} results at step {trainer.global_step:05d} *****" 33 | ) 34 | metrics = trainer.callback_metrics 35 | trainer.logger.log_metrics( 36 | { 37 | k: v 38 | for k, v in metrics.items() 39 | if k not in ["log", "progress_bar", "preds"] 40 | } 41 | ) 42 | # Log results 43 | od = Path(pl_module.config.output_dir) 44 | if type_path == "test": 45 | results_file = od / "test_results.txt" 46 | generations_file = od / "test_generations.txt" 47 | else: 48 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json 49 | # If people want this it will be easy enough to add back. 50 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt" 51 | generations_file = ( 52 | od / f"{type_path}_generations/{trainer.global_step:05d}.txt" 53 | ) 54 | results_file.parent.mkdir(exist_ok=True) 55 | generations_file.parent.mkdir(exist_ok=True) 56 | with open(results_file, "a+") as writer: 57 | for key in sorted(metrics): 58 | if key in ["log", "progress_bar", "preds"]: 59 | continue 60 | val = metrics[key] 61 | if isinstance(val, torch.Tensor): 62 | val = val.item() 63 | msg = f"{key}: {val:.6f}\n" 64 | writer.write(msg) 65 | 66 | if not save_generations: 67 | return 68 | 69 | if "preds" in metrics: 70 | content = "\n".join(metrics["preds"]) 71 | generations_file.open("w+").write(content) 72 | 73 | @rank_zero_only 74 | def on_train_start(self, trainer, pl_module): 75 | try: 76 | npars = pl_module.model.model.num_parameters() 77 | except AttributeError: 78 | npars = pl_module.model.num_parameters() 79 | 80 | n_trainable_pars = count_trainable_parameters(pl_module) 81 | # mp stands for million parameters 82 | trainer.logger.log_metrics( 83 | {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6} 84 | ) 85 | 86 | @rank_zero_only 87 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 88 | return self._write_logs(trainer, pl_module, "test") 89 | 90 | 91 | def get_checkpoint_callback(output_dir, metric): 92 | """Saves the best model by validation ROUGE2 score.""" 93 | if metric == "rouge2": 94 | exp = "{val_avg_rouge2:.4f}-{step_count}" 95 | elif metric == "bleu": 96 | exp = "{val_avg_bleu:.4f}-{step_count}" 97 | else: 98 | raise NotImplementedError( 99 | f"seq2seq callbacks only support rouge2 and bleu, got {metric}," 100 | "You can make your own by adding to this function." 101 | ) 102 | 103 | checkpoint_callback = ModelCheckpoint( 104 | dirpath=output_dir, 105 | filename=exp, 106 | monitor=f"val_{metric}", 107 | mode="max", 108 | save_top_k=1, 109 | ) 110 | return checkpoint_callback 111 | 112 | 113 | class LoggingCallback(pl.Callback): 114 | def on_batch_end(self, trainer, pl_module): 115 | lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())} 116 | pl_module.logger.log_metrics(lrs) 117 | 118 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 119 | rank_zero_info("***** Validation results *****") 120 | metrics = trainer.callback_metrics 121 | # Log results 122 | for key in sorted(metrics): 123 | if key not in ["log", "progress_bar"]: 124 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 125 | 126 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 127 | rank_zero_info("***** Test results *****") 128 | metrics = trainer.callback_metrics 129 | # Log and save results to file 130 | output_test_results_file = os.path.join( 131 | pl_module.base_config.output_dir, "test_results.txt" 132 | ) 133 | with open(output_test_results_file, "w") as writer: 134 | for key in sorted(metrics): 135 | if key not in ["log", "progress_bar"]: 136 | rank_zero_info("{} = {}\n".format(key, str(metrics[key]))) 137 | writer.write("{} = {}\n".format(key, str(metrics[key]))) 138 | -------------------------------------------------------------------------------- /kogito/core/head.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | from enum import Enum 3 | 4 | 5 | class KnowledgeHeadType(Enum): 6 | """ 7 | Type of a Knowledge Head 8 | """ 9 | 10 | TEXT = "text" 11 | SENTENCE = "sentence" 12 | NOUN_PHRASE = "noun_phrase" 13 | VERB_PHRASE = "verb_phrase" 14 | 15 | 16 | class KnowledgeHead: 17 | """ 18 | Represents a concept of Knowledge Head. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | text: str, 24 | type: KnowledgeHeadType = KnowledgeHeadType.TEXT, 25 | entity: Any = None, 26 | verbalizer: Optional[Callable] = None, 27 | ) -> None: 28 | """Initialize a Knowledge Head. 29 | 30 | Args: 31 | text (str): Head text. 32 | type (KnowledgeHeadType, optional): Type of a Knowledge head. Defaults to KnowledgeHeadType.TEXT. 33 | entity (Any, optional): External Knowledge head entity. Defaults to None. 34 | verbalizer (Optional[Callable], optional): Function to convert knowledge head to natural text. 35 | Defaults to None. 36 | """ 37 | self.text = text.strip() 38 | self.type = type 39 | self.entity = entity 40 | self.verbalizer = verbalizer 41 | 42 | def __eq__(self, other: object) -> bool: 43 | return isinstance(other, KnowledgeHead) and self.text == other.text 44 | 45 | def __ne__(self, other) -> bool: 46 | return not self.__eq__(other) 47 | 48 | def __hash__(self) -> int: 49 | return hash(self.text) 50 | 51 | def verbalize(self) -> Optional[str]: 52 | """Convert head to a meaningful text. 53 | 54 | Returns: 55 | Optional[str]: Verbalized head 56 | """ 57 | if self.verbalizer: 58 | return self.verbalizer(self.text) 59 | 60 | def __repr__(self) -> str: 61 | return str(self.text) 62 | 63 | def copy(self) -> "KnowledgeHead": 64 | """Copy itself 65 | 66 | Returns: 67 | KnowledgeHead: Copied knowledge head 68 | """ 69 | return KnowledgeHead( 70 | text=self.text, 71 | type=self.type, 72 | entity=self.entity, 73 | verbalizer=self.verbalizer, 74 | ) 75 | 76 | 77 | def head_verbalizer(head: str): 78 | return head 79 | -------------------------------------------------------------------------------- /kogito/core/linker.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Tuple 2 | from abc import ABC, abstractmethod, abstractclassmethod 3 | 4 | from kogito.core.knowledge import KnowledgeGraph 5 | 6 | 7 | class KnowledgeLinker(ABC): 8 | """Base Knowledge Linker""" 9 | 10 | @abstractmethod 11 | def save_pretrained(self, save_path: str) -> None: 12 | """Save linker as a pretrained model 13 | 14 | Args: 15 | save_path (str): Directory to save the linker to. 16 | 17 | Raises: 18 | NotImplementedError: This method has to be implemented by 19 | concrete subclasses. 20 | """ 21 | raise NotImplementedError 22 | 23 | @abstractclassmethod 24 | def from_pretrained(cls, model_name_or_path: str) -> "KnowledgeLinker": 25 | """Load model from a pretrained model path 26 | This method can load linkers either from HuggingFace by model name 27 | or from disk by model path. 28 | 29 | Args: 30 | model_name_or_path (str): HuggingFace model name or local model path to load from. 31 | 32 | Raises: 33 | NotImplementedError: This method has to be implemented by 34 | concrete subclasses. 35 | 36 | Returns: 37 | KnowledgeLinker: Loaded knowledge linker. 38 | """ 39 | raise NotImplementedError 40 | 41 | @abstractmethod 42 | def link( 43 | self, input_graph: KnowledgeGraph, context: Union[List[str], str] 44 | ) -> List[List[float]]: 45 | """Link given knowledge graph with the context. 46 | This method computes a relevance probability for each knowledge in the graph 47 | with respect to the given context and returns these probabilities in a list 48 | in the same order as the knowledge tuples are in the given graph. Note that returned object 49 | is a list of list of numbers because a knowledge tuple might have multiple tails and the probability 50 | is calculated for each combination. 51 | 52 | Args: 53 | input_graph (KnowledgeGraph): Input graph to link. 54 | context (Union[List[str], str]): Context text. Can be either given as a list of 55 | sentences or as a string, in which case, it will be 56 | split into sentences using spacy engine. 57 | 58 | Returns: 59 | List[List[float]]: List of relevance probabilities for each tail 60 | """ 61 | raise NotImplementedError 62 | 63 | def filter( 64 | self, 65 | input_graph: KnowledgeGraph, 66 | context: Union[List[str], str], 67 | threshold: float = 0.5, 68 | return_probs: bool = False, 69 | ) -> Union[KnowledgeGraph, Tuple[KnowledgeGraph, List[List[float]]]]: 70 | """Filter given graph based on context relevancy. 71 | This method under the hood links the graph to the context and then filters knowledge tuples from the graph 72 | which have a relevance probability lower than the given threshold. Filtered knowledge tuples 73 | are returned as a new knowledge graph. If there are multiple tails for a given knowledge, these tails will be 74 | filtered as well. 75 | 76 | Args: 77 | input_graph (KnowledgeGraph): Input graph to filter. 78 | context (Union[List[str], str]): Context text. Can be either given as a list of 79 | sentences or as a string, in which case, it will be 80 | split into sentences using spacy engine. 81 | threshold (float, optional): Relevance probability used for filtering. Defaults to 0.5. 82 | return_probs (bool, optional): Whether to return all the relevancy probs for the input graph. 83 | Defaults to False. 84 | Returns: 85 | Union[KnowledgeGraph, Tuple[KnowledgeGraph, List[List[float]]]]: 86 | Filtered knowledge graph based on the relevancy scores and optionally, the relevancy scores. 87 | """ 88 | probs = self.link(input_graph, context) 89 | filtered_kgs = [] 90 | 91 | for kg, tail_probs in zip(input_graph, probs): 92 | filtered_tails = [] 93 | 94 | for i, prob in enumerate(tail_probs): 95 | if prob >= threshold: 96 | filtered_tails.append(kg.tails[i]) 97 | 98 | if filtered_tails: 99 | kg.tails = filtered_tails 100 | filtered_kgs.append(kg) 101 | 102 | output_graph = KnowledgeGraph(filtered_kgs) 103 | 104 | if return_probs: 105 | return output_graph, probs 106 | 107 | return output_graph 108 | -------------------------------------------------------------------------------- /kogito/core/model.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from abc import ABC, abstractmethod, abstractclassmethod 4 | from kogito.core.knowledge import KnowledgeGraph 5 | from kogito.evaluation.eval import topk_eval, METRIC_MAP 6 | 7 | 8 | class KnowledgeModel(ABC): 9 | """ 10 | Base class to represent a Knowledge Model. 11 | """ 12 | 13 | @abstractmethod 14 | def train(self, train_graph: KnowledgeGraph, *args, **kwargs) -> "KnowledgeModel": 15 | """Train a knowledge model 16 | 17 | Args: 18 | train_graph (KnowledgeGraph): Training dataset 19 | 20 | Raises: 21 | NotImplementedError: This method has to be implemented 22 | by concrete subclasses. 23 | 24 | Returns: 25 | KnowledgeModel: Trained knowledge model 26 | """ 27 | raise NotImplementedError 28 | 29 | @abstractmethod 30 | def generate(self, input_graph: KnowledgeGraph, *args, **kwargs) -> KnowledgeGraph: 31 | """Generate inferences from knowledge model 32 | 33 | Args: 34 | input_graph (KnowledgeGraph): Input dataset 35 | 36 | Raises: 37 | NotImplementedError: This method has to be implemented by 38 | concrete subclasses. 39 | 40 | Returns: 41 | KnowledgeGraph: Input graph with tails generated 42 | """ 43 | raise NotImplementedError 44 | 45 | @abstractmethod 46 | def save_pretrained(self, save_path: str) -> None: 47 | """Save model as a pretrained model 48 | 49 | Args: 50 | save_path (str): Directory to save the model to. 51 | 52 | Raises: 53 | NotImplementedError: This method has to be implemented by 54 | concrete subclasses. 55 | """ 56 | raise NotImplementedError 57 | 58 | @abstractclassmethod 59 | def from_pretrained(cls, model_name_or_path: str) -> "KnowledgeModel": 60 | """Load model from a pretrained model path 61 | This method can load models either from HuggingFace by model name 62 | or from disk by model path. 63 | 64 | Args: 65 | model_name_or_path (str): HuggingFace model name or local model path to load from. 66 | 67 | Raises: 68 | NotImplementedError: This method has to be implemented by 69 | concrete subclasses. 70 | 71 | Returns: 72 | KnowledgeModel: Loaded knowledge model. 73 | """ 74 | raise NotImplementedError 75 | 76 | def evaluate( 77 | self, 78 | input_graph: KnowledgeGraph, 79 | metrics: List[str] = ["bleu", "meteor", "rouge", "cider", "bert-score"], 80 | top_k: int = 1, 81 | *args, 82 | **kwargs, 83 | ) -> dict: 84 | """Evaluate model on various metrics. 85 | Input graph should contain the reference tails, so that it can be used to score the model generations 86 | on the same input graph. Any arguments provided aside from the ones accepted by this method will be 87 | passed onto the ``KnowledgeModel.generate`` method. 88 | 89 | Args: 90 | input_graph (KnowledgeGraph): Input graph to evaluate. Should contain the ground truth tails. 91 | metrics (List[str], optional): Metrics to compute. 92 | Defaults to ["bleu", "meteor", "rouge", "cider", "bert-score"]. 93 | top_k (int, optional): Top k generations to evaluate. Defaults to 1. 94 | *args (optional): Extra arguments for `KnowledgeModel.generate` method. 95 | **kwargs (optional): Extra keyword arguments for ``KnowledgeModel.generate`` method. 96 | 97 | Returns: 98 | dict: Dictionary of scores 99 | """ 100 | return evaluate(self, input_graph, metrics, top_k=top_k, *args, **kwargs) 101 | 102 | 103 | def evaluate( 104 | model: KnowledgeModel, 105 | input_graph: KnowledgeGraph, 106 | metrics: List[str] = ["bleu", "meteor", "rouge", "cider", "bert-score"], 107 | top_k: int = 1, 108 | *args, 109 | **kwargs, 110 | ): 111 | if not set(metrics).issubset(set(METRIC_MAP.keys())): 112 | raise ValueError( 113 | f"Invalid evaluation metrics found: {set(metrics) - set(METRIC_MAP.keys())}" 114 | ) 115 | 116 | output_graph = model.generate(input_graph=input_graph, *args, **kwargs) 117 | evaluation_data = [] 118 | 119 | for input_kg, output_kg in zip(input_graph, output_graph): 120 | assert input_kg.head == output_kg.head 121 | assert input_kg.relation == output_kg.relation 122 | assert len(input_kg.tails) > 0 123 | 124 | evaluation_data.append((output_kg, input_kg.tails)) 125 | 126 | return topk_eval(evaluation_data, metrics, k=top_k) 127 | -------------------------------------------------------------------------------- /kogito/core/processors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/processors/__init__.py -------------------------------------------------------------------------------- /kogito/core/processors/data/vocab_glove_100d.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/processors/data/vocab_glove_100d.npy -------------------------------------------------------------------------------- /kogito/core/processors/head.py: -------------------------------------------------------------------------------- 1 | import string 2 | from abc import ABC, abstractmethod 3 | from typing import List, Optional 4 | 5 | from spacy.tokens import Doc 6 | from spacy.language import Language 7 | from spacy.lang.en.stop_words import STOP_WORDS 8 | 9 | from kogito.core.head import KnowledgeHead, KnowledgeHeadType 10 | from kogito.core.utils import IGNORE_WORDS 11 | 12 | 13 | class KnowledgeHeadExtractor(ABC): 14 | """Base class for head extraction""" 15 | 16 | def __init__(self, name: str, lang: Optional[Language] = None) -> None: 17 | """Initialize a head extractor 18 | 19 | Args: 20 | name (str): Unique head extractor name 21 | lang (Optional[Language], optional): Spacy language pipeline to use. Defaults to None. 22 | """ 23 | self.name = name 24 | self.lang = lang 25 | 26 | @abstractmethod 27 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]: 28 | """Extract heads from text 29 | 30 | Args: 31 | text (str): Text to extract from 32 | doc (Optional[Doc], optional): Spacy doc to use for extraction. Defaults to None. 33 | 34 | Raises: 35 | NotImplementedError: This method has to be implemented by 36 | concrete subclasses. 37 | 38 | Returns: 39 | List[KnowledgeHead]: List of extracted knowledge heads. 40 | """ 41 | raise NotImplementedError 42 | 43 | 44 | class SentenceHeadExtractor(KnowledgeHeadExtractor): 45 | """Extracts sentences as heads from text""" 46 | 47 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]: 48 | if not doc: 49 | doc = self.lang(text) 50 | 51 | heads = [] 52 | 53 | for sentence in doc.sents: 54 | if sentence.text.strip(): 55 | heads.append( 56 | KnowledgeHead( 57 | text=sentence.text, 58 | type=KnowledgeHeadType.SENTENCE, 59 | entity=sentence, 60 | ) 61 | ) 62 | 63 | return heads 64 | 65 | 66 | class NounPhraseHeadExtractor(KnowledgeHeadExtractor): 67 | """Extracts noun phrases as heads from text""" 68 | 69 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]: 70 | if not doc: 71 | doc = self.lang(text) 72 | 73 | heads = [] 74 | head_texts = set() 75 | 76 | for token in doc: 77 | if ( 78 | token.text.strip().lower() not in STOP_WORDS.union(IGNORE_WORDS) 79 | and token.pos_ == "NOUN" 80 | ): 81 | token_text = token.text.strip(string.punctuation + " ") 82 | if token_text not in head_texts and len(token_text) > 1: 83 | head_texts.add(token_text) 84 | heads.append( 85 | KnowledgeHead( 86 | text=token.text.strip(), 87 | type=KnowledgeHeadType.NOUN_PHRASE, 88 | entity=token, 89 | ) 90 | ) 91 | 92 | for phrase in doc.noun_chunks: 93 | clean_phrase = [] 94 | phrase_doc = self.lang(phrase.text) 95 | 96 | for token in phrase_doc: 97 | if token.text.strip().lower() not in STOP_WORDS.union(IGNORE_WORDS): 98 | clean_phrase.append(token.text) 99 | 100 | clean_text = " ".join(clean_phrase).strip(string.punctuation + " ") 101 | 102 | if clean_text and clean_text not in head_texts and len(clean_text) > 1: 103 | head_texts.add(clean_text) 104 | heads.append( 105 | KnowledgeHead( 106 | text=clean_text, 107 | type=KnowledgeHeadType.NOUN_PHRASE, 108 | entity=phrase, 109 | ) 110 | ) 111 | 112 | return heads 113 | 114 | 115 | class VerbPhraseHeadExtractor(KnowledgeHeadExtractor): 116 | """Extracts verb phrases as heads from text""" 117 | 118 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]: 119 | if not doc: 120 | doc = self.lang(text) 121 | 122 | heads = [] 123 | head_texts = set() 124 | 125 | for token in doc: 126 | if token.pos_ == "VERB": 127 | verb_text = f"to {token.lemma_}" 128 | 129 | if verb_text not in head_texts: 130 | head_texts.add(verb_text) 131 | heads.append( 132 | KnowledgeHead( 133 | text=verb_text, 134 | type=KnowledgeHeadType.VERB_PHRASE, 135 | entity=token, 136 | ) 137 | ) 138 | 139 | for child in token.children: 140 | if child.dep_ in ("attr", "dobj"): 141 | child_text = f"{token.lemma_} {child.text}" 142 | if child_text not in head_texts: 143 | head_texts.add(child_text) 144 | heads.append( 145 | KnowledgeHead( 146 | text=child_text, 147 | type=KnowledgeHeadType.VERB_PHRASE, 148 | entity=[token, child], 149 | ) 150 | ) 151 | 152 | return heads 153 | -------------------------------------------------------------------------------- /kogito/core/processors/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/processors/models/__init__.py -------------------------------------------------------------------------------- /kogito/core/processors/models/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | from torch.utils.data import Dataset 5 | from torch.optim import Adam 6 | from torch import nn 7 | import pytorch_lightning as pl 8 | from transformers import BertTokenizer, BertModel, PretrainedConfig, PreTrainedModel 9 | 10 | from kogito.core.processors.models.utils import Evaluator 11 | 12 | 13 | class BERTHeadDataset(Dataset): 14 | def __init__(self, data, tokenizer_type="uncased"): 15 | self.tokenizer = BertTokenizer.from_pretrained(f"bert-base-{tokenizer_type}") 16 | self.labels = ( 17 | np.asarray(data["label"].to_list()) 18 | if isinstance(data, pd.DataFrame) 19 | else None 20 | ) 21 | texts = data["text"] if isinstance(data, pd.DataFrame) else data 22 | self.features = [ 23 | self.tokenizer( 24 | text, 25 | padding="max_length", 26 | max_length=32, 27 | truncation=True, 28 | return_tensors="pt", 29 | ) 30 | for text in texts 31 | ] 32 | 33 | def __len__(self): 34 | return len(self.features) 35 | 36 | def __getitem__(self, idx): 37 | if self.labels is not None: 38 | return self.features[idx], self.labels[idx] 39 | return self.features[idx] 40 | 41 | 42 | class BERTConfig(PretrainedConfig): 43 | def __init__( 44 | self, 45 | num_classes=3, 46 | dropout=0.5, 47 | learning_rate=1e-4, 48 | freeze_emb=False, 49 | model_case="uncased", 50 | **kwargs, 51 | ): 52 | self.num_classes = num_classes 53 | self.dropout = dropout 54 | self.learning_rate = learning_rate 55 | self.freeze_emb = freeze_emb 56 | self.model_case = model_case 57 | super().__init__(**kwargs) 58 | 59 | 60 | class BERTClassifier(PreTrainedModel, Evaluator, pl.LightningModule): 61 | config_class = BERTConfig 62 | 63 | def __init__(self, config: BERTConfig): 64 | super().__init__(config) 65 | self.bert = BertModel.from_pretrained(f"bert-base-{config.model_case}") 66 | self.dropout = nn.Dropout(config.dropout) 67 | self.linear = nn.Linear(768, config.num_classes) 68 | 69 | if config.freeze_emb: 70 | for parameter in self.bert.parameters(): 71 | parameter.requires_grad = False 72 | self.classifier = nn.Sequential(self.linear) 73 | else: 74 | self.classifier = nn.Sequential(self.dropout, self.linear) 75 | 76 | self.criterion = nn.BCEWithLogitsLoss() 77 | self.learning_rate = config.learning_rate 78 | 79 | self.save_hyperparameters(config.to_dict(), ignore="config") 80 | 81 | def forward(self, input_ids, mask): 82 | _, outputs = self.bert( 83 | input_ids=input_ids, attention_mask=mask, return_dict=False 84 | ) 85 | outputs = self.classifier(outputs) 86 | return outputs 87 | 88 | def training_step(self, batch, batch_idx): 89 | X, y = batch 90 | mask = X["attention_mask"] 91 | input_ids = X["input_ids"].squeeze(1) 92 | outputs = self.forward(input_ids, mask) 93 | train_loss = self.criterion(outputs, y.float()) 94 | preds = torch.sigmoid(outputs) 95 | self.log("train_loss", train_loss, on_epoch=True) 96 | self.log_metrics(preds, y, type="train") 97 | return train_loss 98 | 99 | def validation_step(self, batch, batch_idx): 100 | X, y = batch 101 | mask = X["attention_mask"] 102 | input_ids = X["input_ids"].squeeze(1) 103 | outputs = self.forward(input_ids, mask) 104 | val_loss = self.criterion(outputs, y.float()) 105 | preds = torch.sigmoid(outputs) 106 | self.log("val_loss", val_loss, on_epoch=True) 107 | self.log_metrics(preds, y, type="val") 108 | return val_loss 109 | 110 | def test_step(self, batch, batch_idx): 111 | X, y = batch 112 | mask = X["attention_mask"] 113 | input_ids = X["input_ids"].squeeze(1) 114 | outputs = self.forward(input_ids, mask) 115 | test_loss = self.criterion(outputs, y.float()) 116 | preds = torch.sigmoid(outputs) 117 | self.log("test_loss", test_loss, on_epoch=True) 118 | self.log_metrics(preds, y, type="test") 119 | return test_loss 120 | 121 | def predict_step(self, batch, batch_idx): 122 | X = batch 123 | mask = X["attention_mask"] 124 | input_ids = X["input_ids"].squeeze(1) 125 | outputs = self.forward(input_ids, mask) 126 | preds = torch.sigmoid(outputs) 127 | return preds 128 | 129 | def configure_optimizers(self): 130 | optimizer = Adam(self.parameters(), lr=self.learning_rate) 131 | return optimizer 132 | -------------------------------------------------------------------------------- /kogito/core/processors/models/distilbert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pandas as pd 4 | from torch.utils.data import Dataset 5 | from torch.optim import Adam 6 | from torch import nn 7 | import pytorch_lightning as pl 8 | from transformers import ( 9 | DistilBertTokenizer, 10 | DistilBertModel, 11 | PretrainedConfig, 12 | PreTrainedModel, 13 | ) 14 | 15 | from kogito.core.processors.models.utils import Evaluator 16 | 17 | 18 | class DistilBERTHeadDataset(Dataset): 19 | def __init__(self, data, tokenizer_type="uncased"): 20 | self.tokenizer = DistilBertTokenizer.from_pretrained( 21 | f"distilbert-base-{tokenizer_type}" 22 | ) 23 | self.labels = ( 24 | np.asarray(data["label"].to_list()) 25 | if isinstance(data, pd.DataFrame) 26 | else None 27 | ) 28 | texts = data["text"] if isinstance(data, pd.DataFrame) else data 29 | self.features = [ 30 | self.tokenizer( 31 | text, 32 | padding="max_length", 33 | max_length=32, 34 | truncation=True, 35 | return_tensors="pt", 36 | ) 37 | for text in texts 38 | ] 39 | 40 | def __len__(self): 41 | return len(self.features) 42 | 43 | def __getitem__(self, idx): 44 | if self.labels is not None: 45 | return self.features[idx], self.labels[idx] 46 | return self.features[idx] 47 | 48 | 49 | class DistilBERTConfig(PretrainedConfig): 50 | def __init__( 51 | self, 52 | num_classes=3, 53 | dropout=0.5, 54 | learning_rate=1e-4, 55 | freeze_emb=False, 56 | model_case="uncased", 57 | **kwargs, 58 | ): 59 | self.num_classes = num_classes 60 | self.dropout = dropout 61 | self.learning_rate = learning_rate 62 | self.freeze_emb = freeze_emb 63 | self.model_case = model_case 64 | super().__init__(**kwargs) 65 | 66 | 67 | class DistilBERTClassifier(PreTrainedModel, Evaluator, pl.LightningModule): 68 | config_class = DistilBERTConfig 69 | 70 | def __init__(self, config: DistilBERTConfig): 71 | super().__init__(config) 72 | self.distilbert = DistilBertModel.from_pretrained( 73 | f"distilbert-base-{config.model_case}" 74 | ) 75 | self.dropout = nn.Dropout(config.dropout) 76 | self.linear = nn.Linear(768, config.num_classes) 77 | 78 | if config.freeze_emb: 79 | for parameter in self.distilbert.parameters(): 80 | parameter.requires_grad = False 81 | self.classifier = nn.Sequential(self.linear) 82 | else: 83 | self.classifier = nn.Sequential(self.dropout, self.linear) 84 | 85 | self.criterion = nn.BCEWithLogitsLoss() 86 | self.learning_rate = config.learning_rate 87 | 88 | self.save_hyperparameters(config.to_dict(), ignore="config") 89 | 90 | def forward(self, input_ids, mask): 91 | outputs = self.distilbert( 92 | input_ids=input_ids, attention_mask=mask, return_dict=False 93 | ) 94 | outputs = self.classifier(outputs[0][:, 0, :]) 95 | return outputs 96 | 97 | def training_step(self, batch, batch_idx): 98 | X, y = batch 99 | mask = X["attention_mask"] 100 | input_ids = X["input_ids"].squeeze(1) 101 | outputs = self.forward(input_ids, mask) 102 | train_loss = self.criterion(outputs, y.float()) 103 | preds = torch.sigmoid(outputs) 104 | self.log("train_loss", train_loss, on_epoch=True) 105 | self.log_metrics(preds, y, type="train") 106 | return train_loss 107 | 108 | def validation_step(self, batch, batch_idx): 109 | X, y = batch 110 | mask = X["attention_mask"] 111 | input_ids = X["input_ids"].squeeze(1) 112 | outputs = self.forward(input_ids, mask) 113 | val_loss = self.criterion(outputs, y.float()) 114 | preds = torch.sigmoid(outputs) 115 | self.log("val_loss", val_loss, on_epoch=True) 116 | self.log_metrics(preds, y, type="val") 117 | return val_loss 118 | 119 | def test_step(self, batch, batch_idx): 120 | X, y = batch 121 | mask = X["attention_mask"] 122 | input_ids = X["input_ids"].squeeze(1) 123 | outputs = self.forward(input_ids, mask) 124 | test_loss = self.criterion(outputs, y.float()) 125 | preds = torch.sigmoid(outputs) 126 | self.log("test_loss", test_loss, on_epoch=True) 127 | self.log_metrics(preds, y, type="test") 128 | return test_loss 129 | 130 | def predict_step(self, batch, batch_idx): 131 | X = batch 132 | mask = X["attention_mask"] 133 | input_ids = X["input_ids"].squeeze(1) 134 | outputs = self.forward(input_ids, mask) 135 | preds = torch.sigmoid(outputs) 136 | return preds 137 | 138 | def configure_optimizers(self): 139 | optimizer = Adam(self.parameters(), lr=self.learning_rate) 140 | return optimizer 141 | -------------------------------------------------------------------------------- /kogito/core/processors/models/utils.py: -------------------------------------------------------------------------------- 1 | import torchmetrics 2 | import numpy as np 3 | import spacy 4 | 5 | 6 | def text_to_embedding(text, vocab, embedding_matrix, pooling="max", lang=None): 7 | if not lang: 8 | lang = spacy.load("en_core_web_sm") 9 | 10 | doc = lang(text) 11 | vectors = [] 12 | for token in doc: 13 | if token.text in vocab: 14 | vectors.append(embedding_matrix[vocab[token.text]]) 15 | 16 | if vectors: 17 | if pooling == "max": 18 | return np.amax(np.array(vectors, dtype=np.float32), axis=0) 19 | return np.mean(vectors, axis=0, dtype=np.float32) 20 | 21 | 22 | class Evaluator: 23 | def __init__(self, *args, **kwargs) -> None: 24 | super().__init__() 25 | self.metrics = dict( 26 | train_accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=3), 27 | # (weighted) 28 | train_precision=torchmetrics.Precision(task="multiclass", num_classes=3, average="weighted"), 29 | train_recall=torchmetrics.Recall(task="multiclass", num_classes=3, average="weighted"), 30 | train_f1=torchmetrics.F1Score(task="multiclass", num_classes=3, average="weighted"), 31 | # (micro) 32 | train_precision_micro=torchmetrics.Precision( 33 | task="multiclass", num_classes=3, average="micro" 34 | ), 35 | train_recall_micro=torchmetrics.Recall(task="multiclass", num_classes=3, average="micro"), 36 | train_f1_micro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="micro"), 37 | # (macro) 38 | train_precision_macro=torchmetrics.Precision( 39 | task="multiclass", num_classes=3, average="macro" 40 | ), 41 | train_recall_macro=torchmetrics.Recall(task="multiclass", num_classes=3, average="macro"), 42 | train_f1_macro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="macro"), 43 | # (per class) 44 | train_precision_class=torchmetrics.Precision(task="multiclass", num_classes=3, average="none"), 45 | train_recall_class=torchmetrics.Recall(task="multiclass", num_classes=3, average="none"), 46 | train_f1_class=torchmetrics.F1Score(task="multiclass", num_classes=3, average="none"), 47 | # Validation metrics 48 | val_accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=3), 49 | # (weighted) 50 | val_precision=torchmetrics.Precision(task="multiclass", num_classes=3, average="weighted"), 51 | val_recall=torchmetrics.Recall(task="multiclass", num_classes=3, average="weighted"), 52 | val_f1=torchmetrics.F1Score(task="multiclass", num_classes=3, average="weighted"), 53 | # (micro) 54 | val_precision_micro=torchmetrics.Precision(task="multiclass", num_classes=3, average="micro"), 55 | val_recall_micro=torchmetrics.Recall(task="multiclass", num_classes=3, average="micro"), 56 | val_f1_micro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="micro"), 57 | # (macro) 58 | val_precision_macro=torchmetrics.Precision(task="multiclass", num_classes=3, average="macro"), 59 | val_recall_macro=torchmetrics.Recall(task="multiclass", num_classes=3, average="macro"), 60 | val_f1_macro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="macro"), 61 | # (per class) 62 | val_precision_class=torchmetrics.Precision(task="multiclass", num_classes=3, average="none"), 63 | val_recall_class=torchmetrics.Recall(task="multiclass", num_classes=3, average="none"), 64 | val_f1_class=torchmetrics.F1Score(task="multiclass", num_classes=3, average="none"), 65 | # Test metrics 66 | test_accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=3), 67 | # (weighted) 68 | test_precision=torchmetrics.Precision(task="multiclass", num_classes=3, average="weighted"), 69 | test_recall=torchmetrics.Recall(task="multiclass", num_classes=3, average="weighted"), 70 | test_f1=torchmetrics.F1Score(task="multiclass", num_classes=3, average="weighted"), 71 | # (micro) 72 | test_precision_micro=torchmetrics.Precision(task="multiclass", num_classes=3, average="micro"), 73 | test_recall_micro=torchmetrics.Recall(task="multiclass", num_classes=3, average="micro"), 74 | test_f1_micro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="micro"), 75 | # (macro) 76 | test_precision_macro=torchmetrics.Precision(task="multiclass", num_classes=3, average="macro"), 77 | test_recall_macro=torchmetrics.Recall(task="multiclass", num_classes=3, average="macro"), 78 | test_f1_macro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="macro"), 79 | # (per class) 80 | test_precision_class=torchmetrics.Precision(task="multiclass", num_classes=3, average="none"), 81 | test_recall_class=torchmetrics.Recall(task="multiclass", num_classes=3, average="none"), 82 | test_f1_class=torchmetrics.F1Score(task="multiclass", num_classes=3, average="none"), 83 | ) 84 | 85 | def log_metrics(self, preds, y, type): 86 | for metric_name, metric in self.metrics.items(): 87 | if metric_name.startswith(type): 88 | metric(preds.cpu(), y.cpu()) 89 | value = metric.compute() 90 | if len(value.shape) > 0: 91 | for idx, val in enumerate(value): 92 | self.log( 93 | f"{metric_name}_{idx}", val, on_epoch=True, on_step=False 94 | ) 95 | else: 96 | self.log(metric_name, value, on_epoch=True, on_step=False) 97 | -------------------------------------------------------------------------------- /kogito/evaluation/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Manish Joshi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /kogito/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/evaluation/__init__.py -------------------------------------------------------------------------------- /kogito/evaluation/bert_score/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/evaluation/bert_score/__init__.py -------------------------------------------------------------------------------- /kogito/evaluation/bert_score/bert_score.py: -------------------------------------------------------------------------------- 1 | from kogito.evaluation.bert_score.score import score 2 | 3 | # Code for BertScore reused from original implementation: https://github.com/Tiiiger/bert_score 4 | 5 | 6 | class BertScore: 7 | def __init__(self): 8 | self._hypo_for_image = {} 9 | self.ref_for_image = {} 10 | 11 | def compute_score(self, gts, res): 12 | 13 | assert gts.keys() == res.keys() 14 | imgIds = gts.keys() 15 | 16 | hyp_input = [] 17 | ref_input = [] 18 | same_indices = [] 19 | for id in imgIds: 20 | hypo = res[id] 21 | ref = gts[id] 22 | 23 | # Sanity check. 24 | assert type(hypo) is list 25 | assert len(hypo) == 1 26 | assert type(ref) is list 27 | assert len(ref) >= 1 28 | 29 | hyp_input += [hypo[0]] * len(ref) 30 | ref_input += ref 31 | same_indices.append(len(ref_input)) 32 | 33 | p, r, f_scores = score(hyp_input, ref_input) 34 | 35 | prev_idx = 0 36 | aggreg_f1_scores = [] 37 | for idx in same_indices: 38 | aggreg_f1_scores.append(f_scores[prev_idx:idx].mean().cpu().item()) 39 | prev_idx = idx 40 | 41 | return sum(aggreg_f1_scores) / len(aggreg_f1_scores), aggreg_f1_scores 42 | 43 | def method(self): 44 | return "Bert Score" 45 | -------------------------------------------------------------------------------- /kogito/evaluation/bert_score/score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import defaultdict 3 | from transformers import BertTokenizer, BertModel 4 | 5 | from kogito.evaluation.bert_score.utils import ( 6 | get_idf_dict, 7 | bert_cos_score_idf, 8 | bert_types, 9 | ) 10 | 11 | 12 | def score( 13 | cands, 14 | refs, 15 | bert="bert-base-multilingual-cased", 16 | num_layers=8, 17 | no_idf=False, 18 | batch_size=64, 19 | ): 20 | """ 21 | BERTScore metric. 22 | Args: 23 | - :param: `cands` (list of str): candidate sentences 24 | - :param: `refs` (list of str): reference sentences 25 | - :param: `bert` (str): bert specification 26 | - :param: `num_layers` (int): the layer of representation to use 27 | - :param: `verbose` (bool): turn on intermediate status update 28 | - :param: `no_idf` (bool): do not use idf weighting 29 | - :param: `batch_size` (int): bert score processing batch size 30 | """ 31 | assert len(cands) == len(refs) 32 | assert bert in bert_types 33 | 34 | tokenizer = BertTokenizer.from_pretrained(bert) 35 | model = BertModel.from_pretrained(bert) 36 | model.eval() 37 | device = "cuda" if torch.cuda.is_available() else "cpu" 38 | model.to(device) 39 | 40 | # drop unused layers 41 | model.encoder.layer = torch.nn.ModuleList( 42 | [layer for layer in model.encoder.layer[:num_layers]] 43 | ) 44 | 45 | if no_idf: 46 | idf_dict = defaultdict(lambda: 1.0) 47 | # set idf for [SEP] and [CLS] to 0 48 | idf_dict[101] = 0 49 | idf_dict[102] = 0 50 | else: 51 | idf_dict = get_idf_dict(refs, tokenizer) 52 | 53 | all_preds = bert_cos_score_idf( 54 | model, refs, cands, tokenizer, idf_dict, device=device, batch_size=batch_size 55 | ) 56 | 57 | P = all_preds[:, 0].cpu() 58 | R = all_preds[:, 1].cpu() 59 | F1 = all_preds[:, 2].cpu() 60 | 61 | return P, R, F1 62 | -------------------------------------------------------------------------------- /kogito/evaluation/bleu/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /kogito/evaluation/bleu/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /kogito/evaluation/bleu/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "tylin" 2 | -------------------------------------------------------------------------------- /kogito/evaluation/bleu/bleu.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : bleu.py 4 | # 5 | # Description : Wrapper for BLEU scorer. 6 | # 7 | # Creation Date : 06-01-2015 8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT 9 | # Authors : Hao Fang and Tsung-Yi Lin 10 | 11 | from kogito.evaluation.bleu.bleu_scorer import BleuScorer 12 | 13 | 14 | class Bleu: 15 | def __init__(self, n=4): 16 | # default compute Blue score up to 4 17 | self._n = n 18 | self._hypo_for_image = {} 19 | self.ref_for_image = {} 20 | 21 | def compute_score(self, gts, res): 22 | 23 | assert gts.keys() == res.keys() 24 | imgIds = gts.keys() 25 | 26 | bleu_scorer = BleuScorer(n=self._n) 27 | for id in imgIds: 28 | hypo = res[id] 29 | ref = gts[id] 30 | 31 | # Sanity check. 32 | assert type(hypo) is list 33 | assert len(hypo) == 1 34 | assert type(ref) is list 35 | assert len(ref) >= 1 36 | 37 | bleu_scorer += (hypo[0], ref) 38 | 39 | # score, scores = bleu_scorer.compute_score(option='shortest') 40 | score, scores = bleu_scorer.compute_score(option="closest", verbose=0) 41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=0) 42 | 43 | # return (bleu, bleu_info) 44 | return score, scores 45 | 46 | def method(self): 47 | return "Bleu" 48 | -------------------------------------------------------------------------------- /kogito/evaluation/cider/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "tylin" 2 | -------------------------------------------------------------------------------- /kogito/evaluation/cider/cider.py: -------------------------------------------------------------------------------- 1 | # Filename: cider.py 2 | # 3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726) 5 | # 6 | # Creation Date: Sun Feb 8 14:16:54 2015 7 | # 8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin 9 | 10 | from kogito.evaluation.cider.cider_scorer import CiderScorer 11 | 12 | 13 | class Cider: 14 | """ 15 | Main Class to compute the CIDEr metric 16 | 17 | """ 18 | 19 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 20 | # set cider to sum over 1 to 4-grams 21 | self._n = n 22 | # set the standard deviation parameter for gaussian penalty 23 | self._sigma = sigma 24 | 25 | def compute_score(self, gts, res): 26 | """ 27 | Main function to compute CIDEr score 28 | :param hypo_for_image (dict) : dictionary with key and value 29 | 30 | ref_for_image (dict) : dictionary with key and value 31 | 32 | :return: cider (float) : computed CIDEr score for the corpus 33 | """ 34 | 35 | assert gts.keys() == res.keys() 36 | imgIds = gts.keys() 37 | 38 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 39 | 40 | for id in imgIds: 41 | hypo = res[id] 42 | ref = gts[id] 43 | 44 | # Sanity check. 45 | assert type(hypo) is list 46 | assert len(hypo) == 1 47 | assert type(ref) is list 48 | assert len(ref) > 0 49 | 50 | cider_scorer += (hypo[0], ref) 51 | 52 | (score, scores) = cider_scorer.compute_score() 53 | 54 | return score, scores 55 | 56 | def method(self): 57 | return "CIDEr" 58 | -------------------------------------------------------------------------------- /kogito/evaluation/eval.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from kogito.evaluation.bleu.bleu import Bleu 4 | from kogito.evaluation.meteor.meteor import Meteor 5 | from kogito.evaluation.rouge.rouge import Rouge 6 | from kogito.evaluation.cider.cider import Cider 7 | from kogito.evaluation.bert_score.bert_score import BertScore 8 | from kogito.core.knowledge import Knowledge 9 | 10 | 11 | METRIC_MAP = { 12 | "bleu": (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), 13 | "meteor": (Meteor(), "METEOR"), 14 | "rouge": (Rouge(), "ROUGE_L"), 15 | "cider": (Cider(), "CIDEr"), 16 | "bert-score": (BertScore(), "Bert Score"), 17 | } 18 | 19 | 20 | class Evaluator: 21 | def __init__(self, gts, res, metrics): 22 | self.gts = gts 23 | self.res = res 24 | self.scorers = [METRIC_MAP[metric] for metric in metrics] 25 | 26 | def evaluate(self): 27 | score_dict = {} 28 | 29 | for scorer, method in self.scorers: 30 | score, _ = scorer.compute_score(self.gts, self.res) 31 | if type(method) == list: 32 | for sc, m in zip(score, method): 33 | score_dict[m] = str(sc) 34 | else: 35 | score_dict[method] = score 36 | 37 | return score_dict 38 | 39 | 40 | def topk_eval(data: List[Tuple[Knowledge, List[str]]], metrics, k=1): 41 | topk_gts = {} 42 | topk_res = {} 43 | 44 | for i, (kg, reference) in enumerate(data): 45 | for j, g in enumerate(kg.tails[:k]): 46 | key = str(i) + "_" + str(j) 47 | topk_gts[key] = reference 48 | topk_res[key] = [g] 49 | 50 | evaluator = Evaluator(topk_gts, topk_res, metrics) 51 | scores = evaluator.evaluate() 52 | 53 | return scores 54 | -------------------------------------------------------------------------------- /kogito/evaluation/meteor/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "tylin" 2 | -------------------------------------------------------------------------------- /kogito/evaluation/meteor/meteor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Python wrapper for METEOR implementation, by Xinlei Chen 4 | # Acknowledge Michael Denkowski for the generous discussion and help 5 | 6 | 7 | from nltk.translate.meteor_score import meteor_score 8 | from nltk.tokenize import word_tokenize 9 | 10 | 11 | class Meteor: 12 | def __init__(self): 13 | pass 14 | 15 | def compute_score(self, gts, res): 16 | assert gts.keys() == res.keys() 17 | imgIds = gts.keys() 18 | scores = [] 19 | 20 | for i in imgIds: 21 | assert len(res[i]) == 1 22 | score = round( 23 | meteor_score( 24 | [word_tokenize(s) for s in gts[i]], word_tokenize(res[i][0]) 25 | ), 26 | 4, 27 | ) 28 | scores.append(score) 29 | 30 | return sum(scores) / len(scores), scores 31 | 32 | def method(self): 33 | return "METEOR" 34 | -------------------------------------------------------------------------------- /kogito/evaluation/rouge/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = "vrama91" 2 | -------------------------------------------------------------------------------- /kogito/evaluation/rouge/rouge.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # 3 | # File Name : rouge.py 4 | # 5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004) 6 | # 7 | # Creation Date : 2015-01-07 06:03 8 | # Author : Ramakrishna Vedantam 9 | 10 | import numpy as np 11 | 12 | 13 | def my_lcs(string, sub): 14 | """ 15 | Calculates longest common subsequence for a pair of tokenized strings 16 | :param string : list of str : tokens from a string split using whitespace 17 | :param sub : list of str : shorter string, also split using whitespace 18 | :returns: length (list of int): length of the longest common subsequence between the two strings 19 | 20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS 21 | """ 22 | if len(string) < len(sub): 23 | sub, string = string, sub 24 | 25 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)] 26 | 27 | for j in range(1, len(sub) + 1): 28 | for i in range(1, len(string) + 1): 29 | if string[i - 1] == sub[j - 1]: 30 | lengths[i][j] = lengths[i - 1][j - 1] + 1 31 | else: 32 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1]) 33 | 34 | return lengths[len(string)][len(sub)] 35 | 36 | 37 | class Rouge: 38 | """ 39 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set 40 | 41 | """ 42 | 43 | def __init__(self): 44 | # vrama91: updated the value below based on discussion with Hovey 45 | self.beta = 1.2 46 | 47 | def calc_score(self, candidate, refs): 48 | """ 49 | Compute ROUGE-L score given one candidate and references for an image 50 | :param candidate: str : candidate sentence to be evaluated 51 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated 52 | :returns score: int (ROUGE-L score for the candidate evaluated against references) 53 | """ 54 | assert len(candidate) == 1 55 | assert len(refs) > 0 56 | prec = [] 57 | rec = [] 58 | 59 | # split into tokens 60 | token_c = candidate[0].split(" ") 61 | 62 | for reference in refs: 63 | # split into tokens 64 | token_r = reference.split(" ") 65 | # compute the longest common subsequence 66 | lcs = my_lcs(token_r, token_c) 67 | prec.append(lcs / float(len(token_c))) 68 | rec.append(lcs / float(len(token_r))) 69 | 70 | prec_max = max(prec) 71 | rec_max = max(rec) 72 | 73 | if prec_max != 0 and rec_max != 0: 74 | score = ((1 + self.beta**2) * prec_max * rec_max) / float( 75 | rec_max + self.beta**2 * prec_max 76 | ) 77 | else: 78 | score = 0.0 79 | return score 80 | 81 | def compute_score(self, gts, res): 82 | """ 83 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset 84 | Invoked by evaluate_captions.py 85 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and 86 | "tokenized sentences" as values 87 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and 88 | "tokenized sentences" as values 89 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images) 90 | """ 91 | assert gts.keys() == res.keys() 92 | imgIds = gts.keys() 93 | 94 | score = [] 95 | for id in imgIds: 96 | hypo = res[id] 97 | ref = gts[id] 98 | 99 | score.append(self.calc_score(hypo, ref)) 100 | 101 | # Sanity check. 102 | assert type(hypo) is list 103 | assert len(hypo) == 1 104 | assert type(ref) is list 105 | assert len(ref) > 0 106 | 107 | average_score = np.mean(np.array(score)) 108 | return average_score, np.array(score) 109 | 110 | def method(self): 111 | return "Rouge" 112 | -------------------------------------------------------------------------------- /kogito/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/__init__.py -------------------------------------------------------------------------------- /kogito/models/bart/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/bart/__init__.py -------------------------------------------------------------------------------- /kogito/models/bart/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from dataclasses import dataclass, asdict 3 | 4 | BART_TASKS = ["summarization", "translation"] 5 | FP16_OPT_LEVELS = ["O0", "O1", "O2", "O3"] 6 | 7 | 8 | @dataclass 9 | class COMETBARTConfig: 10 | output_dir: str = None 11 | fp16: bool = False 12 | fp16_opt_level: str = "O2" 13 | tpu_cores: Optional[int] = None 14 | gradient_clip_val: float = 1.0 15 | accumulate_grad_batches: int = 1 16 | seed: int = 42 17 | max_source_length: int = 48 18 | max_target_length: int = 24 19 | val_max_target_length: int = 24 20 | test_max_target_length: int = 24 21 | freeze_encoder: bool = False 22 | freeze_embeds: bool = False 23 | sortish_sampler: bool = True 24 | n_train: int = -1 25 | n_val: int = 500 26 | n_test: int = -1 27 | task: str = "summarization" 28 | src_lang: str = "" 29 | tgt_lang: str = "" 30 | atomic: bool = True 31 | pretrained_model: str = "facebook/bart-large" 32 | pretrained_config: str = None 33 | pretrained_tokenizer: str = None 34 | cache_dir: str = "" 35 | learning_rate: float = 5e-5 36 | weight_decay: float = 0.0 37 | adam_epsilon: float = 1e-8 38 | warmup_steps: int = 0 39 | num_workers: int = 2 40 | max_epochs: int = 3 41 | train_batch_size: int = 32 42 | eval_batch_size: int = 32 43 | gpus: int = 1 44 | decoder_start_token_id: Optional[int] = None 45 | 46 | def __dict__(self): 47 | return asdict(self) 48 | -------------------------------------------------------------------------------- /kogito/models/gpt2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/gpt2/__init__.py -------------------------------------------------------------------------------- /kogito/models/gpt2/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | 4 | 5 | class GPT2Finetuner(pl.LightningModule): 6 | def __init__(self, model, learning_rate=1e-5) -> None: 7 | super().__init__() 8 | self.model = model 9 | self.learning_rate = learning_rate 10 | 11 | def forward(self, input_ids, mask): 12 | return self.model(input_ids=input_ids, attention_mask=mask, labels=input_ids) 13 | 14 | def training_step(self, batch, batch_idx): 15 | X = batch 16 | ids = X["source_ids"] 17 | mask = X["source_mask"] 18 | outputs = self.model(input_ids=ids, attention_mask=mask, labels=ids) 19 | loss = outputs[0] 20 | self.log("train_loss", loss, on_epoch=True) 21 | return loss 22 | 23 | def validation_step(self, batch, batch_idx): 24 | X = batch 25 | ids = X["source_ids"] 26 | mask = X["source_mask"] 27 | outputs = self.model(input_ids=ids, attention_mask=mask, labels=ids) 28 | loss = outputs[0] 29 | self.log("val_loss", loss, on_epoch=True) 30 | return loss 31 | 32 | def configure_optimizers(self): 33 | optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) 34 | return optimizer 35 | -------------------------------------------------------------------------------- /kogito/models/gpt2/zeroshot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import cuda 4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 5 | 6 | from kogito.core.model import KnowledgeModel 7 | from kogito.core.knowledge import KnowledgeGraph 8 | 9 | device = "cuda" if cuda.is_available() else "cpu" 10 | 11 | 12 | class GPT2Zeroshot(KnowledgeModel): 13 | """Zeroshot knowledge model based on GPT-2""" 14 | 15 | def __init__(self, gpt2_model: str = "gpt2") -> None: 16 | """Initialize GPT-2 model 17 | Args: 18 | gpt2_model (str, optional): HuggingFace model name for gpt2. Defaults to "gpt2". 19 | """ 20 | self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model) 21 | self.model = GPT2LMHeadModel.from_pretrained(gpt2_model) 22 | self.model.to(device) 23 | 24 | def train(self): 25 | raise ValueError("GPT-2 Zeroshot model is not trainable") 26 | 27 | def save_pretrained(self, save_path): 28 | self.model.save_pretrained(save_path) 29 | self.tokenizer.save_pretrained(save_path) 30 | 31 | @classmethod 32 | def from_pretrained(cls, model_name_or_path: str = "gpt2"): 33 | return cls(model_name_or_path) 34 | 35 | def generate( 36 | self, input_graph: KnowledgeGraph, seed: int = 42, **kwargs 37 | ) -> KnowledgeGraph: 38 | """Generate inferences from GPT2 model 39 | Args: 40 | input_graph (KnowledgeGraph): Input dataset 41 | seed (int, optional): Random seed. Defaults to 42. 42 | kwargs: Additional arguments to pass to the model.generate() function 43 | Returns: 44 | KnowledgeGraph: Completed knowledge graph 45 | """ 46 | torch.manual_seed(seed) 47 | np.random.seed(seed) 48 | torch.backends.cudnn.deterministic = True 49 | 50 | if "top_k" not in kwargs: 51 | kwargs["top_k"] = 1 52 | 53 | if "top_p" not in kwargs: 54 | kwargs["top_p"] = 0.9 55 | 56 | if "num_return_sequences" not in kwargs: 57 | kwargs["num_return_sequences"] = 3 58 | 59 | if "num_beams" not in kwargs: 60 | kwargs["num_beams"] = 3 61 | 62 | if "temperature" not in kwargs: 63 | kwargs["temperature"] = 0.7 64 | 65 | if "repetition_penalty" not in kwargs: 66 | kwargs["repetition_penalty"] = 1.2 67 | 68 | if "max_length" not in kwargs: 69 | kwargs["max_length"] = 32 70 | 71 | if "do_sample" not in kwargs: 72 | kwargs["do_sample"] = True 73 | 74 | outputs = [] 75 | for input_kg in input_graph: 76 | prompt = input_kg.to_prompt() 77 | input_ids = self.tokenizer.encode( 78 | prompt, add_special_tokens=False, return_tensors="pt" 79 | ) 80 | input_length = input_ids.size(1) 81 | generations = self.model.generate( 82 | input_ids=input_ids.to(device), 83 | max_length=input_length + kwargs["max_length"], 84 | eos_token_id=198, 85 | **kwargs 86 | ) 87 | 88 | if len(generations.shape) > 2: 89 | generations.squeeze_() 90 | 91 | text_generations = [] 92 | for gen in generations: 93 | gen = gen.tolist() 94 | text = self.tokenizer.decode( 95 | gen[input_length:], clean_up_tokenization_spaces=True 96 | ) 97 | text_generations.append(text.strip()) 98 | 99 | output_kg = input_kg.copy() 100 | output_kg.tails = text_generations 101 | outputs.append(output_kg) 102 | 103 | return KnowledgeGraph(outputs) 104 | -------------------------------------------------------------------------------- /kogito/models/gpt3/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/gpt3/__init__.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "kogito" 3 | version = "0.6.3" 4 | description = "A Python NLP Commonsense Knowledge Inference Toolkit" 5 | authors = ["Mete Ismayil ", "Antoine Bosselut "] 6 | license = "Apache License 2.0" 7 | readme = "README.md" 8 | homepage = "https://github.com/epfl-nlp/kogito" 9 | repository = "https://github.com/epfl-nlp/kogito" 10 | documentation = "https://github.com/epfl-nlp/kogito" 11 | keywords = [ 12 | "natural language processing", 13 | "nlp", 14 | "natural language understanding", 15 | "commonsense reasoning", 16 | "commonsense inference", 17 | "knowledge inference" 18 | ] 19 | classifiers = [] 20 | 21 | [tool.poetry.dependencies] 22 | python = ">=3.8,<3.11" 23 | sacrebleu = "^2.0.0" 24 | rouge-score = "^0.0.4" 25 | pytorch-lightning = "~1.5.10" 26 | pandas = "^1.3.5" 27 | spacy = "^3.2.3" 28 | inflect = "^5.3.0" 29 | transformers = "^4.15.0" 30 | wandb = "^0.12.9" 31 | torch = "^1.10.1" 32 | openai = "^0.18.1" 33 | bert-score = "^0.3.11" 34 | sentencepiece = "^0.1.97" 35 | grpcio = "~1.51.1" 36 | 37 | [tool.poetry.dev-dependencies] 38 | pytest = "*" 39 | flake8 = "*" 40 | black = "~22.12.0" 41 | mypy = "*" 42 | sphinx = "*" 43 | insegel = "*" 44 | 45 | [build-system] 46 | requires = ["poetry-core>=1.0.0"] 47 | build-backend = "poetry.core.masonry.api" 48 | --------------------------------------------------------------------------------