├── .flake8
├── .gitattributes
├── .gitignore
├── .mypy.ini
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── app
├── client
│ ├── .dockerignore
│ ├── .env
│ ├── .gitignore
│ ├── Dockerfile
│ ├── README.md
│ ├── package-lock.json
│ ├── package.json
│ ├── public
│ │ ├── CNAME
│ │ ├── android-chrome-192x192.png
│ │ ├── android-chrome-512x512.png
│ │ ├── apple-touch-icon.png
│ │ ├── favicon-16x16.png
│ │ ├── favicon-32x32.png
│ │ ├── favicon.ico
│ │ ├── index.html
│ │ ├── lab-logo.png
│ │ ├── manifest.json
│ │ └── robots.txt
│ └── src
│ │ ├── App.css
│ │ ├── App.js
│ │ ├── App.test.js
│ │ ├── api.js
│ │ ├── index.css
│ │ ├── index.js
│ │ ├── relations.js
│ │ ├── reportWebVitals.js
│ │ └── setupTests.js
├── docker-compose.yml
├── proxy-nginx.conf
└── server
│ ├── .dockerignore
│ ├── Dockerfile
│ ├── inference.py
│ ├── requirements.txt
│ └── server.py
├── docker
├── Dockerfile
├── build.sh
├── dev.Dockerfile
├── entrypoint.sh
├── push.sh
└── train.py
├── docs
├── Makefile
├── _static
│ └── css
│ │ └── custom.css
├── api.rst
├── conf.py
├── getting_started.rst
├── index.rst
├── make.bat
├── requirements.txt
└── userguide.rst
├── eacl2023
├── kogito-poster-eacl2023.pdf
└── kogito-presentation-eacl2023.pdf
├── examples
├── data
│ ├── atomic2020
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── sample_dev.tsv
│ │ ├── sample_test.tsv
│ │ └── sample_train.tsv
│ ├── cc_news_samples.txt
│ ├── dialog_samples.txt
│ └── story_samples.txt
├── demo.ipynb
├── eacl.ipynb
├── evaluate_comet_bart.py
├── evaluate_comet_gpt2.py
├── results
│ ├── cc_news
│ │ ├── cc_news_samples_1.json
│ │ ├── cc_news_samples_10.json
│ │ ├── cc_news_samples_11.json
│ │ ├── cc_news_samples_12.json
│ │ ├── cc_news_samples_13.json
│ │ ├── cc_news_samples_14.json
│ │ ├── cc_news_samples_15.json
│ │ ├── cc_news_samples_16.json
│ │ ├── cc_news_samples_17.json
│ │ ├── cc_news_samples_18.json
│ │ ├── cc_news_samples_19.json
│ │ ├── cc_news_samples_2.json
│ │ ├── cc_news_samples_20.json
│ │ ├── cc_news_samples_21.json
│ │ ├── cc_news_samples_22.json
│ │ ├── cc_news_samples_23.json
│ │ ├── cc_news_samples_24.json
│ │ ├── cc_news_samples_25.json
│ │ ├── cc_news_samples_26.json
│ │ ├── cc_news_samples_27.json
│ │ ├── cc_news_samples_28.json
│ │ ├── cc_news_samples_29.json
│ │ ├── cc_news_samples_3.json
│ │ ├── cc_news_samples_30.json
│ │ ├── cc_news_samples_31.json
│ │ ├── cc_news_samples_32.json
│ │ ├── cc_news_samples_33.json
│ │ ├── cc_news_samples_34.json
│ │ ├── cc_news_samples_35.json
│ │ ├── cc_news_samples_36.json
│ │ ├── cc_news_samples_37.json
│ │ ├── cc_news_samples_38.json
│ │ ├── cc_news_samples_39.json
│ │ ├── cc_news_samples_4.json
│ │ ├── cc_news_samples_40.json
│ │ ├── cc_news_samples_41.json
│ │ ├── cc_news_samples_42.json
│ │ ├── cc_news_samples_43.json
│ │ ├── cc_news_samples_44.json
│ │ ├── cc_news_samples_45.json
│ │ ├── cc_news_samples_46.json
│ │ ├── cc_news_samples_47.json
│ │ ├── cc_news_samples_48.json
│ │ ├── cc_news_samples_49.json
│ │ ├── cc_news_samples_5.json
│ │ ├── cc_news_samples_50.json
│ │ ├── cc_news_samples_51.json
│ │ ├── cc_news_samples_52.json
│ │ ├── cc_news_samples_53.json
│ │ ├── cc_news_samples_54.json
│ │ ├── cc_news_samples_55.json
│ │ ├── cc_news_samples_56.json
│ │ ├── cc_news_samples_57.json
│ │ ├── cc_news_samples_58.json
│ │ ├── cc_news_samples_59.json
│ │ ├── cc_news_samples_6.json
│ │ ├── cc_news_samples_60.json
│ │ ├── cc_news_samples_61.json
│ │ ├── cc_news_samples_62.json
│ │ ├── cc_news_samples_63.json
│ │ ├── cc_news_samples_64.json
│ │ ├── cc_news_samples_65.json
│ │ ├── cc_news_samples_66.json
│ │ ├── cc_news_samples_67.json
│ │ ├── cc_news_samples_7.json
│ │ ├── cc_news_samples_8.json
│ │ └── cc_news_samples_9.json
│ ├── cometbart_results_test_atomic2020.json
│ ├── daily_dialog
│ │ ├── dialog_samples_1.json
│ │ ├── dialog_samples_10.json
│ │ ├── dialog_samples_11.json
│ │ ├── dialog_samples_12.json
│ │ ├── dialog_samples_13.json
│ │ ├── dialog_samples_14.json
│ │ ├── dialog_samples_15.json
│ │ ├── dialog_samples_16.json
│ │ ├── dialog_samples_17.json
│ │ ├── dialog_samples_18.json
│ │ ├── dialog_samples_19.json
│ │ ├── dialog_samples_2.json
│ │ ├── dialog_samples_20.json
│ │ ├── dialog_samples_21.json
│ │ ├── dialog_samples_22.json
│ │ ├── dialog_samples_23.json
│ │ ├── dialog_samples_24.json
│ │ ├── dialog_samples_25.json
│ │ ├── dialog_samples_26.json
│ │ ├── dialog_samples_27.json
│ │ ├── dialog_samples_28.json
│ │ ├── dialog_samples_29.json
│ │ ├── dialog_samples_3.json
│ │ ├── dialog_samples_30.json
│ │ ├── dialog_samples_31.json
│ │ ├── dialog_samples_32.json
│ │ ├── dialog_samples_33.json
│ │ ├── dialog_samples_34.json
│ │ ├── dialog_samples_35.json
│ │ ├── dialog_samples_36.json
│ │ ├── dialog_samples_37.json
│ │ ├── dialog_samples_38.json
│ │ ├── dialog_samples_39.json
│ │ ├── dialog_samples_4.json
│ │ ├── dialog_samples_40.json
│ │ ├── dialog_samples_41.json
│ │ ├── dialog_samples_42.json
│ │ ├── dialog_samples_43.json
│ │ ├── dialog_samples_44.json
│ │ ├── dialog_samples_45.json
│ │ ├── dialog_samples_46.json
│ │ ├── dialog_samples_47.json
│ │ ├── dialog_samples_48.json
│ │ ├── dialog_samples_49.json
│ │ ├── dialog_samples_5.json
│ │ ├── dialog_samples_50.json
│ │ ├── dialog_samples_6.json
│ │ ├── dialog_samples_7.json
│ │ ├── dialog_samples_8.json
│ │ └── dialog_samples_9.json
│ ├── gpt3_prompt_2c976a77.txt
│ ├── gpt3_prompt_56fcb840.txt
│ ├── gpt3_prompt_99e17fae.txt
│ ├── kgraph.json
│ ├── kgraph_dry_run.json
│ ├── kgraph_dry_run_2.json
│ ├── kgraph_full.json
│ ├── kgraph_gpt3.json
│ ├── kgraph_gpt3_custom_relation.json
│ ├── kgraph_gpt3_dry_run.json
│ ├── kgraph_manual.json
│ ├── kgraph_manual_heads.json
│ ├── kgraph_modelbased_relations.json
│ ├── kgraph_modelbased_relations_bert.json
│ ├── kgraph_modelbased_relations_dbert.json
│ ├── kgraph_modelbased_relations_swem.json
│ ├── kgraph_no_head_extract.json
│ ├── kgraph_no_match_subset.json
│ ├── kgraph_rel_subset.json
│ ├── kgraph_with_context.json
│ ├── kgraph_without_context.json
│ ├── roc_stories
│ │ ├── story_samples_1.json
│ │ ├── story_samples_10.json
│ │ ├── story_samples_11.json
│ │ ├── story_samples_12.json
│ │ ├── story_samples_13.json
│ │ ├── story_samples_14.json
│ │ ├── story_samples_15.json
│ │ ├── story_samples_16.json
│ │ ├── story_samples_17.json
│ │ ├── story_samples_18.json
│ │ ├── story_samples_19.json
│ │ ├── story_samples_2.json
│ │ ├── story_samples_20.json
│ │ ├── story_samples_21.json
│ │ ├── story_samples_22.json
│ │ ├── story_samples_23.json
│ │ ├── story_samples_24.json
│ │ ├── story_samples_25.json
│ │ ├── story_samples_26.json
│ │ ├── story_samples_27.json
│ │ ├── story_samples_28.json
│ │ ├── story_samples_29.json
│ │ ├── story_samples_3.json
│ │ ├── story_samples_30.json
│ │ ├── story_samples_31.json
│ │ ├── story_samples_32.json
│ │ ├── story_samples_33.json
│ │ ├── story_samples_34.json
│ │ ├── story_samples_35.json
│ │ ├── story_samples_36.json
│ │ ├── story_samples_37.json
│ │ ├── story_samples_38.json
│ │ ├── story_samples_39.json
│ │ ├── story_samples_4.json
│ │ ├── story_samples_40.json
│ │ ├── story_samples_41.json
│ │ ├── story_samples_42.json
│ │ ├── story_samples_43.json
│ │ ├── story_samples_44.json
│ │ ├── story_samples_45.json
│ │ ├── story_samples_46.json
│ │ ├── story_samples_47.json
│ │ ├── story_samples_48.json
│ │ ├── story_samples_49.json
│ │ ├── story_samples_5.json
│ │ ├── story_samples_50.json
│ │ ├── story_samples_6.json
│ │ ├── story_samples_7.json
│ │ ├── story_samples_8.json
│ │ └── story_samples_9.json
│ ├── test_atomic2020_res_cometgpt2.json
│ ├── test_atomic2020_res_cometgpt2_sample.json
│ └── test_atomic2020_res_zeroshot_sample.jsonl
├── sample_graph.jsonl
├── sample_graph.tsv
├── sample_graph2.tsv
├── sample_graph3.jsonl
├── sample_linking_graph.csv
├── snapshots
│ ├── custom-relation.png
│ ├── kg-concepts.png
│ ├── pipeline.png
│ └── quickstart.png
├── test.ipynb
├── test.py
├── test_atomic2020.json
├── test_atomic2020_sample.json
├── test_comet_bart.py
├── test_comet_gpt2.py
├── test_gpt2.json
├── test_zeroshot.py
├── train_comet_bart.py
└── train_comet_gpt2.py
├── experiments
├── fact_linking
│ ├── data
│ │ ├── roc_full_pipeline_result_samples.json
│ │ ├── roc_full_pipeline_results.json
│ │ ├── roc_nlu_test.csv
│ │ ├── roc_results.json
│ │ └── roc_samples.csv
│ ├── evaluation.ipynb
│ └── preprocess.ipynb
└── relation_modeling
│ ├── analysis.ipynb
│ ├── base.ipynb
│ ├── data
│ └── atomic_split
│ │ ├── n1
│ │ ├── test_n1.csv
│ │ └── train_n1.csv
│ │ ├── n3
│ │ ├── test_n3.csv
│ │ └── train_n3.csv
│ │ └── n5
│ │ ├── test_n5.csv
│ │ └── train_n5.csv
│ ├── heuristic.ipynb
│ ├── ood.ipynb
│ ├── prediction.ipynb
│ ├── relation_modeling_utils.py
│ ├── spacy
│ ├── base_config.cfg
│ ├── config.cfg
│ └── relation_model_spacy.ipynb
│ ├── swem
│ ├── relation_modeling_swem.ipynb
│ ├── relation_modeling_swem.py
│ ├── relation_modeling_swem_finetune.ipynb
│ ├── relation_modeling_swem_multi.ipynb
│ ├── relation_modeling_swem_multi_nn.ipynb
│ ├── relation_modeling_swem_multi_nn_no_pad.ipynb
│ ├── relation_modeling_swem_multi_no_pad.ipynb
│ ├── relation_modeling_swem_multi_no_pad_max.ipynb
│ └── relation_modeling_swem_no_pad.ipynb
│ └── transformer
│ ├── relation_modeling_bert.ipynb
│ ├── relation_modeling_bert.py
│ ├── relation_modeling_distillbert.ipynb
│ ├── relation_modeling_distillbert.py
│ └── relation_modeling_transformer.ipynb
├── kogito
├── __init__.py
├── core
│ ├── __init__.py
│ ├── callbacks.py
│ ├── dataset.py
│ ├── head.py
│ ├── knowledge.py
│ ├── linker.py
│ ├── model.py
│ ├── processors
│ │ ├── __init__.py
│ │ ├── data
│ │ │ └── vocab_glove_100d.npy
│ │ ├── head.py
│ │ ├── models
│ │ │ ├── __init__.py
│ │ │ ├── bert.py
│ │ │ ├── distilbert.py
│ │ │ ├── swem.py
│ │ │ └── utils.py
│ │ └── relation.py
│ ├── relation.py
│ └── utils.py
├── evaluation
│ ├── LICENSE
│ ├── __init__.py
│ ├── bert_score
│ │ ├── __init__.py
│ │ ├── bert_score.py
│ │ ├── score.py
│ │ └── utils.py
│ ├── bleu
│ │ ├── .gitignore
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ ├── bleu.py
│ │ └── bleu_scorer.py
│ ├── cider
│ │ ├── __init__.py
│ │ ├── cider.py
│ │ └── cider_scorer.py
│ ├── eval.py
│ ├── meteor
│ │ ├── __init__.py
│ │ └── meteor.py
│ └── rouge
│ │ ├── __init__.py
│ │ └── rouge.py
├── inference.py
├── linkers
│ └── deberta.py
└── models
│ ├── __init__.py
│ ├── bart
│ ├── __init__.py
│ ├── comet.py
│ ├── config.py
│ └── utils.py
│ ├── gpt2
│ ├── __init__.py
│ ├── comet.py
│ ├── utils.py
│ └── zeroshot.py
│ └── gpt3
│ ├── __init__.py
│ └── zeroshot.py
├── poetry.lock
└── pyproject.toml
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E203, W503
3 | max-line-length = 120
4 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | experiments/**/*.ipynb linguist-vendored
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .venv
106 | env/
107 | venv/
108 | ENV/
109 | env.bak/
110 | venv.bak/
111 |
112 | # Spyder project settings
113 | .spyderproject
114 | .spyproject
115 |
116 | # Rope project settings
117 | .ropeproject
118 |
119 | # mkdocs documentation
120 | /site
121 |
122 | # mypy
123 | .mypy_cache/
124 | .dmypy.json
125 | dmypy.json
126 |
127 | # Pyre type checker
128 | .pyre/
129 |
130 | # MacOS
131 | .DS_Store
132 |
133 | # kogito
134 | experiments/relation_modeling/output
135 | experiments/relation_modeling/models
136 | experiments/relation_modeling/hmodels
137 | experiments/relation_modeling/data
138 | examples/models
139 | wandb
140 | .vector_cache/
141 | lightning_logs/
142 | kogito-relation-matcher/
143 | *.bin
144 | __MACOSX
--------------------------------------------------------------------------------
/.mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.7
3 |
4 | [mypy-setuptools]
5 | ignore_missing_imports = True
6 |
7 | [mypy-pytorch_lightning]
8 | ignore_missing_imports = True
9 |
10 | [mypy-pytorch_lightning.callbacks]
11 | ignore_missing_imports = True
12 |
13 | [mypy-pytorch_lightning.utilities]
14 | ignore_missing_imports = True
15 |
16 | [mypy-spacy]
17 | ignore_missing_imports = True
18 |
19 | [mypy-inflect]
20 | ignore_missing_imports = True
21 |
22 | [mypy-rouge_score]
23 | ignore_missing_imports = True
24 |
25 | [mypy-tqdm]
26 | ignore_missing_imports = True
27 |
28 | [mypy-pandas]
29 | ignore_missing_imports = True
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 |
3 | build:
4 | os: ubuntu-20.04
5 | tools:
6 | python: "3.8"
7 |
8 | python:
9 | install:
10 | - requirements: docs/requirements.txt
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # kogito
2 | A Python NLP Commonsense Knowledge Inference Toolkit
3 |
4 | System Description available here: https://arxiv.org/abs/2211.08451
5 |
6 | ## Installation
7 |
8 | ### Installation with pip
9 | **kogito** can be installed using pip.
10 |
11 | ```sh
12 | pip install kogito
13 | ```
14 |
15 | It requires a minimum ``python`` version of ``3.8``.
16 |
17 | ## Setup
18 |
19 | ### Inference
20 | **kogito** uses [spacy](https://spacy.io) under the hood for various text processing purposes, so, a [spacy](https://spacy.io) language package has to be installed before running the inference module.
21 |
22 | ```sh
23 | python -m spacy download en_core_web_sm
24 | ```
25 | By default, ``CommonsenseInference`` module uses ``en_core_web_sm`` to initialize ``spacy`` pipeline, but a different language pipeline can be specified as well.
26 |
27 | ### Evaluation
28 | If you also would like evaluate knowledge models using `METEOR` score, then you need to download the following ``nltk`` libraries:
29 | ```python
30 | import nltk
31 |
32 | nltk.download("punkt")
33 | nltk.download("wordnet")
34 | nltk.download("omw-1.4")
35 | ```
36 |
37 | ## Quickstart
38 | **kogito** provides an easy interface to interact with knowledge inference or commonsense reasoning models such as [COMET](https://arxiv.org/abs/2010.05953) to generate inferences from a text input.
39 | Here is a sample usage of the library where you can initialize an inference module, a custom commonsense reasoning model, and generate a knowledge graph from text on the fly.
40 |
41 | ```python
42 | from kogito.models.bart.comet import COMETBART
43 | from kogito.inference import CommonsenseInference
44 |
45 | # Load pre-trained model from HuggingFace
46 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2")
47 |
48 | # Initialize inference module with a spacy language pipeline
49 | csi = CommonsenseInference(language="en_core_web_sm")
50 |
51 | # Run inference
52 | text = "PersonX becomes a great basketball player"
53 | kgraph = csi.infer(text, model)
54 |
55 | # Save output knowledge graph to JSON file
56 | kgraph.to_jsonl("kgraph.json")
57 | ```
58 |
59 | Here is an excerpt from the result of the above code sample:
60 |
61 | ```json
62 | {"head": "PersonX becomes a great basketball player", "relation": "Causes", "tails": [" PersonX practices every day.", " PersonX plays basketball every day", " PersonX practices every day"]}
63 | {"head": "basketball", "relation": "ObjectUse", "tails": [" play with friends", " play basketball with", " play basketball"]}
64 | {"head": "player", "relation": "CapableOf", "tails": [" play game", " win game", " play football"]}
65 | {"head": "great basketball player", "relation": "HasProperty", "tails": [" good at basketball", " good at sports", " very good"]}
66 | {"head": "become player", "relation": "isAfter", "tails": [" play game", " become coach", " play with"]}
67 | ```
68 | This is just one way to generate commonsense inferences and **kogito** offers much more. For complete documentation, check out the [kogito docs](https://kogito.readthedocs.io).
69 |
70 | ## Development
71 |
72 | ### Setup
73 | **kogito** uses [Poetry](https://python-poetry.org/) to manage its dependencies.
74 |
75 | Install poetry from the official repository first:
76 | ```sh
77 | curl -sSL https://install.python-poetry.org | python3 -
78 | ```
79 |
80 | Then run the following command to install package dependencies:
81 | ```sh
82 | poetry install
83 | ```
84 |
85 | ## Data
86 | If you need the ATOMIC2020 data to train your knowledge models, you can download it from AI2:
87 |
88 | For ATOMIC:
89 | ```sh
90 | wget https://storage.googleapis.com/ai2-mosaic/public/atomic/v1.0/atomic_data.tgz
91 | ```
92 |
93 | For ATOMIC 2020:
94 | ```sh
95 | wget https://ai2-atomic.s3-us-west-2.amazonaws.com/data/atomic2020_data-feb2021.zip
96 | ```
97 |
98 | ## Paper
99 | If you want to learn more about the library design, models and data used for this toolkit, check out our [paper](https://arxiv.org/abs/2211.08451). The paper can be cited as:
100 |
101 | ```
102 | @article{Ismayilzada2022kogito,
103 | title={kogito: A Commonsense Knowledge Inference Toolkit},
104 | author={Mete Ismayilzada and Antoine Bosselut},
105 | journal={ArXiv},
106 | volume={abs/2211.08451},
107 | year={2022}
108 | }
109 | ```
110 |
111 | If you work with knowledge models, consider citing the following papers:
112 |
113 | ```
114 | @article{Hwang2020COMETATOMIC,
115 | author = {Jena D. Hwang and Chandra Bhagavatula and Ronan Le Bras and Jeff Da and Keisuke Sakaguchi and Antoine Bosselut and Yejin Choi},
116 | booktitle = {Proceedings of the 35th AAAI Conference on Artificial Intelligence (AAAI)},
117 | title = {COMET-ATOMIC 2020: On Symbolic and Neural Commonsense Knowledge Graphs},
118 | year = {2021}
119 | }
120 |
121 | @inproceedings{Bosselut2019COMETCT,
122 | author = {Antoine Bosselut and Hannah Rashkin and Maarten Sap and Chaitanya Malaviya and Asli Çelikyilmaz and Yejin Choi},
123 | booktitle = {Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (ACL)},
124 | title = {COMET: Commonsense Transformers for Automatic Knowledge Graph Construction},
125 | year = {2019}
126 | }
127 | ```
128 |
129 | ## Acknowledgements
130 | Significant portion of the model training and evaluation code has been adapted from the original [codebase](https://github.com/allenai/comet-atomic-2020) for the paper [(Comet-) Atomic 2020: On Symbolic and Neural Commonsense Knowledge Graphs.](https://www.semanticscholar.org/paper/COMET-ATOMIC-2020%3A-On-Symbolic-and-Neural-Knowledge-Hwang-Bhagavatula/e39503e01ebb108c6773948a24ca798cd444eb62)
131 |
--------------------------------------------------------------------------------
/app/client/.dockerignore:
--------------------------------------------------------------------------------
1 | node_modules
2 | build
--------------------------------------------------------------------------------
/app/client/.env:
--------------------------------------------------------------------------------
1 | REACT_APP_SERVER_URL=https://proxy.kogito.live
--------------------------------------------------------------------------------
/app/client/.gitignore:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
2 |
3 | # dependencies
4 | /node_modules
5 | /.pnp
6 | .pnp.js
7 |
8 | # testing
9 | /coverage
10 |
11 | # production
12 | /build
13 |
14 | # misc
15 | .DS_Store
16 | .env.local
17 | .env.development.local
18 | .env.test.local
19 | .env.production.local
20 |
21 | npm-debug.log*
22 | yarn-debug.log*
23 | yarn-error.log*
24 |
--------------------------------------------------------------------------------
/app/client/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM node:18.5
2 |
3 | ENV HOME=/root
4 | ENV APP_DIR=${HOME}/kogito-client
5 | ENV REACT_APP_SERVER_URL=http://localhost:8080
6 |
7 | # Set default shell to /bin/bash
8 | SHELL ["/bin/bash", "-cu"]
9 |
10 | # Install dependencies
11 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends vim wget unzip
12 |
13 | RUN mkdir ${APP_DIR}
14 | WORKDIR ${APP_DIR}
15 |
16 | # Copy client files
17 | COPY . .
18 |
19 | # Setup app dependencies
20 | RUN npm install -g serve
21 | RUN npm install
22 | RUN npm run build
23 |
24 | EXPOSE 3000
25 |
26 | CMD ["serve", "-s", "build"]
--------------------------------------------------------------------------------
/app/client/README.md:
--------------------------------------------------------------------------------
1 | # Getting Started with Create React App
2 |
3 | This project was bootstrapped with [Create React App](https://github.com/facebook/create-react-app).
4 |
5 | ## Available Scripts
6 |
7 | In the project directory, you can run:
8 |
9 | ### `npm start`
10 |
11 | Runs the app in the development mode.\
12 | Open [http://localhost:3000](http://localhost:3000) to view it in your browser.
13 |
14 | The page will reload when you make changes.\
15 | You may also see any lint errors in the console.
16 |
17 | ### `npm test`
18 |
19 | Launches the test runner in the interactive watch mode.\
20 | See the section about [running tests](https://facebook.github.io/create-react-app/docs/running-tests) for more information.
21 |
22 | ### `npm run build`
23 |
24 | Builds the app for production to the `build` folder.\
25 | It correctly bundles React in production mode and optimizes the build for the best performance.
26 |
27 | The build is minified and the filenames include the hashes.\
28 | Your app is ready to be deployed!
29 |
30 | See the section about [deployment](https://facebook.github.io/create-react-app/docs/deployment) for more information.
31 |
32 | ### `npm run eject`
33 |
34 | **Note: this is a one-way operation. Once you `eject`, you can't go back!**
35 |
36 | If you aren't satisfied with the build tool and configuration choices, you can `eject` at any time. This command will remove the single build dependency from your project.
37 |
38 | Instead, it will copy all the configuration files and the transitive dependencies (webpack, Babel, ESLint, etc) right into your project so you have full control over them. All of the commands except `eject` will still work, but they will point to the copied scripts so you can tweak them. At this point you're on your own.
39 |
40 | You don't have to ever use `eject`. The curated feature set is suitable for small and middle deployments, and you shouldn't feel obligated to use this feature. However we understand that this tool wouldn't be useful if you couldn't customize it when you are ready for it.
41 |
42 | ## Learn More
43 |
44 | You can learn more in the [Create React App documentation](https://facebook.github.io/create-react-app/docs/getting-started).
45 |
46 | To learn React, check out the [React documentation](https://reactjs.org/).
47 |
48 | ### Code Splitting
49 |
50 | This section has moved here: [https://facebook.github.io/create-react-app/docs/code-splitting](https://facebook.github.io/create-react-app/docs/code-splitting)
51 |
52 | ### Analyzing the Bundle Size
53 |
54 | This section has moved here: [https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size](https://facebook.github.io/create-react-app/docs/analyzing-the-bundle-size)
55 |
56 | ### Making a Progressive Web App
57 |
58 | This section has moved here: [https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app](https://facebook.github.io/create-react-app/docs/making-a-progressive-web-app)
59 |
60 | ### Advanced Configuration
61 |
62 | This section has moved here: [https://facebook.github.io/create-react-app/docs/advanced-configuration](https://facebook.github.io/create-react-app/docs/advanced-configuration)
63 |
64 | ### Deployment
65 |
66 | This section has moved here: [https://facebook.github.io/create-react-app/docs/deployment](https://facebook.github.io/create-react-app/docs/deployment)
67 |
68 | ### `npm run build` fails to minify
69 |
70 | This section has moved here: [https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify](https://facebook.github.io/create-react-app/docs/troubleshooting#npm-run-build-fails-to-minify)
71 |
--------------------------------------------------------------------------------
/app/client/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "kogito",
3 | "version": "0.1.0",
4 | "homepage": "https://kogito.live",
5 | "private": true,
6 | "dependencies": {
7 | "@testing-library/jest-dom": "^5.16.4",
8 | "@testing-library/react": "^13.3.0",
9 | "@testing-library/user-event": "^13.5.0",
10 | "axios": "^0.27.2",
11 | "file-saver": "^2.0.5",
12 | "lodash": "^4.17.21",
13 | "react": "^18.2.0",
14 | "react-copy-to-clipboard": "^5.1.0",
15 | "react-dom": "^18.2.0",
16 | "react-scripts": "5.0.1",
17 | "semantic-ui-react": "^2.1.3",
18 | "semantic-ui-react-numberinput": "^1.5.1",
19 | "serve": "^14.0.1",
20 | "web-vitals": "^2.1.4"
21 | },
22 | "scripts": {
23 | "start": "react-scripts start",
24 | "build": "react-scripts build",
25 | "test": "react-scripts test",
26 | "eject": "react-scripts eject",
27 | "predeploy": "npm run build",
28 | "deploy": "gh-pages -d build"
29 | },
30 | "eslintConfig": {
31 | "extends": [
32 | "react-app",
33 | "react-app/jest"
34 | ]
35 | },
36 | "browserslist": {
37 | "production": [
38 | ">0.2%",
39 | "not dead",
40 | "not op_mini all"
41 | ],
42 | "development": [
43 | "last 1 chrome version",
44 | "last 1 firefox version",
45 | "last 1 safari version"
46 | ]
47 | },
48 | "devDependencies": {
49 | "gh-pages": "^4.0.0"
50 | }
51 | }
52 |
--------------------------------------------------------------------------------
/app/client/public/CNAME:
--------------------------------------------------------------------------------
1 | kogito.live
--------------------------------------------------------------------------------
/app/client/public/android-chrome-192x192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/android-chrome-192x192.png
--------------------------------------------------------------------------------
/app/client/public/android-chrome-512x512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/android-chrome-512x512.png
--------------------------------------------------------------------------------
/app/client/public/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/apple-touch-icon.png
--------------------------------------------------------------------------------
/app/client/public/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/favicon-16x16.png
--------------------------------------------------------------------------------
/app/client/public/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/favicon-32x32.png
--------------------------------------------------------------------------------
/app/client/public/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/favicon.ico
--------------------------------------------------------------------------------
/app/client/public/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
23 |
24 |
33 |
34 |
35 |
36 | Kogito
37 |
38 |
39 |
40 |
41 |
51 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/app/client/public/lab-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/app/client/public/lab-logo.png
--------------------------------------------------------------------------------
/app/client/public/manifest.json:
--------------------------------------------------------------------------------
1 | {
2 | "short_name": "Kogito",
3 | "name": "Kogito: Knowledge Inference Tool",
4 | "icons": [
5 | {
6 | "src": "favicon.ico",
7 | "sizes": "64x64 32x32 24x24 16x16",
8 | "type": "image/x-icon"
9 | },
10 | {
11 | "src": "android-chrome-192x192.png",
12 | "type": "image/png",
13 | "sizes": "192x192"
14 | },
15 | {
16 | "src": "android-chrome-512x512.png",
17 | "type": "image/png",
18 | "sizes": "512x512"
19 | }
20 | ],
21 | "start_url": ".",
22 | "display": "standalone",
23 | "theme_color": "#000000",
24 | "background_color": "#ffffff"
25 | }
26 |
--------------------------------------------------------------------------------
/app/client/public/robots.txt:
--------------------------------------------------------------------------------
1 | # https://www.robotstxt.org/robotstxt.html
2 | User-agent: *
3 | Disallow:
4 |
--------------------------------------------------------------------------------
/app/client/src/App.css:
--------------------------------------------------------------------------------
1 | .logo {
2 | padding-top: 10px;
3 | font-size: 64px;
4 | font-family: 'Economica', sans-serif;
5 | margin-bottom: 10px;
6 | }
7 |
8 | .logo-k {
9 | background-color: #dce755;
10 | font-weight: bold;
11 | }
12 |
13 | .description {
14 | /* font-size: 20px; */
15 | }
16 |
17 | .cntr-label {
18 | padding-bottom: 5px;
19 | }
20 |
21 | .cntr {
22 | margin-top: 24px;
23 | margin-bottom: 36px;
24 | }
25 |
26 | .cntr-head {
27 | margin-top: 10px;
28 | }
29 |
30 | .ui.button.kbtn {
31 | background: #dce755;
32 | }
33 |
34 | .ui.button.kbtn:hover {
35 | background: #dce755;
36 | }
37 |
38 | .home-results-json-segment {
39 | padding: 0!important;
40 | border: none!important;
41 | }
--------------------------------------------------------------------------------
/app/client/src/App.test.js:
--------------------------------------------------------------------------------
1 | import { render, screen } from '@testing-library/react';
2 | import App from './App';
3 |
4 | test('renders learn react link', () => {
5 | render();
6 | const linkElement = screen.getByText(/learn react/i);
7 | expect(linkElement).toBeInTheDocument();
8 | });
9 |
--------------------------------------------------------------------------------
/app/client/src/api.js:
--------------------------------------------------------------------------------
1 | import axios from 'axios'
2 |
3 | const API_URI = process.env.REACT_APP_SERVER_URL || "http://localhost:8080"
4 |
5 | const api = {
6 | inference: {
7 | generate: (data) => axios.post(API_URI + '/inference', data)
8 | }
9 | }
10 |
11 | export default api
--------------------------------------------------------------------------------
/app/client/src/index.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
3 | 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue',
4 | sans-serif;
5 | -webkit-font-smoothing: antialiased;
6 | -moz-osx-font-smoothing: grayscale;
7 | }
8 |
9 | code {
10 | font-family: source-code-pro, Menlo, Monaco, Consolas, 'Courier New',
11 | monospace;
12 | }
13 |
--------------------------------------------------------------------------------
/app/client/src/index.js:
--------------------------------------------------------------------------------
1 | import React from 'react';
2 | import ReactDOM from 'react-dom/client';
3 | import './index.css';
4 | import App from './App';
5 | // import reportWebVitals from './reportWebVitals';
6 |
7 | const root = ReactDOM.createRoot(document.getElementById('root'));
8 | root.render(
9 | //
10 |
11 | //
12 | );
13 |
14 | // If you want to start measuring performance in your app, pass a function
15 | // to log results (for example: reportWebVitals(console.log))
16 | // or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals
17 | // reportWebVitals();
18 |
--------------------------------------------------------------------------------
/app/client/src/relations.js:
--------------------------------------------------------------------------------
1 | const RELATIONS = [
2 | 'AtLocation',
3 | 'CapableOf',
4 | 'Causes',
5 | 'CausesDesire',
6 | 'Desires',
7 | 'HasProperty',
8 | 'HasSubEvent',
9 | 'HinderedBy',
10 | 'MadeUpOf',
11 | 'NotDesires',
12 | 'ObjectUse',
13 | 'IsAfter',
14 | 'IsBefore',
15 | 'oEffect',
16 | 'oReact',
17 | 'oWant',
18 | 'xAttr',
19 | 'xEffect',
20 | 'xIntent',
21 | 'xNeed',
22 | 'xReact',
23 | 'xReason',
24 | 'xWant',
25 | 'CreatedBy',
26 | 'DefinedAs',
27 | 'DesireOf',
28 | 'HasA',
29 | 'HasFirstSubevent',
30 | 'HasLastSubevent',
31 | 'HasPainCharacter',
32 | 'HasPainIntensity',
33 | 'HasPrerequisite',
34 | 'InheritsFrom',
35 | 'InstanceOf',
36 | 'IsA',
37 | 'LocatedNear',
38 | 'LocationOfAction',
39 | 'MadeOf',
40 | 'NotHasA',
41 | 'NotHasProperty',
42 | 'NotIsA',
43 | 'NotMadeOf',
44 | 'MotivatedByGoal',
45 | 'NotCapableOf',
46 | 'PartOf',
47 | 'ReceivesAction',
48 | 'RelatedTo',
49 | 'SymbolOf',
50 | 'UsedFor',
51 | ]
52 |
53 | export default RELATIONS
--------------------------------------------------------------------------------
/app/client/src/reportWebVitals.js:
--------------------------------------------------------------------------------
1 | const reportWebVitals = onPerfEntry => {
2 | if (onPerfEntry && onPerfEntry instanceof Function) {
3 | import('web-vitals').then(({ getCLS, getFID, getFCP, getLCP, getTTFB }) => {
4 | getCLS(onPerfEntry);
5 | getFID(onPerfEntry);
6 | getFCP(onPerfEntry);
7 | getLCP(onPerfEntry);
8 | getTTFB(onPerfEntry);
9 | });
10 | }
11 | };
12 |
13 | export default reportWebVitals;
14 |
--------------------------------------------------------------------------------
/app/client/src/setupTests.js:
--------------------------------------------------------------------------------
1 | // jest-dom adds custom jest matchers for asserting on DOM nodes.
2 | // allows you to do things like:
3 | // expect(element).toHaveTextContent(/react/i)
4 | // learn more: https://github.com/testing-library/jest-dom
5 | import '@testing-library/jest-dom';
6 |
--------------------------------------------------------------------------------
/app/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: '2'
2 |
3 | services:
4 | server:
5 | container_name: kogito-server
6 | image: kogito-server
7 | build: server
8 | ports:
9 | - 8080:8080
10 | client:
11 | container_name: kogito-client
12 | image: kogito-client
13 | build: client
14 | ports:
15 | - 3000:3000
16 | depends_on:
17 | - server
--------------------------------------------------------------------------------
/app/proxy-nginx.conf:
--------------------------------------------------------------------------------
1 | server {
2 | listen [::]:443 ssl ipv6only=on; # managed by Certbot
3 | listen 443 ssl; # managed by Certbot
4 | ssl_certificate /etc/letsencrypt/live/proxy.kogito.live/fullchain.pem; # managed by Certbot
5 | ssl_certificate_key /etc/letsencrypt/live/proxy.kogito.live/privkey.pem; # managed by Certbot
6 | include /etc/letsencrypt/options-ssl-nginx.conf; # managed by Certbot
7 | ssl_dhparam /etc/letsencrypt/ssl-dhparams.pem; # managed by Certbot
8 | server_name proxy.kogito.live;
9 | proxy_read_timeout 300;
10 | proxy_connect_timeout 300;
11 | proxy_send_timeout 300;
12 |
13 | location / {
14 | proxy_pass http://76.50.42.128:49780;
15 | proxy_http_version 1.1;
16 | proxy_set_header Upgrade $http_upgrade;
17 | proxy_set_header Connection 'upgrade';
18 | proxy_set_header Host $host;
19 | proxy_cache_bypass $http_upgrade;
20 | proxy_set_header X-Real-IP $remote_addr;
21 | proxy_set_header X-Forwarded-Proto https;
22 | proxy_set_header X-Forwarded-For $remote_addr;
23 | proxy_set_header X-Forwarded-Host $remote_addr;
24 | }
25 | }
26 | server {
27 | if ($host = proxy.kogito.live) {
28 | return 301 https://$host$request_uri;
29 | } # managed by Certbot
30 |
31 |
32 | listen 80 ;
33 | listen [::]:80 ;
34 | server_name proxy.kogito.live;
35 | return 404; # managed by Certbot
36 |
37 |
38 | }
--------------------------------------------------------------------------------
/app/server/.dockerignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | lightning_logs
--------------------------------------------------------------------------------
/app/server/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.8
2 |
3 | ENV HOME=/root
4 | ENV APP_DIR=${HOME}/kogito/app/server
5 | ENV FLASK_APP=server
6 |
7 | # Set default shell to /bin/bash
8 | SHELL ["/bin/bash", "-cu"]
9 |
10 | # Install dependencies
11 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends vim wget unzip
12 |
13 | WORKDIR ${HOME}
14 | ARG GITHUB_PERSONAL_TOKEN
15 |
16 | # Clone kogito
17 | RUN git clone https://${GITHUB_PERSONAL_TOKEN}@github.com/epfl-nlp/kogito.git
18 |
19 | WORKDIR ${APP_DIR}
20 |
21 | # Setup app dependencies
22 | RUN pip3 install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu114
23 | RUN python3 -c "import nltk;nltk.download('punkt');nltk.download('wordnet');nltk.download('omw-1.4')"
24 | RUN python3 -m spacy download en_core_web_sm
25 |
26 | EXPOSE 8080
27 |
28 | # gunicorn -w 1 -b 0.0.0.0:8080 --log-file server.log --capture-output --timeout 3600 --pid pid.log server:app --daemon
29 | ENTRYPOINT ["flask", "run", "-h", "0.0.0.0", "-p", "8080"]
--------------------------------------------------------------------------------
/app/server/inference.py:
--------------------------------------------------------------------------------
1 | import spacy
2 |
3 | from kogito.models.bart.comet import COMETBART
4 |
5 | # from kogito.models.gpt2.comet import COMETGPT2
6 | # from kogito.models.gpt2.zeroshot import GPT2Zeroshot
7 | from kogito.inference import CommonsenseInference
8 | from kogito.core.relation import KnowledgeRelation
9 | from kogito.core.processors.relation import (
10 | # SWEMRelationMatcher,
11 | DistilBERTRelationMatcher,
12 | # BERTRelationMatcher,
13 | )
14 | from kogito.linkers.deberta import DebertaLinker
15 |
16 | MODEL_MAP = {
17 | "comet-bart": COMETBART.from_pretrained("mismayil/comet-bart-ai2"),
18 | # "comet-gpt2": COMETGPT2.from_pretrained("mismayil/comet-gpt2-ai2"),
19 | # "gpt2": GPT2Zeroshot("gpt2-xl")
20 | }
21 |
22 | LINKER_MAP = {"deberta": DebertaLinker()}
23 |
24 | PROCESSOR_MAP = {
25 | # "swem_relation_matcher": SWEMRelationMatcher("swem_relation_matcher"),
26 | "distilbert_relation_matcher": DistilBERTRelationMatcher(
27 | "distilbert_relation_matcher"
28 | ),
29 | # "bert_relation_matcher": BERTRelationMatcher("bert_relation_matcher"),
30 | }
31 |
32 | nlp = spacy.load("en_core_web_sm")
33 |
34 | print("Ready for inference.")
35 |
36 |
37 | def infer(data):
38 | text = data.get("text")
39 | model = MODEL_MAP.get(data.get("model"))
40 | heads = data.get("heads")
41 | relations = data.get("relations")
42 | extract_heads = data.get("extractHeads", True)
43 | match_relations = data.get("matchRelations", True)
44 | dry_run = data.get("dryRun", False)
45 | head_procs = data.get("headProcs", [])
46 | rel_procs = data.get("relProcs", [])
47 | context = data.get("context", "").strip()
48 | threshold = float(data.get("threshold", 0.5))
49 |
50 | csi = CommonsenseInference(language="en_core_web_sm")
51 | csi_head_procs = csi.processors["head"]
52 | csi_rel_procs = csi.processors["relation"]
53 |
54 | for proc in set(csi_head_procs).difference(set(head_procs)):
55 | csi.remove_processor(proc)
56 |
57 | for proc in set(csi_rel_procs).difference(set(rel_procs)):
58 | csi.remove_processor(proc)
59 |
60 | for proc in set(head_procs).difference(set(csi_head_procs)):
61 | csi.add_processor(PROCESSOR_MAP[proc])
62 |
63 | for proc in set(rel_procs).difference(set(csi_rel_procs)):
64 | csi.add_processor(PROCESSOR_MAP[proc])
65 |
66 | if relations:
67 | for i in range(len(relations)):
68 | relations[i] = KnowledgeRelation.from_text(relations[i])
69 |
70 | linker = LINKER_MAP["deberta"]
71 |
72 | output_graph = csi.infer(
73 | text=text,
74 | model=model,
75 | heads=heads,
76 | relations=relations,
77 | extract_heads=extract_heads,
78 | match_relations=match_relations,
79 | dry_run=dry_run,
80 | context=context,
81 | threshold=threshold,
82 | linker=linker,
83 | )
84 |
85 | result = {"text": [], "graph": []}
86 |
87 | if output_graph:
88 | result["graph"] = [kg.to_json() for kg in output_graph]
89 |
90 | if text:
91 | doc = nlp(text)
92 | result["text"] = [token.lemma_.lower() for token in doc]
93 |
94 | return result
95 |
--------------------------------------------------------------------------------
/app/server/requirements.txt:
--------------------------------------------------------------------------------
1 | Flask
2 | kogito
3 | flask-cors
4 | gunicorn
--------------------------------------------------------------------------------
/app/server/server.py:
--------------------------------------------------------------------------------
1 | import os
2 | import traceback
3 |
4 | from flask import Flask, request, jsonify
5 | from flask_cors import CORS
6 | from inference import infer
7 |
8 | app = Flask(__name__)
9 | CORS(app)
10 |
11 |
12 | @app.route("/")
13 | def heartbeat():
14 | return "Running"
15 |
16 |
17 | @app.route("/inference", methods=["POST"])
18 | def inference():
19 | try:
20 | return jsonify(infer(request.json))
21 | except Exception as e:
22 | traceback.print_exc(e)
23 | return str(e), 500
24 |
25 |
26 | def main():
27 | port = int(os.environ.get("PORT", 8080))
28 | app.run(debug=os.environ.get("FLASK_DEBUG", False), host="0.0.0.0", port=port)
29 |
30 |
31 | if __name__ == "__main__":
32 | main()
33 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.5.1-cudnn8-runtime-ubuntu20.04@sha256:a6831f0d6328ea7301fa196ae2a376d2e67caae384af4ffd93fb196b527c0a0f
2 |
3 | ENV HOME=/root
4 | ENV CONDA_PREFIX=${HOME}/.conda
5 | ENV CONDA=${CONDA_PREFIX}/condabin/conda
6 | ENV KOGITO_DIR=${HOME}/kogito
7 |
8 | # Set default shell to /bin/bash
9 | SHELL ["/bin/bash", "-cu"]
10 |
11 | # Install dependencies
12 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends openssh-server vim wget unzip tmux git
13 |
14 | # Set up SSH server
15 | RUN mkdir /var/run/sshd
16 | RUN echo 'root:root' | chpasswd
17 | RUN sed -i 's/#*PermitRootLogin prohibit-password/PermitRootLogin yes/g' /etc/ssh/sshd_config
18 | # SSH login fix. Otherwise user is kicked off after login
19 | RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
20 | ENV NOTVISIBLE="in users profile"
21 | RUN echo "export VISIBLE=now" >> /etc/profile
22 | EXPOSE 22
23 |
24 | WORKDIR ${HOME}
25 |
26 | # Cluster setup
27 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh -O anaconda.sh
28 | RUN bash anaconda.sh -b -p ${CONDA_PREFIX}
29 | RUN ${CONDA} config --set auto_activate_base false
30 | RUN ${CONDA} init bash
31 | RUN echo "export LANG=en_US.UTF-8" >> ~/.bashrc
32 | RUN ${CONDA} create --name kogito -y python=3.10
33 | RUN ${CONDA} install -n kogito ipykernel --force-reinstall
34 |
35 | WORKDIR ${KOGITO_DIR}
36 |
37 | ARG ENV_TAINT=0
38 |
39 | # Setup project dependencies
40 | RUN ${CONDA} run -n kogito pip install kogito --extra-index-url https://download.pytorch.org/whl/cu116
41 | RUN ${CONDA} run -n kogito python -c "import nltk;nltk.download('punkt');nltk.download('wordnet');nltk.download('omw-1.4')"
42 | RUN ${CONDA} run -n kogito python -m spacy download en_core_web_sm
43 |
44 | ARG VERSION_TAINT=0
45 |
46 | # Setup data
47 | COPY ./data .
48 | COPY ./train.py .
49 | COPY ./train.sh .
50 |
51 | CMD ["/usr/sbin/sshd", "-D"]
52 | # CMD ["./train.sh"]
--------------------------------------------------------------------------------
/docker/build.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | export GITHUB_PERSONAL_TOKEN=${1}
4 | docker build -f Dockerfile -t ic-registry.epfl.ch/nlp/kogito --build-arg GITHUB_PERSONAL_TOKEN .
--------------------------------------------------------------------------------
/docker/dev.Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:11.5.1-cudnn8-runtime-ubuntu20.04@sha256:a6831f0d6328ea7301fa196ae2a376d2e67caae384af4ffd93fb196b527c0a0f
2 |
3 | ENV HOME=/root
4 | ENV CONDA_PREFIX=${HOME}/.conda
5 | ENV CONDA=${CONDA_PREFIX}/condabin/conda
6 | ENV KOGITO_DIR=${HOME}/kogito
7 | ENV POETRY=${HOME}/.local/bin/poetry
8 |
9 | # Set default shell to /bin/bash
10 | SHELL ["/bin/bash", "-cu"]
11 |
12 | # Install dependencies
13 | RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-packages --no-install-recommends openssh-server vim wget unzip tmux
14 |
15 | WORKDIR ${HOME}
16 |
17 | # Cluster setup
18 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py38_4.11.0-Linux-x86_64.sh -O anaconda.sh
19 | RUN bash anaconda.sh -b -p ${CONDA_PREFIX}
20 | RUN ${CONDA} config --set auto_activate_base false
21 | RUN ${CONDA} init bash
22 | RUN git config --global user.name "Mete Ismayil"
23 | RUN git config --global user.email "mismayilza@gmail.com"
24 | RUN git config pull.rebase false
25 | RUN echo "export LANG=en_US.UTF-8" >> ~/.bashrc
26 |
27 | # Setup kogito env
28 | RUN ${CONDA} create --name kogito -y python=3.8
29 | RUN ${CONDA} run -n kogito curl -sSL https://install.python-poetry.org | python3 -
30 | RUN ${CONDA} install -n kogito -y pytorch cudatoolkit=11.5 -c pytorch
31 |
32 | ARG GITHUB_PERSONAL_TOKEN
33 |
34 | # Clone kogito
35 | RUN git clone https://${GITHUB_PERSONAL_TOKEN}@github.com/epfl-nlp/kogito.git
36 |
37 | WORKDIR ${KOGITO_DIR}
38 |
39 | # Setup kogito dependencies
40 | RUN ${CONDA} run -n kogito ${POETRY} install
41 | RUN ${CONDA} run -n kogito python -c "import nltk;nltk.download('punkt');nltk.download('wordnet');nltk.download('omw-1.4')"
42 | RUN ${CONDA} run -n kogito python -m spacy download en_core_web_sm
43 |
44 | # Install training data
45 | ENV KOGITO_DATA_DIR=${KOGITO_DIR}/data
46 | RUN mkdir ${KOGITO_DATA_DIR}
47 | RUN wget https://ai2-atomic.s3-us-west-2.amazonaws.com/data/atomic2020_data-feb2021.zip
48 | RUN unzip atomic2020_data-feb2021.zip -d ${KOGITO_DATA_DIR}
49 |
50 | # Set up SSH server
51 | RUN apt-get update && apt-get install -y openssh-server tmux vim
52 | RUN mkdir /var/run/sshd
53 | RUN echo 'root:root' | chpasswd
54 | RUN sed -i 's/#*PermitRootLogin prohibit-password/PermitRootLogin yes/g' /etc/ssh/sshd_config
55 | # SSH login fix. Otherwise user is kicked off after login
56 | RUN sed -i 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' /etc/pam.d/sshd
57 | ENV NOTVISIBLE="in users profile"
58 | RUN echo "export VISIBLE=now" >> /etc/profile
59 | EXPOSE 22
60 |
61 | COPY ./train.py .
62 | COPY ./entrypoint.sh .
63 |
64 | ENTRYPOINT ["/usr/sbin/sshd", "-D"]
65 | # ENTRYPOINT ["./entrypoint.sh"]
--------------------------------------------------------------------------------
/docker/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | ${CONDA} run -n kogito python train.py
--------------------------------------------------------------------------------
/docker/push.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | docker push ic-registry.epfl.ch/nlp/kogito
--------------------------------------------------------------------------------
/docker/train.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.gpt2.comet import COMETGPT2
3 | import os
4 |
5 | if __name__ == "__main__":
6 | data_dir = os.environ.get("KOGITO_DATA_DIR")
7 | model = COMETGPT2("gpt2-xl")
8 | train_graph = KnowledgeGraph.from_csv(
9 | f"{data_dir}/atomic2020_data-feb2021/train.tsv", header=None, sep="\t"
10 | )
11 | val_graph = KnowledgeGraph.from_csv(
12 | f"{data_dir}/atomic2020_data-feb2021/dev.tsv", header=None, sep="\t"
13 | )
14 | model.train(
15 | train_graph=train_graph,
16 | val_graph=val_graph,
17 | batch_size=16,
18 | output_dir="/scratch/mete/models/comet-gpt2",
19 | log_wandb=True,
20 | lr=5e-5,
21 | epochs=1,
22 | )
23 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 | @import url('https://fonts.googleapis.com/css2?family=Inconsolata:wght@200;300;400;500;600;700;800;900&display=swap');
2 |
3 | .highlight-python .kn {
4 | font-weight: 900;
5 | }
6 |
7 | .highlight-python .nn {
8 | font-weight: 500;
9 | }
10 |
11 | .highlight-python .nc {
12 | font-weight: 500;
13 | }
14 |
15 | #logo-container h1::first-letter {
16 | background-color: #dce755;
17 | font-weight: 400;
18 | }
--------------------------------------------------------------------------------
/docs/api.rst:
--------------------------------------------------------------------------------
1 | =============
2 | API Reference
3 | =============
4 |
5 | Inference
6 | =========
7 |
8 | .. automodule:: kogito.inference
9 | :members:
10 | :special-members: __init__
11 |
12 | Head
13 | ====
14 |
15 | .. automodule:: kogito.core.head
16 | :members:
17 | :special-members: __init__
18 |
19 | Relation
20 | ========
21 |
22 | .. automodule:: kogito.core.relation
23 | :members:
24 | :special-members: __init__
25 |
26 | Knowledge
27 | =========
28 |
29 | .. automodule:: kogito.core.knowledge
30 | :members:
31 | :special-members: __init__
32 |
33 | Models
34 | ======
35 |
36 | .. automodule:: kogito.core.model
37 | :members:
38 | :special-members: __init__
39 |
40 | .. automodule:: kogito.models.bart.comet
41 | :members:
42 | :special-members: __init__
43 |
44 | .. automodule:: kogito.models.gpt2.comet
45 | :members:
46 | :special-members: __init__
47 |
48 | .. automodule:: kogito.models.gpt2.zeroshot
49 | :members:
50 | :special-members: __init__
51 |
52 | .. automodule:: kogito.models.gpt3.zeroshot
53 | :members:
54 | :special-members: __init__
55 |
56 | Processors
57 | ==========
58 |
59 | .. automodule:: kogito.core.processors.head
60 | :members:
61 | :special-members: __init__
62 |
63 | .. automodule:: kogito.core.processors.relation
64 | :members:
65 | :special-members: __init__
66 |
67 | Linkers
68 | =======
69 |
70 | .. automodule:: kogito.core.linker
71 | :members:
72 | :special-members: __init__
73 |
74 | .. automodule:: kogito.linkers.deberta
75 | :members:
76 | :special-members: __init__
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 | #
13 | import os
14 | import sys
15 |
16 | sys.path.insert(0, os.path.abspath(".."))
17 |
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = "kogito"
22 | copyright = "2022, Mete Ismayil"
23 | author = "Mete Ismayil"
24 |
25 |
26 | # -- General configuration ---------------------------------------------------
27 |
28 | # Add any Sphinx extension module names here, as strings. They can be
29 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
30 | # ones.
31 | extensions = [
32 | "insegel",
33 | "sphinx.ext.autodoc",
34 | "sphinx.ext.coverage",
35 | "sphinx.ext.napoleon",
36 | ]
37 |
38 | # Add any paths that contain templates here, relative to this directory.
39 | templates_path = ["_templates"]
40 |
41 | # List of patterns, relative to source directory, that match files and
42 | # directories to ignore when looking for source files.
43 | # This pattern also affects html_static_path and html_extra_path.
44 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
45 |
46 |
47 | # -- Options for HTML output -------------------------------------------------
48 |
49 | # The theme to use for HTML and HTML Help pages. See the documentation for
50 | # a list of builtin themes.
51 | #
52 | html_theme = "insegel"
53 |
54 | # Add any paths that contain custom static files (such as style sheets) here,
55 | # relative to this directory. They are copied after the builtin static files,
56 | # so a file named "default.css" will overwrite the builtin "default.css".
57 | html_static_path = ["_static"]
58 |
59 | html_css_files = [
60 | "css/custom.css",
61 | ]
62 |
63 | autodoc_member_order = "bysource"
64 |
--------------------------------------------------------------------------------
/docs/getting_started.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | Getting Started
3 | ===============
4 |
5 | Installation
6 | ============
7 |
8 | Installation with pip
9 | *********************
10 | **kogito** can be installed using pip.
11 |
12 | .. code-block:: shell
13 |
14 | pip install kogito
15 |
16 | It requires a minimum ``python`` version of ``3.8``.
17 |
18 | Setup
19 | =====
20 |
21 | Inference
22 | *********
23 | **kogito** uses `spacy `__ under the hood for various text processing purposes, so, a `spacy `__ language package has to be installed before running the inference module.
24 |
25 | .. code-block:: shell
26 |
27 | python -m spacy download en_core_web_sm
28 |
29 | By default, ``CommonsenseInference`` module uses ``en_core_web_sm`` to initialize `spacy `__ pipeline, but a different language pipeline can be specified as well.
30 |
31 | Evaluation
32 | **********
33 | If you also would like evaluate knowledge models using `METEOR `_ score, then you need to download the following `nltk `_ libraries:
34 |
35 | .. code-block:: python
36 |
37 | import nltk
38 |
39 | nltk.download("punkt")
40 | nltk.download("wordnet")
41 | nltk.download("omw-1.4")
42 |
43 |
44 | Quickstart
45 | ===========
46 | **kogito** provides an easy interface to interact with knowledge inference or commonsense reasoning models such as `COMET `__ to generate inferences from a text input.
47 | Here is a sample usage of the library where you can initialize an inference module, a custom commonsense reasoning model, and generate a knowledge graph from text on the fly.
48 |
49 | .. code-block:: python
50 |
51 | from kogito.models.bart.comet import COMETBART
52 | from kogito.inference import CommonsenseInference
53 |
54 | # Load pre-trained model from HuggingFace
55 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2")
56 |
57 | # Initialize inference module with a spacy language pipeline
58 | csi = CommonsenseInference(language="en_core_web_sm")
59 |
60 | # Run inference
61 | text = "PersonX becomes a great basketball player"
62 | kgraph = csi.infer(text, model)
63 |
64 | # Save output knowledge graph to JSON file
65 | kgraph.to_jsonl("kgraph.json")
66 |
67 |
68 | Here is an excerpt from the result of the above code sample:
69 |
70 | .. code-block:: json
71 |
72 | {"head": "PersonX becomes a great basketball player", "relation": "Causes", "tails": [" PersonX practices every day.", " PersonX plays basketball every day", " PersonX practices every day"]}
73 | {"head": "basketball", "relation": "ObjectUse", "tails": [" play with friends", " play basketball with", " play basketball"]}
74 | {"head": "player", "relation": "CapableOf", "tails": [" play game", " win game", " play football"]}
75 | {"head": "great basketball player", "relation": "HasProperty", "tails": [" good at basketball", " good at sports", " very good"]}
76 | {"head": "become player", "relation": "isAfter", "tails": [" play game", " become coach", " play with"]}
77 |
78 | This is just one way to generate commonsense inferences and **kogito** offers much more. For information on more use-cases and a complete API reference, you can check out the `User Guide `_ and `API Reference `_ pages.
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. kogito documentation master file, created by
2 | sphinx-quickstart on Wed Apr 13 11:36:35 2022.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | A Python Knowledge Inference Tool
7 | ======================================
8 |
9 | .. toctree::
10 | :maxdepth: 2
11 |
12 | getting_started
13 | userguide
14 | api
15 |
16 | .. Indices and tables
17 | .. ==================
18 |
19 | .. * :ref:`genindex`
20 | .. * :ref:`modindex`
21 | .. * :ref:`search`
22 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | %SPHINXBUILD% >NUL 2>NUL
14 | if errorlevel 9009 (
15 | echo.
16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
17 | echo.installed, then set the SPHINXBUILD environment variable to point
18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
19 | echo.may add the Sphinx directory to PATH.
20 | echo.
21 | echo.If you don't have Sphinx installed, grab it from
22 | echo.https://www.sphinx-doc.org/
23 | exit /b 1
24 | )
25 |
26 | if "%1" == "" goto help
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | insegel
2 | kogito
--------------------------------------------------------------------------------
/eacl2023/kogito-poster-eacl2023.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/eacl2023/kogito-poster-eacl2023.pdf
--------------------------------------------------------------------------------
/eacl2023/kogito-presentation-eacl2023.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/eacl2023/kogito-presentation-eacl2023.pdf
--------------------------------------------------------------------------------
/examples/data/atomic2020/README.md:
--------------------------------------------------------------------------------
1 | # ATOMIC 2020 Knowledge Graph
2 |
3 | All data `*.tsv` files are formatted as follows:
4 | - each line represents a distinct commonsense tuple
5 | - column 1: head node/concept
6 | - column 2: edge relation (e.g., xWant,xAttr,AtLocation)
7 | - column 3: tail node/concept
8 |
9 | `train.tsv`, `dev.tsv`, `test.tsv` correspond to train/dev/test splits.
10 |
11 |
12 | ## Paper
13 | Please cite the following work when using this data:
14 |
15 | > Jena D. Hwang, Chandra Bhagavatula, Ronan Le Bras, Jeff Da, Keisuke Sakaguchi, Antoine Bosselut, Yejin Choi (2021).
16 | > (Comet-) Atomic 2020: On Symbolic and Neural Commonsense Knowledge Graphs.
17 | > AAAI 2021
--------------------------------------------------------------------------------
/examples/data/atomic2020/sample_dev.tsv:
--------------------------------------------------------------------------------
1 | PersonX 'd better go oEffect none
2 | PersonX 'd better go oEffect none
3 | PersonX 'd better go oReact none
4 | PersonX 'd better go oReact none
5 | PersonX 'd better go oWant none
6 | PersonX 'd better go oWant none
7 | PersonX 'd better go oWant none
8 | PersonX 'd better go xAttr avoidant
9 | PersonX 'd better go xAttr weak
10 | PersonX 'd better go xAttr hurried
11 | PersonX 'd better go xAttr late
12 | PersonX 'd better go xAttr Tardy
13 | PersonX 'd better go xAttr busy
14 | PersonX 'd better go xEffect She ran to the bathroom
15 | PersonX 'd better go xEffect She finally made it
16 | PersonX 'd better go xEffect leaves
17 | PersonX 'd better go xEffect runs away
18 | PersonX 'd better go xIntent to go somewhere else more important.
19 | PersonX 'd better go xIntent none
20 | PersonX 'd better go xNeed none
21 | PersonX 'd better go xNeed none
22 | PersonX 'd better go xNeed none
23 | PersonX 'd better go xReact the person feels happy since he arrived at his destination.
24 | PersonX 'd better go xReact rushed, in a hurry
25 | PersonX 'd better go xWant to escape from him
26 | PersonX 'd better go xWant to resign his job
27 | PersonX 'd better go xWant to leave on time
28 | PersonX 'd better go xWant to arrive home
29 | PersonX 'd better go xWant to relax and unwind
30 | PersonX 'd better go xWant to walk away
31 | PersonX 'd better go xWant not speak to anyone
32 | PersonX accepts PersonX's diploma oEffect none
--------------------------------------------------------------------------------
/examples/data/atomic2020/sample_test.tsv:
--------------------------------------------------------------------------------
1 | PersonX abuses PersonX's power oEffect are told what to do
2 | PersonX abuses PersonX's power oEffect given unfair consequences or punishment
3 | PersonX abuses PersonX's power oEffect reach out for help
4 | PersonX abuses PersonX's power oEffect none
5 | PersonX abuses PersonX's power oReact humiliated
6 | PersonX abuses PersonX's power oReact sad
7 | PersonX abuses PersonX's power oReact angry
8 | PersonX abuses PersonX's power oReact cheated
9 | PersonX abuses PersonX's power oWant report PersonX to HR
10 | PersonX abuses PersonX's power oWant get PersonX fired
11 | PersonX abuses PersonX's power oWant to elect a new person
12 | PersonX abuses PersonX's power oWant to depose PersonX
13 | PersonX abuses PersonX's power oWant to impeach PrrsonX
14 | PersonX abuses PersonX's power oWant to punish person X
15 | PersonX abuses PersonX's power oWant to get rid of person X's atrocities
16 | PersonX abuses PersonX's power oWant to elect another leader
17 | PersonX abuses PersonX's power oWant to banish PersonX
18 | PersonX abuses PersonX's power oWant to remove PersonX
19 | PersonX abuses PersonX's power oWant to stop PersonX
20 | PersonX abuses PersonX's power oWant to depose PersonX
21 | PersonX abuses PersonX's power oWant to hurt PersonX
22 | PersonX abuses PersonX's power xAttr out of line
23 | PersonX abuses PersonX's power xAttr irresponsible
24 | PersonX abuses PersonX's power xAttr mean
25 | PersonX abuses PersonX's power xAttr confident
26 | PersonX abuses PersonX's power xAttr abusive
27 | PersonX abuses PersonX's power xAttr unreliable
28 | PersonX abuses PersonX's power xAttr abusive
29 | PersonX abuses PersonX's power xAttr careless
30 | PersonX abuses PersonX's power xEffect becomes authoratarian
31 | PersonX abuses PersonX's power xEffect is ostracized
32 | PersonX abuses PersonX's power xEffect is relieved of position
--------------------------------------------------------------------------------
/examples/data/atomic2020/sample_train.tsv:
--------------------------------------------------------------------------------
1 | PersonX abandons ___ altogether oEffect none
2 | PersonX abandons ___ altogether oEffect none
3 | PersonX abandons ___ altogether oReact dejected
4 | PersonX abandons ___ altogether oWant none
5 | PersonX abandons ___ altogether oWant none
6 | PersonX abandons ___ altogether oWant to find a new job for him
7 | PersonX abandons ___ altogether oWant to support him
8 | PersonX abandons ___ altogether xAttr impatient
9 | PersonX abandons ___ altogether xAttr decisive
10 | PersonX abandons ___ altogether xAttr undependable
11 | PersonX abandons ___ altogether xAttr fickle
12 | PersonX abandons ___ altogether xAttr destructed
13 | PersonX abandons ___ altogether xAttr sad
14 | PersonX abandons ___ altogether xEffect gets a reputation as a quitter
15 | PersonX abandons ___ altogether xEffect hangs head in shame
16 | PersonX abandons ___ altogether xEffect Begins the process of change
17 | PersonX abandons ___ altogether xEffect Turns over a new leaf
18 | PersonX abandons ___ altogether xIntent put a stop
19 | PersonX abandons ___ altogether xNeed Plows the field.
20 | PersonX abandons ___ altogether xNeed Gets exhausted from it.
21 | PersonX abandons ___ altogether xNeed none
22 | PersonX abandons ___ altogether xNeed to give a resignation letter
23 | PersonX abandons ___ altogether xNeed to get permission from his parents
24 | PersonX abandons ___ altogether xReact authoritative
25 | PersonX abandons ___ altogether xWant Sell his land.
26 | PersonX abandons ___ altogether xWant Was just city.
27 | PersonX abandons ___ altogether xWant to start something new
28 | PersonX abandons ___ altogether xWant to start fresh
29 | PersonX abandons ___ altogether xWant to find a new job
30 | PersonX abandons ___ altogether xWant to search for a new job
31 | PersonX abandons the ___ altogether oEffect none
32 | PersonX abandons the ___ altogether oEffect none
33 | PersonX abandons the ___ altogether oEffect none
34 | PersonX abandons the ___ altogether oReact defeat
35 | PersonX abandons the ___ altogether oWant none
36 | PersonX abandons the ___ altogether oWant to do something else as well
37 | PersonX abandons the ___ altogether oWant they find something better
38 | PersonX abandons the ___ altogether oWant none
39 | PersonX abandons the ___ altogether xAttr flaky
40 | PersonX abandons the ___ altogether xAttr irresponsible
41 | PersonX abandons the ___ altogether xAttr desperate
42 | PersonX abandons the ___ altogether xAttr convinced
43 | PersonX abandons the ___ altogether xAttr decisive
44 | PersonX abandons the ___ altogether xAttr frustrated
45 | PersonX abandons the ___ altogether xEffect eats all the cakes
46 | PersonX abandons the ___ altogether xEffect abandons his diets too
47 | PersonX abandons the ___ altogether xEffect repercussions for leaving all responsibilities
48 | PersonX abandons the ___ altogether xEffect they go home
49 | PersonX abandons the ___ altogether xEffect they try to form a different plan
50 | PersonX abandons the ___ altogether xEffect they search for a different alternative
51 | PersonX abandons the ___ altogether xIntent to appear not interested
52 | PersonX abandons the ___ altogether xNeed none
53 | PersonX abandons the ___ altogether xNeed to get frustrated
54 | PersonX abandons the ___ altogether xNeed to determine it's not worth it
55 | PersonX abandons the ___ altogether xNeed none
56 | PersonX abandons the ___ altogether xReact pressurized
57 | PersonX abandons the ___ altogether xWant to go out
58 | PersonX abandons the ___ altogether xWant to find other place
59 | PersonX abandons the ___ altogether xWant find something else to do
60 | PersonX abandons the ___ altogether xWant to do the project the best he can
61 | PersonX abandons the ___ altogether xWant sigh in relief
62 | PersonX abandons the ___ altogether xWant find another project
63 | PersonX abolishes ___ altogether oEffect none
64 | PersonX abolishes ___ altogether oEffect none
--------------------------------------------------------------------------------
/examples/evaluate_comet_bart.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.bart.comet import COMETBART
3 |
4 | input_graph = KnowledgeGraph.from_jsonl("test_atomic2020_sample.json")
5 |
6 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2")
7 | scores = model.evaluate(
8 | input_graph, batch_size=256, num_return_sequences=1
9 | )
10 | print(scores)
11 |
--------------------------------------------------------------------------------
/examples/evaluate_comet_gpt2.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.gpt2.comet import COMETGPT2
3 |
4 | input_graph = KnowledgeGraph.from_jsonl("test_atomic2020.json")
5 |
6 | model = COMETGPT2.from_pretrained("mismayil/comet-bart-ai2")
7 | scores = model.evaluate(input_graph)
8 |
9 | print(scores)
10 |
--------------------------------------------------------------------------------
/examples/results/cc_news/cc_news_samples_13.json:
--------------------------------------------------------------------------------
1 | {"head": "Washington", "relation": "ObjectUse", "tails": [" Washington, D.C.", " Washington, D.C", " Washington,D.C."]}
2 | {"head": "Washington", "relation": "AtLocation", "tails": [" city", " town", " cities"]}
3 | {"head": "Washington", "relation": "CapableOf", "tails": [" go to war", " go on vacation", " go to the beach"]}
4 | {"head": "Washington", "relation": "HasProperty", "tails": [" Washington, D.C.", " Washington, D.C", " Washington, D.C.,"]}
5 | {"head": "Washington", "relation": "MadeUpOf", "tails": [" Washington, D.C.", " Washington, D.C", " Washington, D.C.,"]}
6 | {"head": "Washington", "relation": "NotDesires", "tails": [" Washington, D.C.", " Washington, D.C", " Washington D.C."]}
7 | {"head": "Washington", "relation": "Desires", "tails": [" Washington, D.C.", " Washington, D.C", " Washington, D.C.,"]}
8 | {"head": "From Everett, Washington", "relation": "xAttr", "tails": [" adventurous", " sailing", " sailing ship"]}
9 | {"head": "From Everett, Washington", "relation": "xWant", "tails": [" to go to the beach", " to see the ocean", " to go home"]}
10 | {"head": "From Everett, Washington", "relation": "xNeed", "tails": [" to drive to the city", " to drive to the location", " to drive to the town"]}
11 | {"head": "From Everett, Washington", "relation": "xEffect", "tails": [" get to their destination", " get to the beach", " get to their destination."]}
12 | {"head": "From Everett, Washington", "relation": "HinderedBy", "tails": [" from there", " from a different city", " from a different state"]}
13 | {"head": "From Everett, Washington", "relation": "oWant", "tails": [" to go to the beach", " to see the sights", " to go to the movies"]}
14 | {"head": "From Everett, Washington", "relation": "xReact", "tails": [" happy", " happy to be there", " happy to be there."]}
15 | {"head": "From Everett, Washington", "relation": "oEffect", "tails": [" none", " travel", " travel to"]}
16 | {"head": "From Everett, Washington", "relation": "xIntent", "tails": [" to see the ocean", " to go to the beach", " to visit family"]}
17 | {"head": "From Everett, Washington", "relation": "oReact", "tails": [" none", " like they have a place to live", " like they have a place to go"]}
18 | {"head": "From Everett, Washington", "relation": "isBefore", "tails": [" from there to there", " from there", " from there to their home"]}
19 | {"head": "From Everett, Washington", "relation": "isAfter", "tails": [" from Seattle to Everett", " from there to there", " from there"]}
20 | {"head": "From Everett, Washington", "relation": "HasSubEvent", "tails": [" get to their destination", " get to the beach", " get off plane"]}
21 | {"head": "From Everett, Washington", "relation": "Causes", "tails": [" from there", " from a different state", " from a different place"]}
22 | {"head": "From Everett, Washington", "relation": "xReason", "tails": [" from there to there", " from there", " from the ocean"]}
23 | {"head": "Everett", "relation": "ObjectUse", "tails": [" get a new job", " have a good time", " get a job"]}
24 | {"head": "Everett", "relation": "AtLocation", "tails": [" motel room", " school", " town"]}
25 | {"head": "Everett", "relation": "CapableOf", "tails": [" go to sleep", " go to bed", " go to the store"]}
26 | {"head": "Everett", "relation": "HasProperty", "tails": [" none", " one of the two brothers", " one of the three brothers"]}
27 | {"head": "Everett", "relation": "MadeUpOf", "tails": [" ", " ", " "]}
28 | {"head": "Everett", "relation": "NotDesires", "tails": [" get in car accident", " none", " have to go to work"]}
29 | {"head": "Everett", "relation": "Desires", "tails": [" none", " get a new job", " get a job"]}
--------------------------------------------------------------------------------
/examples/results/cc_news/cc_news_samples_41.json:
--------------------------------------------------------------------------------
1 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xAttr", "tails": [" responsible", " energy-hungry", " energy-generating"]}
2 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xWant", "tails": [" to make money", " to sell the company", " to sell the business"]}
3 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xNeed", "tails": [" to invest in the company", " to invest in a company", " to invest in the business"]}
4 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xEffect", "tails": [" Gcl New Energy has a lot of debt.", " Gcl New Energy has a lot of money.", " Gcl New Energy has no money."]}
5 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "HinderedBy", "tails": [" Gcl New Energy is not a company.", " Gcl New Energy is not profitable.", " Gcl New Energy has no money."]}
6 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "oWant", "tails": [" none", " to make money", " to sell the company"]}
7 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xReact", "tails": [" good about themselves", " happy", " satisfied"]}
8 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "oEffect", "tails": [" none", " Gcl New Energy has a lot of debt.", " Gcl New Energy has no money."]}
9 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xIntent", "tails": [" to make money", " to make money.", " to make a profit"]}
10 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "oReact", "tails": [" none", " like they have something to do", " like they have a good company"]}
11 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "isBefore", "tails": [" Gcl New Energy", " Gcl New Energy Holdings", " Gcl New Energy is a company"]}
12 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "isAfter", "tails": [" Gcl New Energy", " Gcl New Energy Holdings", " Gcl New Energy Holdings Ltd"]}
13 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "HasSubEvent", "tails": [" Gcl New Energy gets a loan from the bank.", " Gcl New Energy gets a loan from a bank.", " Gcl New Energy gets a loan from the bank"]}
14 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "Causes", "tails": [" Gcl New Energy has no money.", " Gcl New Energy Holdings is not a company.", " Gcl New Energy Holdings is not a real company."]}
15 | {"head": "Gcl New Energy Holdings Ltd :", "relation": "xReason", "tails": [" Gcl New Energy is a private company.", " Gcl New Energy", " Gcl New Energy Holdings"]}
16 | {"head": "Gcl New Energy Holdings Ltd", "relation": "ObjectUse", "tails": [" to make money", " power plants", " to make a profit"]}
17 | {"head": "Gcl New Energy Holdings Ltd", "relation": "AtLocation", "tails": [" energy company", " electricity", " company"]}
18 | {"head": "Gcl New Energy Holdings Ltd", "relation": "CapableOf", "tails": [" get a loan", " power plants", " buy electricity"]}
19 | {"head": "Gcl New Energy Holdings Ltd", "relation": "HasProperty", "tails": [" Gcl New Energy Holdings", " Gcl New Energy", " Gcl New Energy is a company."]}
20 | {"head": "Gcl New Energy Holdings Ltd", "relation": "MadeUpOf", "tails": [" power plant", " electricity", " energy"]}
21 | {"head": "Gcl New Energy Holdings Ltd", "relation": "NotDesires", "tails": [" Gcl New Energy is a public company.", " Gcl New Energy is a private company.", " Gcl New Energy"]}
22 | {"head": "Gcl New Energy Holdings Ltd", "relation": "Desires", "tails": [" to make money", " none", " to make a profit"]}
--------------------------------------------------------------------------------
/examples/results/cc_news/cc_news_samples_65.json:
--------------------------------------------------------------------------------
1 | {"head": "LONDON LEAGUE", "relation": "xAttr", "tails": [" competitive", " athletic", " sporty"]}
2 | {"head": "LONDON LEAGUE", "relation": "xWant", "tails": [" to win the game", " to win the championship", " to play a game"]}
3 | {"head": "LONDON LEAGUE", "relation": "xNeed", "tails": [" find a team to play in", " find a place to play", " find a team to play for"]}
4 | {"head": "LONDON LEAGUE", "relation": "xEffect", "tails": [" wins the league", " wins the game", " wins the championship"]}
5 | {"head": "LONDON LEAGUE", "relation": "HinderedBy", "tails": [" LONDON LEAGUE", " LONDON LEAGUE ", " LONDON LEAGUE "]}
6 | {"head": "LONDON LEAGUE", "relation": "oWant", "tails": [" to win the game", " to play in the league", " to play a game"]}
7 | {"head": "LONDON LEAGUE", "relation": "xReact", "tails": [" happy", " proud", " good"]}
8 | {"head": "LONDON LEAGUE", "relation": "oEffect", "tails": [" none", " play in", " play football"]}
9 | {"head": "LONDON LEAGUE", "relation": "xIntent", "tails": [" to be competitive", " to play football", " to play in"]}
10 | {"head": "LONDON LEAGUE", "relation": "oReact", "tails": [" none", " happy", " happy."]}
11 | {"head": "LONDON LEAGUE", "relation": "isBefore", "tails": [" LONDON LEAGUE", " LONDON FOOTBALL LEAGUE", " LONDON LONDON LEAGUE"]}
12 | {"head": "LONDON LEAGUE", "relation": "isAfter", "tails": [" LONDON LEAGUE", " LONDON FOOTBALL LEAGUE", " LONDON LONDON LEAGUE"]}
13 | {"head": "LONDON LEAGUE", "relation": "HasSubEvent", "tails": [" win the game", " play in", " win the match"]}
14 | {"head": "LONDON LEAGUE", "relation": "Causes", "tails": [" LONDON LEAGUE", " LONDON LEAGUE.", " LONDON LEAGUE "]}
15 | {"head": "LONDON LEAGUE", "relation": "xReason", "tails": [" LONDON LEAGUE", " LONDON LONDON LEAGUE", " LONDON LEAGUE."]}
--------------------------------------------------------------------------------
/examples/results/cc_news/cc_news_samples_67.json:
--------------------------------------------------------------------------------
1 | {"head": "to find", "relation": "HinderedBy", "tails": [" to find ", " to find", " to find to find "]}
2 | {"head": "to find", "relation": "isBefore", "tails": [" to find", " to find something", " to find "]}
3 | {"head": "to find", "relation": "isAfter", "tails": [" to find something", " to find something to eat", " to find"]}
4 | {"head": "to find", "relation": "HasSubEvent", "tails": [" to find", " to find ", " to look for it"]}
5 | {"head": "to find", "relation": "Causes", "tails": [" to find", " to find ", " to find something"]}
6 | {"head": "to find", "relation": "xReason", "tails": [" to find", " to find something", " to find "]}
7 | {"head": "to decompose", "relation": "HinderedBy", "tails": [" to decompose", " to decompose ", " to decompose to decompose "]}
8 | {"head": "to decompose", "relation": "isBefore", "tails": [" to decompose", " to decompose in the ground", " to decompose in a lab"]}
9 | {"head": "to decompose", "relation": "isAfter", "tails": [" to decompose", " to decompose ", " decompose"]}
10 | {"head": "to decompose", "relation": "HasSubEvent", "tails": [" to decompose", " to decompose ", " to decompose in the ground"]}
11 | {"head": "to decompose", "relation": "Causes", "tails": [" to decompose", " to decompose ", " decomposing"]}
12 | {"head": "to decompose", "relation": "xReason", "tails": [" to decompose", " to decompose ", " to decompose in a lab"]}
13 | {"head": "body", "relation": "ObjectUse", "tails": [" have sex with", " body parts", " have sex"]}
14 | {"head": "body", "relation": "AtLocation", "tails": [" body", " medicine chest", " body parts"]}
15 | {"head": "body", "relation": "CapableOf", "tails": [" lie down on bed", " lie on bed", " lie down in bed"]}
16 | {"head": "body", "relation": "HasProperty", "tails": [" found in body", " body", " found in human body"]}
17 | {"head": "body", "relation": "MadeUpOf", "tails": [" body", " body part", " body parts"]}
18 | {"head": "body", "relation": "NotDesires", "tails": [" dead", " death", " body parts"]}
19 | {"head": "body", "relation": "Desires", "tails": [" body", " body parts", " feel good"]}
20 | {"head": "badly decomposed body", "relation": "ObjectUse", "tails": [" hide in the woods", " kill someone", " hide in a closet"]}
21 | {"head": "badly decomposed body", "relation": "AtLocation", "tails": [" dead body", " body bag", " corpse"]}
22 | {"head": "badly decomposed body", "relation": "CapableOf", "tails": [" dead body", " kill yourself", " bury in grave"]}
23 | {"head": "badly decomposed body", "relation": "HasProperty", "tails": [" embalmed", " dead body", " dead"]}
24 | {"head": "badly decomposed body", "relation": "MadeUpOf", "tails": [" embalmed", " dead body", " embalmed body"]}
25 | {"head": "badly decomposed body", "relation": "NotDesires", "tails": [" embalmed", " dead body", " bad smell"]}
26 | {"head": "badly decomposed body", "relation": "Desires", "tails": [" dead body", " death", " bad smell"]}
27 | {"head": "The badly decomposed body was found near...", "relation": "xAttr", "tails": [" dead", " morbid", " sadistic"]}
28 | {"head": "The badly decomposed body was found near...", "relation": "xWant", "tails": [" to bury the body", " to bury the body.", " to bury the dead body"]}
29 | {"head": "The badly decomposed body was found near...", "relation": "xNeed", "tails": [" find a body", " find a dead body", " none"]}
30 | {"head": "The badly decomposed body was found near...", "relation": "xEffect", "tails": [" is buried in a grave", " is buried", " the body is found"]}
31 | {"head": "The badly decomposed body was found near...", "relation": "HinderedBy", "tails": [" The body was found in the woods.", " The body was found in the middle of the road.", " The body was found in the middle of nowhere."]}
32 | {"head": "The badly decomposed body was found near...", "relation": "oWant", "tails": [" to bury the body", " none", " to bury the body."]}
33 | {"head": "The badly decomposed body was found near...", "relation": "xReact", "tails": [" sad", " dead", " scared"]}
34 | {"head": "The badly decomposed body was found near...", "relation": "oEffect", "tails": [" none", " the police are called to the scene", " the police find the body"]}
35 | {"head": "The badly decomposed body was found near...", "relation": "xIntent", "tails": [" none", " the body to be found", " the body to be buried"]}
36 | {"head": "The badly decomposed body was found near...", "relation": "oReact", "tails": [" none", " sad", " dead"]}
37 | {"head": "The badly decomposed body was found near...", "relation": "isBefore", "tails": [" the body to be buried", " the body to be found", " the body is buried"]}
38 | {"head": "The badly decomposed body was found near...", "relation": "isAfter", "tails": [" the body to be buried", " the body to be found", " the body was found"]}
39 | {"head": "The badly decomposed body was found near...", "relation": "HasSubEvent", "tails": [" the body to be found", " the body to be buried", " find a body"]}
40 | {"head": "The badly decomposed body was found near...", "relation": "Causes", "tails": [" the body to be found", " the body to be buried", " the body was found"]}
41 | {"head": "The badly decomposed body was found near...", "relation": "xReason", "tails": [" the body to be buried", " the body to be found", " the body was found"]}
--------------------------------------------------------------------------------
/examples/results/cc_news/cc_news_samples_7.json:
--------------------------------------------------------------------------------
1 | {"head": "Mylan NV:", "relation": "xAttr", "tails": [" hardworking", " hard-working", " hard working"]}
2 | {"head": "Mylan NV:", "relation": "xWant", "tails": [" to make money", " to sell the medicine", " to sell the product"]}
3 | {"head": "Mylan NV:", "relation": "xNeed", "tails": [" to get a prescription", " to buy the drug", " to have a prescription"]}
4 | {"head": "Mylan NV:", "relation": "xEffect", "tails": [" gets a prescription", " get a prescription", " get a prescription filled"]}
5 | {"head": "Mylan NV:", "relation": "HinderedBy", "tails": [" The pharmacy is out of stock.", " The drug is too expensive.", " The pharmacy is closed."]}
6 | {"head": "Mylan NV:", "relation": "oWant", "tails": [" none", " to get rid of it", " to get rid of the drug"]}
7 | {"head": "Mylan NV:", "relation": "xReact", "tails": [" happy", " satisfied", " good"]}
8 | {"head": "Mylan NV:", "relation": "oEffect", "tails": [" none", " they get a prescription for it", " they get a prescription for the medicine"]}
9 | {"head": "Mylan NV:", "relation": "xIntent", "tails": [" none", " to get rid of a disease", " to get rid of the pain"]}
10 | {"head": "Mylan NV:", "relation": "oReact", "tails": [" none", " happy", " happy."]}
11 | {"head": "Mylan NV:", "relation": "isBefore", "tails": [" drugstore", " drug company", " drug"]}
12 | {"head": "Mylan NV:", "relation": "isAfter", "tails": [" drugstore", " medicine", " drug store"]}
13 | {"head": "Mylan NV:", "relation": "HasSubEvent", "tails": [" get a prescription", " get a prescription filled", " get a prescription for it"]}
14 | {"head": "Mylan NV:", "relation": "Causes", "tails": [" the medicine is not working", " the drug is not working", " the medicine to be safe"]}
15 | {"head": "Mylan NV:", "relation": "xReason", "tails": [" have to pay for it", " have to pay for the medication", " have to pay for the medicine"]}
16 | {"head": "Mylan NV", "relation": "ObjectUse", "tails": [" make a pill", " make pills", " sell to customers"]}
17 | {"head": "Mylan NV", "relation": "AtLocation", "tails": [" medicine chest", " drugstore", " drug store"]}
18 | {"head": "Mylan NV", "relation": "CapableOf", "tails": [" sell to customers", " sell drugs", " sell to others"]}
19 | {"head": "Mylan NV", "relation": "HasProperty", "tails": [" sold in the United States", " sold in the US", " good for heart health"]}
20 | {"head": "Mylan NV", "relation": "MadeUpOf", "tails": [" drug", " medicine", " medicine chest"]}
21 | {"head": "Mylan NV", "relation": "NotDesires", "tails": [" have to pay for it", " have to pay for drugs", " have to pay for the product"]}
22 | {"head": "Mylan NV", "relation": "Desires", "tails": [" to be a drug company", " to be able to take medicine", " to be able to sell drugs"]}
--------------------------------------------------------------------------------
/examples/results/gpt3_prompt_2c976a77.txt:
--------------------------------------------------------------------------------
1 | What needs to be true for this event to take place?
2 |
3 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule
4 |
5 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y
6 |
7 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field.
8 |
9 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY
10 |
11 | PersonX always drank. Before, PersonX needs to to get a drink
12 |
13 | PersonX about to get married. Before, PersonX needs to meet someone
14 |
15 | PersonX accidentally fell. Before, PersonX needs to to walk too fast
16 |
17 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast
18 |
19 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date
20 |
21 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal
22 |
23 | PersonX aces PersonX's exam. Before, PersonX needs to
24 |
25 | What needs to be true for this event to take place?
26 |
27 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule
28 |
29 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y
30 |
31 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field.
32 |
33 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY
34 |
35 | PersonX always drank. Before, PersonX needs to to get a drink
36 |
37 | PersonX about to get married. Before, PersonX needs to meet someone
38 |
39 | PersonX accidentally fell. Before, PersonX needs to to walk too fast
40 |
41 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast
42 |
43 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date
44 |
45 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal
46 |
47 | PersonX accuses PersonY of cheating. Before, PersonX needs to
--------------------------------------------------------------------------------
/examples/results/gpt3_prompt_56fcb840.txt:
--------------------------------------------------------------------------------
1 | What needs to be true for this event to take place?
2 |
3 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY
4 |
5 | PersonX about to get married. Before, PersonX needs to meet someone
6 |
7 | PersonX always drank. Before, PersonX needs to to get a drink
8 |
9 | PersonX accidentally fell. Before, PersonX needs to to walk too fast
10 |
11 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date
12 |
13 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule
14 |
15 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y
16 |
17 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field.
18 |
19 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast
20 |
21 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal
22 |
23 | PersonX sees PersonY's point. Before, PersonX needs to
24 |
25 | What needs to be true for this event to take place?
26 |
27 | PersonX accompanies PersonY far. Before, PersonX needs to to walk with PersonY
28 |
29 | PersonX about to get married. Before, PersonX needs to meet someone
30 |
31 | PersonX always drank. Before, PersonX needs to to get a drink
32 |
33 | PersonX accidentally fell. Before, PersonX needs to to walk too fast
34 |
35 | PersonX accepts PersonY's proposal. Before, PersonX needs to to date
36 |
37 | PersonX accepts PersonY appointment. Before, PersonX needs to to clear spot in schedule
38 |
39 | PersonX accepts PersonY thanks. Before, PersonX needs to to says thanks to Y
40 |
41 | PersonX abandons ___ altogether. Before, PersonX needs to Plows the field.
42 |
43 | PersonX accidentally bumped. Before, PersonX needs to run or drive fast
44 |
45 | PersonX accomplishes PersonX's goal. Before, PersonX needs to to make a goal
46 |
47 | PersonX makes a huge mistake. Before, PersonX needs to
--------------------------------------------------------------------------------
/examples/results/gpt3_prompt_99e17fae.txt:
--------------------------------------------------------------------------------
1 | How does this situation affect each character's wishes?
2 |
3 | Situation 1: PersonX calls PersonY
4 | Wishes: As a result, PersonX wishes to have a long chat
5 |
6 | Situation 2: PersonX bleeds a lot
7 | Wishes: As a result, PersonX wishes to see a doctor
8 |
9 | Situation 3: PersonX works as a cashier
10 | Wishes: As a result, PersonX wishes to be a store manager
11 |
12 | Situation 4: PersonX stays up all night studying
13 | Wishes: As a result, PersonX wishes to sleep all day
14 |
15 | Situation 5: PersonX makes his own costume
16 | Wishes: As a result, PersonX wishes to go to a costume party
17 |
18 | Situation 6: PersonX tells PersonY a secret
19 | Wishes: As a result, PersonX wishes to get PersonY's advice
20 |
21 | Situation 7: PersonX gets dirty
22 | Wishes: As a result, PersonX wishes to clean up
23 |
24 | Situation 8: PersonX ends a friendship
25 | Wishes: As a result, PersonX wishes to meet new people
26 |
27 | Situation 9: PersonX gets PersonY's autograph
28 | Wishes: As a result, PersonX wishes to have a relationship with PersonY
29 |
30 | Situation 10: PersonX mows the lawn
31 | Wishes: As a result, PersonX wishes to get a new lawnmower
32 |
33 | Situation 11: PersonX sees PersonY's point
34 | Wishes: As a result, PersonX wishes
35 |
36 | How does this situation affect each character's wishes?
37 |
38 | Situation 1: PersonX calls PersonY
39 | Wishes: As a result, PersonX wishes to have a long chat
40 |
41 | Situation 2: PersonX bleeds a lot
42 | Wishes: As a result, PersonX wishes to see a doctor
43 |
44 | Situation 3: PersonX works as a cashier
45 | Wishes: As a result, PersonX wishes to be a store manager
46 |
47 | Situation 4: PersonX stays up all night studying
48 | Wishes: As a result, PersonX wishes to sleep all day
49 |
50 | Situation 5: PersonX makes his own costume
51 | Wishes: As a result, PersonX wishes to go to a costume party
52 |
53 | Situation 6: PersonX tells PersonY a secret
54 | Wishes: As a result, PersonX wishes to get PersonY's advice
55 |
56 | Situation 7: PersonX gets dirty
57 | Wishes: As a result, PersonX wishes to clean up
58 |
59 | Situation 8: PersonX ends a friendship
60 | Wishes: As a result, PersonX wishes to meet new people
61 |
62 | Situation 9: PersonX gets PersonY's autograph
63 | Wishes: As a result, PersonX wishes to have a relationship with PersonY
64 |
65 | Situation 10: PersonX mows the lawn
66 | Wishes: As a result, PersonX wishes to get a new lawnmower
67 |
68 | Situation 11: PersonX makes a huge mistake
69 | Wishes: As a result, PersonX wishes
--------------------------------------------------------------------------------
/examples/results/kgraph_dry_run.json:
--------------------------------------------------------------------------------
1 | {"head": "Gabby always brought cookies to school.", "relation": "xWant", "tails": []}
2 | {"head": "Gabby always brought cookies to school.", "relation": "HinderedBy", "tails": []}
3 | {"head": "school", "relation": "HasProperty", "tails": []}
4 | {"head": "cookies", "relation": "CapableOf", "tails": []}
5 | {"head": "bring cookies", "relation": "isAfter", "tails": []}
6 | {"head": "Gabby", "relation": "xNeed", "tails": []}
7 | {"head": "Gabby", "relation": "AtLocation", "tails": []}
8 | {"head": "school", "relation": "Causes", "tails": []}
9 | {"head": "Gabby", "relation": "MadeUpOf", "tails": []}
10 | {"head": "to bring", "relation": "Causes", "tails": []}
11 | {"head": "Gabby always brought cookies to school.", "relation": "xAttr", "tails": []}
12 | {"head": "to bring", "relation": "isFilledBy", "tails": []}
13 | {"head": "school", "relation": "ObjectUse", "tails": []}
14 | {"head": "cookies", "relation": "MadeUpOf", "tails": []}
15 | {"head": "to bring", "relation": "xReason", "tails": []}
16 | {"head": "cookies", "relation": "AtLocation", "tails": []}
17 | {"head": "cookies", "relation": "xNeed", "tails": []}
18 | {"head": "school", "relation": "NotDesires", "tails": []}
19 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []}
20 | {"head": "Gabby always brought cookies to school.", "relation": "isFilledBy", "tails": []}
21 | {"head": "school", "relation": "CapableOf", "tails": []}
22 | {"head": "to bring", "relation": "isBefore", "tails": []}
23 | {"head": "Gabby always brought cookies to school.", "relation": "xReact", "tails": []}
24 | {"head": "Gabby always brought cookies to school.", "relation": "xReason", "tails": []}
25 | {"head": "bring cookies", "relation": "HinderedBy", "tails": []}
26 | {"head": "to bring", "relation": "HasSubEvent", "tails": []}
27 | {"head": "Gabby always brought cookies to school.", "relation": "isBefore", "tails": []}
28 | {"head": "school", "relation": "AtLocation", "tails": []}
29 | {"head": "school", "relation": "xNeed", "tails": []}
30 | {"head": "Gabby always brought cookies to school.", "relation": "oReact", "tails": []}
31 | {"head": "to bring", "relation": "xNeed", "tails": []}
32 | {"head": "school", "relation": "MadeUpOf", "tails": []}
33 | {"head": "to bring", "relation": "isAfter", "tails": []}
34 | {"head": "Gabby always brought cookies to school.", "relation": "HasSubEvent", "tails": []}
35 | {"head": "Gabby", "relation": "Desires", "tails": []}
36 | {"head": "Gabby always brought cookies to school.", "relation": "xIntent", "tails": []}
37 | {"head": "Gabby always brought cookies to school.", "relation": "xNeed", "tails": []}
38 | {"head": "Gabby", "relation": "HasProperty", "tails": []}
39 | {"head": "Gabby always brought cookies to school.", "relation": "xEffect", "tails": []}
40 | {"head": "bring cookies", "relation": "Causes", "tails": []}
41 | {"head": "bring cookies", "relation": "isFilledBy", "tails": []}
42 | {"head": "Gabby always brought cookies to school.", "relation": "isAfter", "tails": []}
43 | {"head": "bring cookies", "relation": "xReason", "tails": []}
44 | {"head": "Gabby", "relation": "Causes", "tails": []}
45 | {"head": "cookies", "relation": "Desires", "tails": []}
46 | {"head": "bring cookies", "relation": "isBefore", "tails": []}
47 | {"head": "cookies", "relation": "HasProperty", "tails": []}
48 | {"head": "Gabby", "relation": "ObjectUse", "tails": []}
49 | {"head": "cookies", "relation": "Causes", "tails": []}
50 | {"head": "Gabby", "relation": "NotDesires", "tails": []}
51 | {"head": "Gabby always brought cookies to school.", "relation": "oEffect", "tails": []}
52 | {"head": "bring cookies", "relation": "xNeed", "tails": []}
53 | {"head": "cookies", "relation": "ObjectUse", "tails": []}
54 | {"head": "to bring", "relation": "HinderedBy", "tails": []}
55 | {"head": "bring cookies", "relation": "HasSubEvent", "tails": []}
56 | {"head": "Gabby", "relation": "CapableOf", "tails": []}
57 | {"head": "cookies", "relation": "NotDesires", "tails": []}
58 | {"head": "school", "relation": "Desires", "tails": []}
59 | {"head": "Gabby always brought cookies to school.", "relation": "oWant", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_dry_run_2.json:
--------------------------------------------------------------------------------
1 | {"head": "he didnt listen to me", "relation": "xReason", "tails": []}
2 | {"head": "he didnt listen to me", "relation": "HinderedBy", "tails": []}
3 | {"head": "to want", "relation": "xNeed", "tails": []}
4 | {"head": "he didnt listen to me", "relation": "xNeed", "tails": []}
5 | {"head": "to listen", "relation": "HinderedBy", "tails": []}
6 | {"head": "I wanted to feed him.", "relation": "oEffect", "tails": []}
7 | {"head": "I wanted to feed him.", "relation": "xIntent", "tails": []}
8 | {"head": "to listen", "relation": "xReason", "tails": []}
9 | {"head": "I wanted to feed him.", "relation": "isAfter", "tails": []}
10 | {"head": "I", "relation": "HasProperty", "tails": []}
11 | {"head": "to listen", "relation": "xNeed", "tails": []}
12 | {"head": "I wanted to feed him.", "relation": "oReact", "tails": []}
13 | {"head": "he didnt listen to me", "relation": "oEffect", "tails": []}
14 | {"head": "I", "relation": "MadeUpOf", "tails": []}
15 | {"head": "I", "relation": "ObjectUse", "tails": []}
16 | {"head": "I", "relation": "NotDesires", "tails": []}
17 | {"head": "to want", "relation": "isAfter", "tails": []}
18 | {"head": "he didnt listen to me", "relation": "xIntent", "tails": []}
19 | {"head": "I wanted to feed him.", "relation": "HasSubEvent", "tails": []}
20 | {"head": "I wanted to feed him.", "relation": "xEffect", "tails": []}
21 | {"head": "I wanted to feed him.", "relation": "isFilledBy", "tails": []}
22 | {"head": "he didnt listen to me", "relation": "isAfter", "tails": []}
23 | {"head": "he didnt listen to me", "relation": "oReact", "tails": []}
24 | {"head": "I", "relation": "CapableOf", "tails": []}
25 | {"head": "I", "relation": "Desires", "tails": []}
26 | {"head": "to listen", "relation": "isAfter", "tails": []}
27 | {"head": "to want", "relation": "HasSubEvent", "tails": []}
28 | {"head": "he didnt listen to me", "relation": "xEffect", "tails": []}
29 | {"head": "to want", "relation": "isFilledBy", "tails": []}
30 | {"head": "I wanted to feed him.", "relation": "isBefore", "tails": []}
31 | {"head": "he didnt listen to me", "relation": "HasSubEvent", "tails": []}
32 | {"head": "he didnt listen to me", "relation": "isFilledBy", "tails": []}
33 | {"head": "I", "relation": "AtLocation", "tails": []}
34 | {"head": "to want", "relation": "isBefore", "tails": []}
35 | {"head": "I wanted to feed him.", "relation": "xReact", "tails": []}
36 | {"head": "to listen", "relation": "HasSubEvent", "tails": []}
37 | {"head": "to listen", "relation": "isFilledBy", "tails": []}
38 | {"head": "I", "relation": "Causes", "tails": []}
39 | {"head": "he didnt listen to me", "relation": "isBefore", "tails": []}
40 | {"head": "I wanted to feed him.", "relation": "oWant", "tails": []}
41 | {"head": "I wanted to feed him.", "relation": "Causes", "tails": []}
42 | {"head": "to listen", "relation": "isBefore", "tails": []}
43 | {"head": "I wanted to feed him.", "relation": "xAttr", "tails": []}
44 | {"head": "I wanted to feed him.", "relation": "xWant", "tails": []}
45 | {"head": "he didnt listen to me", "relation": "xReact", "tails": []}
46 | {"head": "to want", "relation": "Causes", "tails": []}
47 | {"head": "he didnt listen to me", "relation": "oWant", "tails": []}
48 | {"head": "he didnt listen to me", "relation": "Causes", "tails": []}
49 | {"head": "I wanted to feed him.", "relation": "xReason", "tails": []}
50 | {"head": "I wanted to feed him.", "relation": "HinderedBy", "tails": []}
51 | {"head": "I", "relation": "xNeed", "tails": []}
52 | {"head": "he didnt listen to me", "relation": "xAttr", "tails": []}
53 | {"head": "to listen", "relation": "Causes", "tails": []}
54 | {"head": "he didnt listen to me", "relation": "xWant", "tails": []}
55 | {"head": "to want", "relation": "xReason", "tails": []}
56 | {"head": "to want", "relation": "HinderedBy", "tails": []}
57 | {"head": "I wanted to feed him.", "relation": "xNeed", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_gpt3.json:
--------------------------------------------------------------------------------
1 | {"head": "PersonX aces PersonX's exam", "relation": "xNeed", "tails": [" study"]}
2 | {"head": "PersonX accuses PersonY of cheating", "relation": "xNeed", "tails": [" to find out for sure"]}
--------------------------------------------------------------------------------
/examples/results/kgraph_gpt3_custom_relation.json:
--------------------------------------------------------------------------------
1 | {"head": "PersonX sees PersonY's point", "relation": "xNeed", "tails": [" to listen to PersonY"]}
2 | {"head": "PersonX sees PersonY's point", "relation": "xWishes", "tails": [" to apologize to PersonY"]}
3 | {"head": "PersonX makes a huge mistake", "relation": "xNeed", "tails": [""]}
4 | {"head": "PersonX makes a huge mistake", "relation": "xWishes", "tails": [" to fix the mistake"]}
--------------------------------------------------------------------------------
/examples/results/kgraph_gpt3_dry_run.json:
--------------------------------------------------------------------------------
1 | {"head": "PersonX accuses PersonY of cheating", "relation": "xNeed", "tails": ["\n\n1. Check if there is any rule in the game that PersonY"]}
2 | {"head": "PersonX aces PersonX's exam", "relation": "xNeed", "tails": ["\n\nstudy."]}
--------------------------------------------------------------------------------
/examples/results/kgraph_manual.json:
--------------------------------------------------------------------------------
1 | {"head": "Gabby always brought cookies to school.", "relation": "Desires", "tails": []}
2 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_manual_heads.json:
--------------------------------------------------------------------------------
1 | {"head": "post office", "relation": "xReason", "tails": []}
2 | {"head": "to get out of the room", "relation": "xWant", "tails": []}
3 | {"head": "post office", "relation": "isBefore", "tails": []}
4 | {"head": "to get out of the room", "relation": "HinderedBy", "tails": []}
5 | {"head": "to get out of the room", "relation": "oWant", "tails": []}
6 | {"head": "to get out of the room", "relation": "xAttr", "tails": []}
7 | {"head": "post office", "relation": "oReact", "tails": []}
8 | {"head": "post office", "relation": "xIntent", "tails": []}
9 | {"head": "post office", "relation": "HasSubEvent", "tails": []}
10 | {"head": "post office", "relation": "xNeed", "tails": []}
11 | {"head": "to get out of the room", "relation": "Causes", "tails": []}
12 | {"head": "post office", "relation": "xEffect", "tails": []}
13 | {"head": "to get out of the room", "relation": "isFilledBy", "tails": []}
14 | {"head": "to get out of the room", "relation": "xReact", "tails": []}
15 | {"head": "post office", "relation": "isAfter", "tails": []}
16 | {"head": "to get out of the room", "relation": "xReason", "tails": []}
17 | {"head": "to get out of the room", "relation": "isBefore", "tails": []}
18 | {"head": "post office", "relation": "Causes", "tails": []}
19 | {"head": "to get out of the room", "relation": "oReact", "tails": []}
20 | {"head": "post office", "relation": "oEffect", "tails": []}
21 | {"head": "post office", "relation": "oWant", "tails": []}
22 | {"head": "to get out of the room", "relation": "HasSubEvent", "tails": []}
23 | {"head": "to get out of the room", "relation": "xIntent", "tails": []}
24 | {"head": "to get out of the room", "relation": "xNeed", "tails": []}
25 | {"head": "to get out of the room", "relation": "xEffect", "tails": []}
26 | {"head": "post office", "relation": "xWant", "tails": []}
27 | {"head": "post office", "relation": "HinderedBy", "tails": []}
28 | {"head": "to get out of the room", "relation": "isAfter", "tails": []}
29 | {"head": "post office", "relation": "xAttr", "tails": []}
30 | {"head": "to get out of the room", "relation": "oEffect", "tails": []}
31 | {"head": "post office", "relation": "isFilledBy", "tails": []}
32 | {"head": "post office", "relation": "xReact", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_modelbased_relations.json:
--------------------------------------------------------------------------------
1 | {"head": "banana", "relation": "ObjectUse", "tails": []}
2 | {"head": "banana", "relation": "CapableOf", "tails": []}
3 | {"head": "banana", "relation": "MadeUpOf", "tails": []}
4 | {"head": "banana", "relation": "HasProperty", "tails": []}
5 | {"head": "banana", "relation": "Desires", "tails": []}
6 | {"head": "banana", "relation": "NotDesires", "tails": []}
7 | {"head": "banana", "relation": "AtLocation", "tails": []}
8 | {"head": "love another", "relation": "Causes", "tails": []}
9 | {"head": "love another", "relation": "HinderedBy", "tails": []}
10 | {"head": "love another", "relation": "xReason", "tails": []}
11 | {"head": "love another", "relation": "isAfter", "tails": []}
12 | {"head": "love another", "relation": "isBefore", "tails": []}
13 | {"head": "love another", "relation": "HasSubEvent", "tails": []}
14 | {"head": "love another", "relation": "isFilledBy", "tails": []}
15 | {"head": "love another", "relation": "xIntent", "tails": []}
16 | {"head": "love another", "relation": "xReact", "tails": []}
17 | {"head": "love another", "relation": "oReact", "tails": []}
18 | {"head": "love another", "relation": "xAttr", "tails": []}
19 | {"head": "love another", "relation": "xEffect", "tails": []}
20 | {"head": "love another", "relation": "xNeed", "tails": []}
21 | {"head": "love another", "relation": "xWant", "tails": []}
22 | {"head": "love another", "relation": "oEffect", "tails": []}
23 | {"head": "love another", "relation": "oWant", "tails": []}
24 | {"head": "Student gets a card", "relation": "Causes", "tails": []}
25 | {"head": "Student gets a card", "relation": "HinderedBy", "tails": []}
26 | {"head": "Student gets a card", "relation": "xReason", "tails": []}
27 | {"head": "Student gets a card", "relation": "isAfter", "tails": []}
28 | {"head": "Student gets a card", "relation": "isBefore", "tails": []}
29 | {"head": "Student gets a card", "relation": "HasSubEvent", "tails": []}
30 | {"head": "Student gets a card", "relation": "isFilledBy", "tails": []}
31 | {"head": "Student gets a card", "relation": "xIntent", "tails": []}
32 | {"head": "Student gets a card", "relation": "xReact", "tails": []}
33 | {"head": "Student gets a card", "relation": "oReact", "tails": []}
34 | {"head": "Student gets a card", "relation": "xAttr", "tails": []}
35 | {"head": "Student gets a card", "relation": "xEffect", "tails": []}
36 | {"head": "Student gets a card", "relation": "xNeed", "tails": []}
37 | {"head": "Student gets a card", "relation": "xWant", "tails": []}
38 | {"head": "Student gets a card", "relation": "oEffect", "tails": []}
39 | {"head": "Student gets a card", "relation": "oWant", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_modelbased_relations_bert.json:
--------------------------------------------------------------------------------
1 | {"head": "Student gets a card", "relation": "Causes", "tails": []}
2 | {"head": "banana", "relation": "CapableOf", "tails": []}
3 | {"head": "Student gets a card", "relation": "isFilledBy", "tails": []}
4 | {"head": "love another", "relation": "isAfter", "tails": []}
5 | {"head": "Student gets a card", "relation": "xReact", "tails": []}
6 | {"head": "Student gets a card", "relation": "xReason", "tails": []}
7 | {"head": "Student gets a card", "relation": "oEffect", "tails": []}
8 | {"head": "Student gets a card", "relation": "isBefore", "tails": []}
9 | {"head": "love another", "relation": "Causes", "tails": []}
10 | {"head": "love another", "relation": "isFilledBy", "tails": []}
11 | {"head": "Student gets a card", "relation": "oWant", "tails": []}
12 | {"head": "love another", "relation": "HasSubEvent", "tails": []}
13 | {"head": "banana", "relation": "AtLocation", "tails": []}
14 | {"head": "banana", "relation": "Desires", "tails": []}
15 | {"head": "love another", "relation": "xReason", "tails": []}
16 | {"head": "Student gets a card", "relation": "oReact", "tails": []}
17 | {"head": "banana", "relation": "MadeUpOf", "tails": []}
18 | {"head": "love another", "relation": "isBefore", "tails": []}
19 | {"head": "Student gets a card", "relation": "xWant", "tails": []}
20 | {"head": "banana", "relation": "HasProperty", "tails": []}
21 | {"head": "Student gets a card", "relation": "HinderedBy", "tails": []}
22 | {"head": "Student gets a card", "relation": "HasSubEvent", "tails": []}
23 | {"head": "Student gets a card", "relation": "xIntent", "tails": []}
24 | {"head": "Student gets a card", "relation": "xAttr", "tails": []}
25 | {"head": "Student gets a card", "relation": "xNeed", "tails": []}
26 | {"head": "banana", "relation": "ObjectUse", "tails": []}
27 | {"head": "Student gets a card", "relation": "xEffect", "tails": []}
28 | {"head": "love another", "relation": "HinderedBy", "tails": []}
29 | {"head": "banana", "relation": "NotDesires", "tails": []}
30 | {"head": "Student gets a card", "relation": "isAfter", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_modelbased_relations_dbert.json:
--------------------------------------------------------------------------------
1 | {"head": "love another", "relation": "xEffect", "tails": []}
2 | {"head": "banana", "relation": "CapableOf", "tails": []}
3 | {"head": "Student gets a card", "relation": "ObjectUse", "tails": []}
4 | {"head": "love another", "relation": "isAfter", "tails": []}
5 | {"head": "love another", "relation": "Causes", "tails": []}
6 | {"head": "love another", "relation": "isFilledBy", "tails": []}
7 | {"head": "Student gets a card", "relation": "NotDesires", "tails": []}
8 | {"head": "Student gets a card", "relation": "HasProperty", "tails": []}
9 | {"head": "banana", "relation": "AtLocation", "tails": []}
10 | {"head": "love another", "relation": "xReact", "tails": []}
11 | {"head": "Student gets a card", "relation": "CapableOf", "tails": []}
12 | {"head": "banana", "relation": "Desires", "tails": []}
13 | {"head": "love another", "relation": "oEffect", "tails": []}
14 | {"head": "banana", "relation": "MadeUpOf", "tails": []}
15 | {"head": "love another", "relation": "xReason", "tails": []}
16 | {"head": "love another", "relation": "isBefore", "tails": []}
17 | {"head": "love another", "relation": "oWant", "tails": []}
18 | {"head": "banana", "relation": "HasProperty", "tails": []}
19 | {"head": "love another", "relation": "oReact", "tails": []}
20 | {"head": "banana", "relation": "ObjectUse", "tails": []}
21 | {"head": "Student gets a card", "relation": "AtLocation", "tails": []}
22 | {"head": "Student gets a card", "relation": "Desires", "tails": []}
23 | {"head": "love another", "relation": "xWant", "tails": []}
24 | {"head": "love another", "relation": "HinderedBy", "tails": []}
25 | {"head": "Student gets a card", "relation": "MadeUpOf", "tails": []}
26 | {"head": "banana", "relation": "NotDesires", "tails": []}
27 | {"head": "love another", "relation": "HasSubEvent", "tails": []}
28 | {"head": "love another", "relation": "xAttr", "tails": []}
29 | {"head": "love another", "relation": "xNeed", "tails": []}
30 | {"head": "love another", "relation": "xIntent", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_modelbased_relations_swem.json:
--------------------------------------------------------------------------------
1 | {"head": "love another", "relation": "xEffect", "tails": []}
2 | {"head": "Student gets a card", "relation": "Causes", "tails": []}
3 | {"head": "banana", "relation": "CapableOf", "tails": []}
4 | {"head": "Student gets a card", "relation": "isFilledBy", "tails": []}
5 | {"head": "love another", "relation": "isAfter", "tails": []}
6 | {"head": "Student gets a card", "relation": "xReact", "tails": []}
7 | {"head": "Student gets a card", "relation": "xReason", "tails": []}
8 | {"head": "Student gets a card", "relation": "isBefore", "tails": []}
9 | {"head": "banana", "relation": "AtLocation", "tails": []}
10 | {"head": "love another", "relation": "oEffect", "tails": []}
11 | {"head": "banana", "relation": "MadeUpOf", "tails": []}
12 | {"head": "Student gets a card", "relation": "oReact", "tails": []}
13 | {"head": "love another", "relation": "oWant", "tails": []}
14 | {"head": "Student gets a card", "relation": "HasSubEvent", "tails": []}
15 | {"head": "Student gets a card", "relation": "xIntent", "tails": []}
16 | {"head": "Student gets a card", "relation": "xNeed", "tails": []}
17 | {"head": "love another", "relation": "xWant", "tails": []}
18 | {"head": "Student gets a card", "relation": "xEffect", "tails": []}
19 | {"head": "love another", "relation": "HinderedBy", "tails": []}
20 | {"head": "Student gets a card", "relation": "isAfter", "tails": []}
21 | {"head": "love another", "relation": "xAttr", "tails": []}
22 | {"head": "Student gets a card", "relation": "oEffect", "tails": []}
23 | {"head": "love another", "relation": "Causes", "tails": []}
24 | {"head": "love another", "relation": "isFilledBy", "tails": []}
25 | {"head": "Student gets a card", "relation": "oWant", "tails": []}
26 | {"head": "love another", "relation": "xReact", "tails": []}
27 | {"head": "banana", "relation": "Desires", "tails": []}
28 | {"head": "love another", "relation": "xReason", "tails": []}
29 | {"head": "love another", "relation": "isBefore", "tails": []}
30 | {"head": "Student gets a card", "relation": "xWant", "tails": []}
31 | {"head": "banana", "relation": "HasProperty", "tails": []}
32 | {"head": "Student gets a card", "relation": "HinderedBy", "tails": []}
33 | {"head": "Student gets a card", "relation": "xAttr", "tails": []}
34 | {"head": "love another", "relation": "oReact", "tails": []}
35 | {"head": "banana", "relation": "ObjectUse", "tails": []}
36 | {"head": "banana", "relation": "NotDesires", "tails": []}
37 | {"head": "love another", "relation": "HasSubEvent", "tails": []}
38 | {"head": "love another", "relation": "xNeed", "tails": []}
39 | {"head": "love another", "relation": "xIntent", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_no_head_extract.json:
--------------------------------------------------------------------------------
1 | {"head": "Gabby always brought cookies to school.", "relation": "xWant", "tails": []}
2 | {"head": "Gabby always brought cookies to school.", "relation": "HinderedBy", "tails": []}
3 | {"head": "Gabby always brought cookies to school.", "relation": "HasSubEvent", "tails": []}
4 | {"head": "Gabby always brought cookies to school.", "relation": "xIntent", "tails": []}
5 | {"head": "Gabby always brought cookies to school.", "relation": "xNeed", "tails": []}
6 | {"head": "Gabby always brought cookies to school.", "relation": "xEffect", "tails": []}
7 | {"head": "Gabby always brought cookies to school.", "relation": "xAttr", "tails": []}
8 | {"head": "Gabby always brought cookies to school.", "relation": "isAfter", "tails": []}
9 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []}
10 | {"head": "Gabby always brought cookies to school.", "relation": "isFilledBy", "tails": []}
11 | {"head": "Gabby always brought cookies to school.", "relation": "xReact", "tails": []}
12 | {"head": "Gabby always brought cookies to school.", "relation": "xReason", "tails": []}
13 | {"head": "Gabby always brought cookies to school.", "relation": "oEffect", "tails": []}
14 | {"head": "Gabby always brought cookies to school.", "relation": "isBefore", "tails": []}
15 | {"head": "Gabby always brought cookies to school.", "relation": "oWant", "tails": []}
16 | {"head": "Gabby always brought cookies to school.", "relation": "oReact", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_no_match_subset.json:
--------------------------------------------------------------------------------
1 | {"head": "to bring", "relation": "Desires", "tails": []}
2 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []}
3 | {"head": "cookies", "relation": "Causes", "tails": []}
4 | {"head": "Gabby", "relation": "Causes", "tails": []}
5 | {"head": "school", "relation": "Causes", "tails": []}
6 | {"head": "Gabby", "relation": "Desires", "tails": []}
7 | {"head": "bring cookies", "relation": "Causes", "tails": []}
8 | {"head": "to bring", "relation": "Causes", "tails": []}
9 | {"head": "Gabby always brought cookies to school.", "relation": "Desires", "tails": []}
10 | {"head": "school", "relation": "Desires", "tails": []}
11 | {"head": "cookies", "relation": "Desires", "tails": []}
12 | {"head": "bring cookies", "relation": "Desires", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_rel_subset.json:
--------------------------------------------------------------------------------
1 | {"head": "Gabby", "relation": "xNeed", "tails": []}
2 | {"head": "school", "relation": "Causes", "tails": []}
3 | {"head": "Gabby always brought cookies to school.", "relation": "xNeed", "tails": []}
4 | {"head": "bring cookies", "relation": "Causes", "tails": []}
5 | {"head": "to bring", "relation": "Causes", "tails": []}
6 | {"head": "school", "relation": "ObjectUse", "tails": []}
7 | {"head": "Gabby", "relation": "Causes", "tails": []}
8 | {"head": "cookies", "relation": "xNeed", "tails": []}
9 | {"head": "Gabby", "relation": "ObjectUse", "tails": []}
10 | {"head": "Gabby always brought cookies to school.", "relation": "Causes", "tails": []}
11 | {"head": "cookies", "relation": "Causes", "tails": []}
12 | {"head": "cookies", "relation": "ObjectUse", "tails": []}
13 | {"head": "school", "relation": "xNeed", "tails": []}
14 | {"head": "bring cookies", "relation": "xNeed", "tails": []}
15 | {"head": "to bring", "relation": "xNeed", "tails": []}
--------------------------------------------------------------------------------
/examples/results/kgraph_with_context.json:
--------------------------------------------------------------------------------
1 | {"head": "PersonX wraps gifts", "relation": "oReact", "tails": [" happy", " grateful"]}
2 | {"head": "gifts", "relation": "AtLocation", "tails": [" store", " gift box"]}
3 | {"head": "gifts", "relation": "MadeUpOf", "tails": [" give as a gift", " gift", " give as gift"]}
4 | {"head": "PersonX wraps gifts", "relation": "HasSubEvent", "tails": [" PersonX receives a gift", " PersonX receives a gift."]}
5 | {"head": "PersonX wraps gifts", "relation": "xIntent", "tails": [" to be nice", " to give a gift", " to be generous"]}
6 | {"head": "PersonX wraps gifts", "relation": "xNeed", "tails": [" to buy wrapping paper", " to buy the wrapping paper", " to buy wrapping paper."]}
7 | {"head": "PersonX wraps gifts", "relation": "xEffect", "tails": [" PersonX receives a gift in return.", " PersonX receives a gift.", " PersonX receives a gift in return"]}
8 | {"head": "PersonX wraps gifts", "relation": "isAfter", "tails": [" PersonX buys a gift", " PersonX buys a gift for PersonY"]}
9 | {"head": "wraps gifts", "relation": "Desires", "tails": [" wrap presents", " wrap gift", " give to person"]}
10 | {"head": "wraps gifts", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped paper"]}
11 | {"head": "wraps", "relation": "Desires", "tails": [" wrap gift", " wrap presents", " wrap presents in"]}
12 | {"head": "PersonX wraps gifts", "relation": "oEffect", "tails": [" PersonY opens the gift.", " PersonY opens the gift"]}
13 | {"head": "wraps gifts", "relation": "ObjectUse", "tails": [" wrap presents", " wrap presents in", " give to person Y"]}
14 | {"head": "wraps", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped in paper"]}
15 | {"head": "PersonX wraps gifts", "relation": "oWant", "tails": [" to thank PersonX", " to open the gift"]}
16 | {"head": "wraps", "relation": "ObjectUse", "tails": [" wrap the gift in", " wrap a gift in", " wrap a present in"]}
17 | {"head": "gifts", "relation": "Desires", "tails": [" give as a gift", " give to someone", " give to person"]}
18 | {"head": "PersonX wraps gifts", "relation": "xWant", "tails": [" to give the gifts to the person", " to give the gift to the person", " to give the gifts to their friends"]}
19 | {"head": "wraps gifts", "relation": "CapableOf", "tails": [" wrap gift", " wrap presents", " wrap"]}
20 | {"head": "PersonX wraps gifts", "relation": "HinderedBy", "tails": [" PersonX doesn't have any wrapping paper.", " PersonX has no wrapping paper.", " PersonX doesn't know how to wrap."]}
21 | {"head": "gifts", "relation": "HasProperty", "tails": [" given as gifts", " gifts", " give to person"]}
22 | {"head": "PersonX wraps gifts", "relation": "xAttr", "tails": [" thoughtful", " generous", " giving"]}
23 | {"head": "wraps", "relation": "CapableOf", "tails": [" wrap gift", " wrap gift in", " wrap presents in"]}
24 | {"head": "gifts", "relation": "ObjectUse", "tails": [" give as a gift", " give as a birthday gift", " give to someone"]}
25 | {"head": "wraps gifts", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap paper", " wrap"]}
26 | {"head": "wraps gifts", "relation": "AtLocation", "tails": [" box"]}
27 | {"head": "PersonX wraps gifts", "relation": "Causes", "tails": [" PersonX wants to be nice.", " PersonX wants to be generous.", " PersonX wants to be nice"]}
28 | {"head": "PersonX wraps gifts", "relation": "isFilledBy", "tails": [" gift bag", " presents", " gifts"]}
29 | {"head": "PersonX wraps gifts", "relation": "xReact", "tails": [" happy"]}
30 | {"head": "wraps", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap"]}
31 | {"head": "gifts", "relation": "CapableOf", "tails": [" give to person", " give to people", " give to persony"]}
32 | {"head": "PersonX wraps gifts", "relation": "xReason", "tails": [" PersonX buys a gift", " PersonX buys a gift for PersonY"]}
33 | {"head": "PersonX wraps gifts", "relation": "isBefore", "tails": [" PersonX puts the gifts under the tree", " PersonX puts the gifts in the gift bag", " PersonX gives the gifts to the kids"]}
--------------------------------------------------------------------------------
/examples/results/kgraph_without_context.json:
--------------------------------------------------------------------------------
1 | {"head": "PersonX wraps gifts", "relation": "oReact", "tails": [" none", " happy", " grateful"]}
2 | {"head": "gifts", "relation": "AtLocation", "tails": [" gift shop", " store", " gift box"]}
3 | {"head": "gifts", "relation": "MadeUpOf", "tails": [" give as a gift", " gift", " give as gift"]}
4 | {"head": "PersonX wraps gifts", "relation": "HasSubEvent", "tails": [" PersonX opens the gift", " PersonX receives a gift", " PersonX receives a gift."]}
5 | {"head": "PersonX wraps gifts", "relation": "xIntent", "tails": [" to be nice", " to give a gift", " to be generous"]}
6 | {"head": "PersonX wraps gifts", "relation": "xNeed", "tails": [" to buy wrapping paper", " to buy the wrapping paper", " to buy wrapping paper."]}
7 | {"head": "PersonX wraps gifts", "relation": "xEffect", "tails": [" PersonX receives a gift in return.", " PersonX receives a gift.", " PersonX receives a gift in return"]}
8 | {"head": "PersonX wraps gifts", "relation": "isAfter", "tails": [" PersonX goes to the store", " PersonX buys a gift", " PersonX buys a gift for PersonY"]}
9 | {"head": "wraps gifts", "relation": "Desires", "tails": [" wrap presents", " wrap gift", " give to person"]}
10 | {"head": "wraps gifts", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped paper"]}
11 | {"head": "wraps", "relation": "Desires", "tails": [" wrap gift", " wrap presents", " wrap presents in"]}
12 | {"head": "PersonX wraps gifts", "relation": "oEffect", "tails": [" none", " PersonY opens the gift.", " PersonY opens the gift"]}
13 | {"head": "wraps gifts", "relation": "ObjectUse", "tails": [" wrap presents", " wrap presents in", " give to person Y"]}
14 | {"head": "wraps", "relation": "HasProperty", "tails": [" wrapping paper", " wrapped", " wrapped in paper"]}
15 | {"head": "PersonX wraps gifts", "relation": "oWant", "tails": [" none", " to thank PersonX", " to open the gift"]}
16 | {"head": "wraps gifts", "relation": "NotDesires", "tails": [" wrapping paper", " wrap presents", " wrap gift"]}
17 | {"head": "wraps", "relation": "ObjectUse", "tails": [" wrap the gift in", " wrap a gift in", " wrap a present in"]}
18 | {"head": "gifts", "relation": "Desires", "tails": [" give as a gift", " give to someone", " give to person"]}
19 | {"head": "PersonX wraps gifts", "relation": "xWant", "tails": [" to give the gifts to the person", " to give the gift to the person", " to give the gifts to their friends"]}
20 | {"head": "wraps", "relation": "NotDesires", "tails": [" wrap around body", " wrap up a gift", " wrap up a present"]}
21 | {"head": "wraps gifts", "relation": "CapableOf", "tails": [" wrap gift", " wrap presents", " wrap"]}
22 | {"head": "PersonX wraps gifts", "relation": "HinderedBy", "tails": [" PersonX doesn't have any wrapping paper.", " PersonX has no wrapping paper.", " PersonX doesn't know how to wrap."]}
23 | {"head": "gifts", "relation": "HasProperty", "tails": [" given as gifts", " gifts", " give to person"]}
24 | {"head": "PersonX wraps gifts", "relation": "xAttr", "tails": [" thoughtful", " generous", " giving"]}
25 | {"head": "wraps", "relation": "CapableOf", "tails": [" wrap gift", " wrap gift in", " wrap presents in"]}
26 | {"head": "gifts", "relation": "ObjectUse", "tails": [" give as a gift", " give as a birthday gift", " give to someone"]}
27 | {"head": "wraps gifts", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap paper", " wrap"]}
28 | {"head": "wraps gifts", "relation": "AtLocation", "tails": [" gift shop", " mail box", " box"]}
29 | {"head": "gifts", "relation": "NotDesires", "tails": [" give as a gift", " give to person", " give to someone"]}
30 | {"head": "PersonX wraps gifts", "relation": "Causes", "tails": [" PersonX wants to be nice.", " PersonX wants to be generous.", " PersonX wants to be nice"]}
31 | {"head": "PersonX wraps gifts", "relation": "isFilledBy", "tails": [" gift bag", " presents", " gifts"]}
32 | {"head": "wraps", "relation": "AtLocation", "tails": [" store", " box", " mail box"]}
33 | {"head": "PersonX wraps gifts", "relation": "xReact", "tails": [" happy", " good about themselves", " good about themselves."]}
34 | {"head": "wraps", "relation": "MadeUpOf", "tails": [" wrapping paper", " wrap", " wrap cloth"]}
35 | {"head": "gifts", "relation": "CapableOf", "tails": [" give to person", " give to people", " give to persony"]}
36 | {"head": "PersonX wraps gifts", "relation": "xReason", "tails": [" PersonX buys a lot of wrapping paper", " PersonX buys a gift", " PersonX buys a gift for PersonY"]}
37 | {"head": "PersonX wraps gifts", "relation": "isBefore", "tails": [" PersonX puts the gifts under the tree", " PersonX puts the gifts in the gift bag", " PersonX gives the gifts to the kids"]}
--------------------------------------------------------------------------------
/examples/results/test_atomic2020_res_cometgpt2_sample.json:
--------------------------------------------------------------------------------
1 | {"head": "PersonX takes things for granted", "relation": "xNeed", "tails": ["to have a lot of work to do"]}
2 | {"head": "PersonX calls PersonY ambulance", "relation": "xNeed", "tails": ["to pick up the phone"]}
3 | {"head": "PersonX pleases ___ to make", "relation": "xWant", "tails": ["to be successful"]}
4 | {"head": "PersonX shoves PersonY back", "relation": "xEffect", "tails": ["none"]}
5 | {"head": "PersonX dates for years", "relation": "xWant", "tails": ["to get married"]}
6 | {"head": "PersonX covers every aspect", "relation": "isAfter", "tails": ["PersonX is a lawyer"]}
7 | {"head": "PersonX wants to go", "relation": "isAfter", "tails": ["PersonX is in the hospital"]}
8 | {"head": "PersonX hits by lightning", "relation": "xEffect", "tails": ["is in pain"]}
9 | {"head": "PersonX finally meet PersonY", "relation": "xNeed", "tails": ["to go to the meeting"]}
10 | {"head": "chain", "relation": "ObjectUse", "tails": ["tie the dog up"]}
--------------------------------------------------------------------------------
/examples/results/test_atomic2020_res_zeroshot_sample.jsonl:
--------------------------------------------------------------------------------
1 | {"head": "PersonX takes things for granted", "relation": "xNeed", "tails": ["be able to do something that is not possible with other people. Now, PersonX needs to be able to do something that is not possible with other people", "be able to do something that is not possible with other people. Now, PersonX needs to be able to do something that is not possible with other people", "be able to do something that is not possible with other people. Now, PersonX needs to be able to do something that is not possible with other people"]}
2 | {"head": "PersonX calls PersonY ambulance", "relation": "xNeed", "tails": ["get a call from PersonY.", "get a call from PersonY.", "get a call from PersonY."]}
3 | {"head": "PersonX pleases ___ to make", "relation": "xWant", "tails": ["____ the person who is not a member of the group.", "____ the person who is not a member of the group.", "____ the person who is not a member of the group."]}
4 | {"head": "PersonX shoves PersonY back", "relation": "xEffect", "tails": ["increased by 1.5x.", "increased by 1.5x.", "increased by 1.5x."]}
5 | {"head": "PersonX dates for years", "relation": "xWant", "tails": ["have a relationship with you.", "have a relationship with you.", "have a relationship with you."]}
6 | {"head": "PersonX covers every aspect", "relation": "isAfter", "tails": ["I've covered the basics of how to use the app, how to use the app's built-in camera, how to use the app's built-", "I've covered the basics of how to use the app, how to use the app's built-in camera, how to use the app's built-", "I've covered the basics of how to use the app, how to use the app's built-in camera, how to use the app's built-"]}
7 | {"head": "PersonX wants to go", "relation": "isAfter", "tails": ["he's been working on his own project.", "he's been working on his own project.", "he's been working on his own project."]}
8 | {"head": "PersonX hits by lightning", "relation": "xEffect", "tails": ["increased by 1% per level.", "increased by 1% per level.", "increased by 1% per level."]}
9 | {"head": "PersonX finally meet PersonY", "relation": "xNeed", "tails": ["get a job and PersonY needs to get a job.", "get a job and PersonY needs to get a job.", "get a job and PersonY needs to get a job."]}
10 | {"head": "chain", "relation": "ObjectUse", "tails": ["the following:", "the following:", "the same purpose."]}
--------------------------------------------------------------------------------
/examples/sample_graph.jsonl:
--------------------------------------------------------------------------------
1 | {"source": "PersonX buys lunch", "rel": "xNeed", "tails": ["bring a wallet"]}
2 | {"source": "Throwing a party", "rel": "Causes", "tails": ["have fun"]}
--------------------------------------------------------------------------------
/examples/sample_graph.tsv:
--------------------------------------------------------------------------------
1 | PersonX always drank xNeed to get a drink
2 | PersonX abandons ___ altogether xNeed Plows the field.
3 | PersonX about to get married xNeed meet someone
4 | PersonX accepts PersonY appointment xNeed to clear spot in schedule
5 | PersonX accepts PersonY thanks xNeed to says thanks to Y
6 | PersonX accepts PersonY's proposal xNeed to date
7 | PersonX accidentally bumped xNeed run or drive fast
8 | PersonX accidentally fell xNeed to walk too fast
9 | PersonX accompanies PersonY far xNeed to walk with PersonY
10 | PersonX accomplishes PersonX's goal xNeed to make a goal
--------------------------------------------------------------------------------
/examples/sample_graph2.tsv:
--------------------------------------------------------------------------------
1 | PersonX always drank xNeed to get a drink
2 | PersonX abandons ___ altogether xNeed Plows the field.
3 | PersonX about to get married xNeed meet someone
4 | PersonX accepts PersonY appointment xNeed to clear spot in schedule
5 | PersonX accepts PersonY thanks xNeed to says thanks to Y
6 | PersonX accepts PersonY's proposal xNeed to date
7 | PersonX accidentally bumped xNeed run or drive fast
8 | PersonX accidentally fell xNeed to walk too fast
9 | PersonX accompanies PersonY far xNeed to walk with PersonY
10 | PersonX accomplishes PersonX's goal xNeed to make a goal
11 | PersonX is at a party xWishes to drink beer and dance
12 | PersonX bleeds a lot xWishes to see a doctor
13 | PersonX works as a cashier xWishes to be a store manager
14 | PersonX gets dirty xWishes to clean up
15 | PersonX stays up all night studying xWishes to sleep all day
16 | PersonX gets PersonY's autograph xWishes to have a relationship with PersonY
17 | PersonX ends a friendship xWishes to meet new people
18 | PersonX makes his own costume xWishes to go to a costume party
19 | PersonX calls PersonY xWishes to have a long chat
20 | PersonX tells PersonY a secret xWishes to get PersonY's advice
21 | PersonX mows the lawn xWishes to get a new lawnmower
22 |
--------------------------------------------------------------------------------
/examples/sample_graph3.jsonl:
--------------------------------------------------------------------------------
1 | {"head": "PersonX accepts PersonY appointment", "relation": "xNeed", "tails": ["to clear spot in schedule"]}
2 | {"head": "PersonX accepts PersonY thanks", "relation": "xNeed", "tails": ["to says thanks to Y"]}
3 | {"head": "PersonX abandons ___ altogether", "relation": "xNeed", "tails": ["Plows the field."]}
4 | {"head": "PersonX accompanies PersonY far", "relation": "xNeed", "tails": ["to walk with PersonY"]}
5 | {"head": "PersonX always drank", "relation": "xNeed", "tails": ["to get a drink"]}
6 | {"head": "PersonX about to get married", "relation": "xNeed", "tails": ["meet someone"]}
7 | {"head": "PersonX accidentally fell", "relation": "xNeed", "tails": ["to walk too fast"]}
8 | {"head": "PersonX accidentally bumped", "relation": "xNeed", "tails": ["run or drive fast"]}
9 | {"head": "PersonX accepts PersonY's proposal", "relation": "xNeed", "tails": ["to date"]}
10 | {"head": "PersonX accomplishes PersonX's goal", "relation": "xNeed", "tails": ["to make a goal"]}
--------------------------------------------------------------------------------
/examples/sample_linking_graph.csv:
--------------------------------------------------------------------------------
1 | PersonX drives ___ to work | xWant | his friend to get to work on time
2 | PersonX drives ___ fast | xWant | to get out of the car
3 | drive | HasSubEvent | get into car
--------------------------------------------------------------------------------
/examples/snapshots/custom-relation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/custom-relation.png
--------------------------------------------------------------------------------
/examples/snapshots/kg-concepts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/kg-concepts.png
--------------------------------------------------------------------------------
/examples/snapshots/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/pipeline.png
--------------------------------------------------------------------------------
/examples/snapshots/quickstart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/examples/snapshots/quickstart.png
--------------------------------------------------------------------------------
/examples/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from tqdm import tqdm
3 | import pathlib
4 |
5 | from kogito.inference import CommonsenseInference
6 | from kogito.models.bart.comet import COMETBART
7 |
8 |
9 | def main():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--datapath", type=str)
12 | parser.add_argument("--output-dir", type=str)
13 |
14 | args = parser.parse_args()
15 |
16 | with open(args.datapath) as f:
17 | data = f.readlines()
18 |
19 | csi = CommonsenseInference()
20 | model = COMETBART.from_pretrained()
21 |
22 | for i, line in tqdm(
23 | enumerate(data), total=len(data), desc="Generating", position=0, leave=True
24 | ):
25 | kgraph = csi.infer(line, model)
26 | kgraph.to_jsonl(
27 | f"{args.output_dir}/{pathlib.Path(args.datapath).stem}_{i+1}.json"
28 | )
29 |
30 |
31 | if __name__ == "__main__":
32 | main()
33 |
--------------------------------------------------------------------------------
/examples/test_atomic2020_sample.json:
--------------------------------------------------------------------------------
1 | {"relation": "xNeed", "head": "PersonX takes things for granted", "tails": ["to have been lazy at work", "to have wasted resources", "none", "to have used their money on unnecessary things"]}
2 | {"relation": "xNeed", "head": "PersonX calls PersonY ambulance", "tails": ["get the phone", "recognize the person needs an ambulance", "to see Person Y is hurt.", "to get a phone.", "none"]}
3 | {"relation": "xWant", "head": "PersonX pleases ___ to make", "tails": ["to say thanks for making it", "to eat it", "to produce something", "to be in their good graces"]}
4 | {"relation": "xEffect", "head": "PersonX shoves PersonY back", "tails": ["is physical", "none", "is violent"]}
5 | {"relation": "xWant", "head": "PersonX dates for years", "tails": ["to keep in contact for 2nd dates", "to propose", "to go to the movies", "to go to dinner", "to get married"]}
6 | {"relation": "isAfter", "head": "PersonX covers every aspect", "tails": ["PersonX researches an article about current events"]}
7 | {"relation": "isAfter", "head": "PersonX wants to go", "tails": ["PersonX sees a flyer for a festival"]}
8 | {"relation": "xEffect", "head": "PersonX hits by lightning", "tails": ["Has heart attack", "goes to the hospital", "Has hair burned", "gets hurt"]}
9 | {"relation": "xNeed", "head": "PersonX finally meet PersonY", "tails": ["ask friend to be introduced to PersonY", "Get dressed", "Find person", "to search for person Y", "travels to meet PersonY", "to be away from person Y for sometime"]}
10 | {"relation": "ObjectUse", "head": "chain", "tails": ["hold the swing", "put around it", "lock up a vicious pet.", "make spooky noises in a haunted house", "attach the mooring", "lock the bike", "lock up the bike while they look around", "hold something on the wall", "to chain bicycle to the bicycle rack", "take with them", "decorate the house", "hold down a garage door without locks.", "take off", "keep pick-pockets from stealing their wallet", "pull the car", "hold their spectacles on their neck", "hold the key", "catch pants on", "connect the truck and the car", "lock it up", "hook it up too", "tug on a truck"]}
--------------------------------------------------------------------------------
/examples/test_comet_bart.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.bart.comet import COMETBART
3 |
4 | input_graph = KnowledgeGraph.from_jsonl("test_atomic2020_sample.json")
5 |
6 | model = COMETBART.from_pretrained("mismayil/comet-bart-ai2")
7 | output_graph = model.generate(
8 | input_graph, batch_size=256, num_return_sequences=1
9 | )
10 | output_graph.to_jsonl("results.json")
11 |
--------------------------------------------------------------------------------
/examples/test_comet_gpt2.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.gpt2.comet import COMETGPT2
3 |
4 | model = COMETGPT2.from_pretrained("mismayil/comet-gpt2-ai2")
5 | input_graph = KnowledgeGraph.from_jsonl("./test_atomic2020_sample.json")
6 | output_graph = model.generate(input_graph)
7 | output_graph.to_jsonl("results/test_atomic2020_res_cometgpt2_sample.json")
8 |
--------------------------------------------------------------------------------
/examples/test_gpt2.json:
--------------------------------------------------------------------------------
1 | {"head": "The researcher tweets about a new library", "relation": "xAttr", "tails": ["a fan of the library. PersonY is a critic. PersonZ is a neutral observer.", "a fan of the library. PersonY is a critic. PersonZ is a neutral observer.", "a fan of the library. PersonY is a critic. PersonZ is a neutral observer."]}
2 | {"head": "The researcher tweets about a new library", "relation": "xWant", "tails": ["find out more about the library. PersonX will then tweet about the library. PersonX will then tweet about the library again. PersonX will then tweet", "find out more about the library. PersonX will then tweet about the library. PersonX will then tweet about the library again. PersonX will then tweet", "find out more about the library. PersonX will then tweet about the library. PersonX will then tweet about the library again. PersonX will then tweet"]}
3 | {"head": "The researcher tweets about a new library", "relation": "xNeed", "tails": ["find the library and then find the person who owns it. Now, PersonX can find the library and then find the person who owns it. This is", "find the library and then find the person who owns it. Now, PersonX can find the library and then find the person who owns it. This is", "find the library and then find the person who owns it. Now, PersonX can find the library and then find the person who owns it. This is"]}
4 | {"head": "The researcher tweets about a new library", "relation": "xEffect", "tails": ["that they will be more likely to visit the library.", "that they will be more likely to visit the library.", "that they will be more likely to visit the library."]}
5 | {"head": "The researcher tweets about a new library", "relation": "HinderedBy", "tails": ["the library was not open to the public.", "the library was not open to the public.", "the library was not open to the public."]}
6 | {"head": "The researcher tweets about a new library", "relation": "oWant", "tails": ["use it. The researcher tweets about a new library. After, others will want to use it. The researcher tweets about a new library. After, others", "use it. The researcher tweets about a new library. After, others will want to use it. The researcher tweets about a new library. After, others", "use it. The researcher tweets about a new library. After, others will want to use it. The researcher tweets about a new library. After, others"]}
7 | {"head": "The researcher tweets about a new library", "relation": "xReact", "tails": ["a library for people who are interested in the humanities. PersonY will be a library for people who are interested in the sciences.", "a library for people who are interested in the humanities. PersonY will be a library for people who are interested in the sciences.", "a library for people who are interested in the humanities. PersonY will be a library for people who are interested in the sciences."]}
8 | {"head": "The researcher tweets about a new library", "relation": "oEffect", "tails": ["positive or negative depending on the social network.", "positive or negative depending on the social network.", "positive or negative depending on the social network."]}
9 | {"head": "The researcher tweets about a new library", "relation": "xIntent", "tails": ["the library. PersonY did this to the library. PersonZ did this to the library.", "the library. PersonY did this to the library. PersonZ did this to the library.", "the library. PersonY did this to the library. PersonZ did this to the library."]}
10 | {"head": "The researcher tweets about a new library", "relation": "oReact", "tails": ["that they have to follow suit.", "that they have to follow suit.", "that they have to follow suit."]}
11 | {"head": "The researcher tweets about a new library", "relation": "isBefore", "tails": ["the library's Twitter account tweets about the researcher.", "the library's Twitter account tweets about the researcher.", "the library's Twitter account tweets about the researcher."]}
12 | {"head": "The researcher tweets about a new library", "relation": "isAfter", "tails": ["he tweets about a new library. Before that, he tweets about a new library. Before that, he tweets about a new library. Before that, he", "he tweets about a new library. Before that, he tweets about a new library. Before that, he tweets about a new library. Before that, he", "he tweets about a new library. Before that, he tweets about a new library. Before that, he tweets about a new library. Before that, he"]}
13 | {"head": "The researcher tweets about a new library", "relation": "HasSubEvent", "tails": ["be able to see the library's name, location, and the number of books in the library.", "be able to see the library's name, location, and the number of books in the library.", "be able to see the library's name, location, and the number of books in the library."]}
14 | {"head": "The researcher tweets about a new library", "relation": "Causes", "tails": ["a flurry of activity on Twitter. \u00a0The library is then mentioned in the news, and the library's website is mentioned in the news.", "a flurry of activity on Twitter. \u00a0The library is then mentioned in the news, and the library's website is mentioned in the news.", "a flurry of activity on Twitter. \u00a0The library is then mentioned in the news, and the library's website is mentioned in the news."]}
15 | {"head": "The researcher tweets about a new library", "relation": "xReason", "tails": ["he was bored. PersonY did this because he was bored. PersonZ did this because he was bored.", "he was bored. PersonY did this because he was bored. PersonZ did this because he was bored.", "he was bored. PersonY did this because he was bored. PersonZ did this because he was bored."]}
--------------------------------------------------------------------------------
/examples/test_zeroshot.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.gpt2.zeroshot import GPT2Zeroshot
3 |
4 | input_graph = KnowledgeGraph.from_jsonl("./test_atomic2020_sample.json")
5 |
6 | model = GPT2Zeroshot()
7 | output_graph = model.generate(input_graph)
8 | output_graph.to_jsonl("results/test_atomic2020_res_zeroshot_sample.jsonl")
9 |
--------------------------------------------------------------------------------
/examples/train_comet_bart.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.bart.comet import COMETBART, COMETBARTConfig
3 |
4 |
5 | config = COMETBARTConfig(
6 | output_dir="models/comet-bart",
7 | task="summarization",
8 | n_val=100,
9 | num_workers=1,
10 | learning_rate=1e-5,
11 | gpus=0,
12 | sortish_sampler=True,
13 | atomic=True,
14 | train_batch_size=32,
15 | eval_batch_size=32,
16 | max_epochs=1,
17 | pretrained_model="facebook/bart-large",
18 | )
19 | model = COMETBART(config)
20 | train_graph = KnowledgeGraph.from_csv(
21 | "data/atomic2020_data-feb2021/train.tsv", header=None, sep="\t"
22 | )
23 | val_graph = KnowledgeGraph.from_csv(
24 | "data/atomic2020_data-feb2021/dev.tsv", header=None, sep="\t"
25 | )
26 | test_graph = KnowledgeGraph.from_csv(
27 | "data/atomic2020_data-feb2021/test.tsv", header=None, sep="\t"
28 | )
29 | model.train(
30 | train_graph=train_graph,
31 | val_graph=val_graph,
32 | test_graph=test_graph,
33 | logger_name="wandb",
34 | )
35 |
--------------------------------------------------------------------------------
/examples/train_comet_gpt2.py:
--------------------------------------------------------------------------------
1 | from kogito.core.knowledge import KnowledgeGraph
2 | from kogito.models.gpt2.comet import COMETGPT2
3 |
4 | model = COMETGPT2("gpt2-xl")
5 | train_graph = KnowledgeGraph.from_csv(
6 | "data/atomic2020/sample_train.tsv", header=None, sep="\t"
7 | )
8 | val_graph = KnowledgeGraph.from_csv(
9 | "data/atomic2020/sample_dev.tsv", header=None, sep="\t"
10 | )
11 | model.train(
12 | train_graph=train_graph,
13 | val_graph=val_graph,
14 | batch_size=32,
15 | output_dir="models/comet-gpt2",
16 | epochs=1,
17 | lr=5e-5,
18 | )
19 |
--------------------------------------------------------------------------------
/experiments/relation_modeling/spacy/base_config.cfg:
--------------------------------------------------------------------------------
1 | # This is an auto-generated partial config. To use it with 'spacy train'
2 | # you can run spacy init fill-config to auto-fill all default settings:
3 | # python -m spacy init fill-config ./base_config.cfg ./config.cfg
4 | [paths]
5 | train = "data/train.spacy"
6 | dev = "data/dev.spacy"
7 | vectors = "en_core_web_lg"
8 | [system]
9 | gpu_allocator = null
10 |
11 | [nlp]
12 | lang = "en"
13 | pipeline = ["tok2vec","textcat_multilabel"]
14 | batch_size = 1000
15 |
16 | [components]
17 |
18 | [components.tok2vec]
19 | factory = "tok2vec"
20 |
21 | [components.tok2vec.model]
22 | @architectures = "spacy.Tok2Vec.v2"
23 |
24 | [components.tok2vec.model.embed]
25 | @architectures = "spacy.MultiHashEmbed.v2"
26 | width = ${components.tok2vec.model.encode.width}
27 | attrs = ["ORTH", "SHAPE"]
28 | rows = [5000, 2500]
29 | include_static_vectors = true
30 |
31 | [components.tok2vec.model.encode]
32 | @architectures = "spacy.MaxoutWindowEncoder.v2"
33 | width = 256
34 | depth = 8
35 | window_size = 1
36 | maxout_pieces = 3
37 |
38 | [components.textcat_multilabel]
39 | factory = "textcat_multilabel"
40 |
41 | [components.textcat_multilabel.model]
42 | @architectures = "spacy.TextCatEnsemble.v2"
43 | nO = null
44 |
45 | [components.textcat_multilabel.model.tok2vec]
46 | @architectures = "spacy.Tok2VecListener.v1"
47 | width = ${components.tok2vec.model.encode.width}
48 |
49 | [components.textcat_multilabel.model.linear_model]
50 | @architectures = "spacy.TextCatBOW.v2"
51 | exclusive_classes = false
52 | ngram_size = 1
53 | no_output_layer = false
54 |
55 | [corpora]
56 |
57 | [corpora.train]
58 | @readers = "spacy.Corpus.v1"
59 | path = ${paths.train}
60 | max_length = 0
61 |
62 | [corpora.dev]
63 | @readers = "spacy.Corpus.v1"
64 | path = ${paths.dev}
65 | max_length = 0
66 |
67 | [training]
68 | dev_corpus = "corpora.dev"
69 | train_corpus = "corpora.train"
70 |
71 | [training.optimizer]
72 | @optimizers = "Adam.v1"
73 |
74 | [training.batcher]
75 | @batchers = "spacy.batch_by_words.v1"
76 | discard_oversize = false
77 | tolerance = 0.2
78 |
79 | [training.batcher.size]
80 | @schedules = "compounding.v1"
81 | start = 100
82 | stop = 1000
83 | compound = 1.001
84 |
85 | [initialize]
86 | vectors = ${paths.vectors}
--------------------------------------------------------------------------------
/experiments/relation_modeling/spacy/config.cfg:
--------------------------------------------------------------------------------
1 | [paths]
2 | train = "data/train.spacy"
3 | dev = "data/dev.spacy"
4 | vectors = "en_core_web_lg"
5 | init_tok2vec = null
6 |
7 | [system]
8 | gpu_allocator = null
9 | seed = 0
10 |
11 | [nlp]
12 | lang = "en"
13 | pipeline = ["tok2vec","textcat_multilabel"]
14 | batch_size = 1000
15 | disabled = []
16 | before_creation = null
17 | after_creation = null
18 | after_pipeline_creation = null
19 | tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
20 |
21 | [components]
22 |
23 | [components.textcat_multilabel]
24 | factory = "textcat_multilabel"
25 | scorer = {"@scorers":"spacy.textcat_multilabel_scorer.v1"}
26 | threshold = 0.5
27 |
28 | [components.textcat_multilabel.model]
29 | @architectures = "spacy.TextCatEnsemble.v2"
30 | nO = null
31 |
32 | [components.textcat_multilabel.model.linear_model]
33 | @architectures = "spacy.TextCatBOW.v2"
34 | exclusive_classes = false
35 | ngram_size = 1
36 | no_output_layer = false
37 | nO = null
38 |
39 | [components.textcat_multilabel.model.tok2vec]
40 | @architectures = "spacy.Tok2VecListener.v1"
41 | width = ${components.tok2vec.model.encode.width}
42 | upstream = "*"
43 |
44 | [components.tok2vec]
45 | factory = "tok2vec"
46 |
47 | [components.tok2vec.model]
48 | @architectures = "spacy.Tok2Vec.v2"
49 |
50 | [components.tok2vec.model.embed]
51 | @architectures = "spacy.MultiHashEmbed.v2"
52 | width = ${components.tok2vec.model.encode.width}
53 | attrs = ["ORTH","SHAPE"]
54 | rows = [5000,2500]
55 | include_static_vectors = true
56 |
57 | [components.tok2vec.model.encode]
58 | @architectures = "spacy.MaxoutWindowEncoder.v2"
59 | width = 256
60 | depth = 8
61 | window_size = 1
62 | maxout_pieces = 3
63 |
64 | [corpora]
65 |
66 | [corpora.dev]
67 | @readers = "spacy.Corpus.v1"
68 | path = ${paths.dev}
69 | max_length = 0
70 | gold_preproc = false
71 | limit = 0
72 | augmenter = null
73 |
74 | [corpora.train]
75 | @readers = "spacy.Corpus.v1"
76 | path = ${paths.train}
77 | max_length = 0
78 | gold_preproc = false
79 | limit = 0
80 | augmenter = null
81 |
82 | [training]
83 | dev_corpus = "corpora.dev"
84 | train_corpus = "corpora.train"
85 | seed = ${system.seed}
86 | gpu_allocator = ${system.gpu_allocator}
87 | dropout = 0.1
88 | accumulate_gradient = 1
89 | patience = 1600
90 | max_epochs = 0
91 | max_steps = 20000
92 | eval_frequency = 200
93 | frozen_components = []
94 | annotating_components = []
95 | before_to_disk = null
96 |
97 | [training.batcher]
98 | @batchers = "spacy.batch_by_words.v1"
99 | discard_oversize = false
100 | tolerance = 0.2
101 | get_length = null
102 |
103 | [training.batcher.size]
104 | @schedules = "compounding.v1"
105 | start = 100
106 | stop = 1000
107 | compound = 1.001
108 | t = 0.0
109 |
110 | [training.logger]
111 | @loggers = "spacy.ConsoleLogger.v1"
112 | progress_bar = false
113 |
114 | [training.optimizer]
115 | @optimizers = "Adam.v1"
116 | beta1 = 0.9
117 | beta2 = 0.999
118 | L2_is_weight_decay = true
119 | L2 = 0.01
120 | grad_clip = 1.0
121 | use_averages = false
122 | eps = 0.00000001
123 | learn_rate = 0.001
124 |
125 | [training.score_weights]
126 | cats_score = 1.0
127 | cats_score_desc = null
128 | cats_micro_p = null
129 | cats_micro_r = null
130 | cats_micro_f = null
131 | cats_macro_p = null
132 | cats_macro_r = null
133 | cats_macro_f = null
134 | cats_macro_auc = null
135 | cats_f_per_type = null
136 | cats_macro_auc_per_type = null
137 |
138 | [pretraining]
139 |
140 | [initialize]
141 | vectors = ${paths.vectors}
142 | init_tok2vec = ${paths.init_tok2vec}
143 | vocab_data = null
144 | lookups = null
145 | before_init = null
146 | after_init = null
147 |
148 | [initialize.components]
149 |
150 | [initialize.tokenizer]
--------------------------------------------------------------------------------
/experiments/relation_modeling/spacy/relation_model_spacy.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from relation_modeling_utils import load_data\n",
10 | "\n",
11 | "train_data = load_data(\"data/atomic2020_data-feb2021/train.tsv\")\n",
12 | "dev_data = load_data(\"data/atomic2020_data-feb2021/dev.tsv\")"
13 | ]
14 | },
15 | {
16 | "cell_type": "code",
17 | "execution_count": 2,
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "data": {
22 | "text/plain": [
23 | "(53891, 4823)"
24 | ]
25 | },
26 | "execution_count": 2,
27 | "metadata": {},
28 | "output_type": "execute_result"
29 | }
30 | ],
31 | "source": [
32 | "len(train_data), len(dev_data)"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 4,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "import spacy\n",
42 | "from tqdm import tqdm\n",
43 | "nlp = spacy.load(\"en_core_web_lg\")\n",
44 | "\n",
45 | "label_map = ['physical', 'event', 'social']\n",
46 | "\n",
47 | "def make_docs(data):\n",
48 | " \"\"\"\n",
49 | " this will take a list of texts and labels \n",
50 | " and transform them in spacy documents\n",
51 | " \n",
52 | " data: list(tuple(text, label))\n",
53 | " \n",
54 | " returns: List(spacy.Doc.doc)\n",
55 | " \"\"\"\n",
56 | " \n",
57 | " docs = []\n",
58 | " \n",
59 | " for doc, label in tqdm(nlp.pipe(data, as_tuples=True), total=len(data)):\n",
60 | " \n",
61 | " for label_txt in label_map:\n",
62 | " doc.cats[label_txt] = 0\n",
63 | "\n",
64 | " doc.cats[label_map[label]] = 1\n",
65 | " \n",
66 | " # put them into a nice list\n",
67 | " docs.append(doc)\n",
68 | " \n",
69 | " return docs"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 10,
75 | "metadata": {},
76 | "outputs": [
77 | {
78 | "name": "stderr",
79 | "output_type": "stream",
80 | "text": [
81 | "100%|██████████| 53891/53891 [00:34<00:00, 1570.54it/s]\n",
82 | "100%|██████████| 4823/4823 [00:03<00:00, 1498.61it/s]\n"
83 | ]
84 | }
85 | ],
86 | "source": [
87 | "from spacy.tokens import DocBin\n",
88 | "\n",
89 | "train_docs = make_docs(list(train_data.itertuples(index=False, name=None)))\n",
90 | "dev_docs = make_docs(list(dev_data.itertuples(index=False, name=None)))\n",
91 | "\n",
92 | "# then we save it in a binary file to disc\n",
93 | "train_doc_bin = DocBin(docs=train_docs)\n",
94 | "train_doc_bin.to_disk(\"./data/train.spacy\")\n",
95 | "\n",
96 | "dev_doc_bin = DocBin(docs=dev_docs)\n",
97 | "dev_doc_bin.to_disk(\"./data/dev.spacy\")"
98 | ]
99 | }
100 | ],
101 | "metadata": {
102 | "interpreter": {
103 | "hash": "7c3b128559c7e8fd624042ca8b6c93b33cd59aca7b58d05c9d4cd21ec1a84d35"
104 | },
105 | "kernelspec": {
106 | "display_name": "Python 3.8.12 ('kogito')",
107 | "language": "python",
108 | "name": "python3"
109 | },
110 | "language_info": {
111 | "codemirror_mode": {
112 | "name": "ipython",
113 | "version": 3
114 | },
115 | "file_extension": ".py",
116 | "mimetype": "text/x-python",
117 | "name": "python",
118 | "nbconvert_exporter": "python",
119 | "pygments_lexer": "ipython3",
120 | "version": "3.8.12"
121 | },
122 | "orig_nbformat": 4
123 | },
124 | "nbformat": 4,
125 | "nbformat_minor": 2
126 | }
127 |
--------------------------------------------------------------------------------
/experiments/relation_modeling/swem/relation_modeling_swem.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import DataLoader
3 | import pytorch_lightning as pl
4 | from pytorch_lightning.loggers import WandbLogger
5 | import wandb
6 |
7 | from kogito.core.processors.models.swem import (
8 | SWEMConfig,
9 | SWEMClassifier,
10 | SWEMHeadDataset,
11 | )
12 | from relation_modeling_utils import load_fdata, load_data, get_timestamp
13 |
14 | DATASET_TYPE = "n1"
15 | NUM_EPOCHS = 20
16 | LR_RATE = 1e-4
17 | FREEZE_EMB = False
18 | BATCH_SIZE = 128
19 | POOLING = "avg"
20 |
21 | VOCAB = np.load("data/vocab_glove_100d.npy", allow_pickle=True).item()
22 |
23 | if __name__ == "__main__":
24 | train_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/train_{DATASET_TYPE}.csv")
25 | val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
26 | test_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/test_{DATASET_TYPE}.csv")
27 | train_data = SWEMHeadDataset(train_df, vocab=VOCAB)
28 | val_data = SWEMHeadDataset(val_df, vocab=VOCAB)
29 | test_data = SWEMHeadDataset(test_df, vocab=VOCAB)
30 |
31 | train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
32 | val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE)
33 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)
34 |
35 | timestamp = get_timestamp()
36 |
37 | emb_txt = "frozen" if FREEZE_EMB else "finetune"
38 |
39 | wandb_logger = WandbLogger(
40 | project="kogito-relation-matcher", name=f"swem_{emb_txt}_{DATASET_TYPE}"
41 | )
42 | wandb_logger.experiment.config["epochs"] = NUM_EPOCHS
43 | wandb_logger.experiment.config["batch_size"] = BATCH_SIZE
44 | config = SWEMConfig(pooling=POOLING, freeze_emb=FREEZE_EMB, learning_rate=LR_RATE)
45 | model = SWEMClassifier(config)
46 | trainer = pl.Trainer(
47 | default_root_dir="models/swem",
48 | max_epochs=NUM_EPOCHS,
49 | logger=wandb_logger,
50 | accelerator="gpu",
51 | devices=[0],
52 | )
53 | trainer.fit(
54 | model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
55 | )
56 | trainer.test(model, dataloaders=test_dataloader)
57 | trainer.save_checkpoint(
58 | f"models/swem/swem_{emb_txt}_{DATASET_TYPE}_{timestamp}.ckpt", weights_only=True
59 | )
60 | model.save_pretrained("hmodels/swem")
61 | wandb.finish()
62 |
--------------------------------------------------------------------------------
/experiments/relation_modeling/transformer/relation_modeling_bert.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import pytorch_lightning as pl
3 | from relation_modeling_utils import load_fdata, get_timestamp, load_data
4 | from pytorch_lightning.loggers import WandbLogger
5 | import wandb
6 |
7 | from kogito.core.processors.models.bert import (
8 | BERTConfig,
9 | BERTHeadDataset,
10 | BERTClassifier,
11 | )
12 |
13 | MODEL_TYPE = "uncased"
14 | NUM_EPOCHS = 3
15 | BATCH_SIZE = 2
16 | FREEZE_EMB = True
17 | DATASET_TYPE = "n1"
18 | LR_RATE = 1e-4
19 |
20 | if __name__ == "__main__":
21 | train_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/train_{DATASET_TYPE}.csv")
22 | val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
23 | test_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/test_{DATASET_TYPE}.csv")
24 | train_data = BERTHeadDataset(train_df, tokenizer_type=MODEL_TYPE)
25 | val_data = BERTHeadDataset(val_df, tokenizer_type=MODEL_TYPE)
26 | test_data = BERTHeadDataset(test_df, tokenizer_type=MODEL_TYPE)
27 |
28 | train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
29 | val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE)
30 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)
31 |
32 | timestamp = get_timestamp()
33 | emb_txt = "frozen" if FREEZE_EMB else "finetune"
34 |
35 | wandb_logger = WandbLogger(
36 | project="kogito-relation-matcher",
37 | name=f"bert_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}",
38 | )
39 | wandb_logger.experiment.config["epochs"] = NUM_EPOCHS
40 | wandb_logger.experiment.config["batch_size"] = BATCH_SIZE
41 | config = BERTConfig(
42 | learning_rate=LR_RATE, model_case=MODEL_TYPE, freeze_emb=FREEZE_EMB
43 | )
44 | model = BERTClassifier(config)
45 | trainer = pl.Trainer(
46 | default_root_dir="models/bert",
47 | max_epochs=NUM_EPOCHS,
48 | logger=wandb_logger,
49 | accelerator="gpu",
50 | devices=[0],
51 | )
52 | trainer.fit(
53 | model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
54 | )
55 | trainer.test(model, dataloaders=test_dataloader)
56 | trainer.save_checkpoint(
57 | f"models/bert/bert_model_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}_{timestamp}.ckpt",
58 | weights_only=True,
59 | )
60 | model.save_pretrained("hmodels/bert")
61 | wandb.finish()
62 |
--------------------------------------------------------------------------------
/experiments/relation_modeling/transformer/relation_modeling_distillbert.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | import pytorch_lightning as pl
3 |
4 | from relation_modeling_utils import get_timestamp, load_fdata, load_data
5 | from pytorch_lightning.loggers import WandbLogger
6 | import wandb
7 |
8 | from kogito.core.processors.models.distilbert import (
9 | DistilBERTConfig,
10 | DistilBERTHeadDataset,
11 | DistilBERTClassifier,
12 | )
13 |
14 | MODEL_TYPE = "uncased"
15 | NUM_EPOCHS = 3
16 | BATCH_SIZE = 64
17 | DATASET_TYPE = "n1"
18 | FREEZE_EMB = False
19 | LR_RATE = 1e-4
20 |
21 |
22 | if __name__ == "__main__":
23 | train_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/train_{DATASET_TYPE}.csv")
24 | val_df = load_data("data/atomic2020_data-feb2021/dev.tsv", multi_label=True)
25 | test_df = load_fdata(f"data/atomic_ood2/{DATASET_TYPE}/test_{DATASET_TYPE}.csv")
26 | train_data = DistilBERTHeadDataset(train_df, tokenizer_type=MODEL_TYPE)
27 | val_data = DistilBERTHeadDataset(val_df, tokenizer_type=MODEL_TYPE)
28 | test_data = DistilBERTHeadDataset(test_df, tokenizer_type=MODEL_TYPE)
29 |
30 | train_dataloader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
31 | val_dataloader = DataLoader(val_data, batch_size=BATCH_SIZE)
32 | test_dataloader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)
33 |
34 | emb_txt = "frozen" if FREEZE_EMB else "finetune"
35 |
36 | timestamp = get_timestamp()
37 | wandb_logger = WandbLogger(
38 | project="kogito-relation-matcher",
39 | name=f"distilbert_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}",
40 | )
41 | wandb_logger.experiment.config["epochs"] = NUM_EPOCHS
42 | wandb_logger.experiment.config["batch_size"] = BATCH_SIZE
43 | config = DistilBERTConfig(
44 | learning_rate=LR_RATE, model_case=MODEL_TYPE, freeze_emb=FREEZE_EMB
45 | )
46 | model = DistilBERTClassifier(config)
47 | trainer = pl.Trainer(
48 | default_root_dir="models/distilbert",
49 | max_epochs=NUM_EPOCHS,
50 | logger=wandb_logger,
51 | accelerator="gpu",
52 | devices=[1],
53 | )
54 | trainer.fit(
55 | model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader
56 | )
57 | trainer.test(model, dataloaders=test_dataloader)
58 | trainer.save_checkpoint(
59 | f"models/distilbert/distilbert_model_{emb_txt}_{MODEL_TYPE}_{DATASET_TYPE}_{timestamp}.ckpt",
60 | weights_only=True,
61 | )
62 | model.save_pretrained("hmodels/distilbert")
63 | wandb.finish()
64 |
--------------------------------------------------------------------------------
/kogito/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/__init__.py
--------------------------------------------------------------------------------
/kogito/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/__init__.py
--------------------------------------------------------------------------------
/kogito/core/callbacks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import pytorch_lightning as pl
7 | import torch
8 | from pytorch_lightning.callbacks import ModelCheckpoint
9 | from pytorch_lightning.utilities import rank_zero_only
10 | from pytorch_lightning.utilities import rank_zero_info
11 |
12 |
13 | def count_trainable_parameters(model):
14 | model_parameters = filter(lambda p: p.requires_grad, model.parameters())
15 | params = sum([np.prod(p.size()) for p in model_parameters])
16 | return params
17 |
18 |
19 | logger = logging.getLogger(__name__)
20 |
21 |
22 | class Seq2SeqLoggingCallback(pl.Callback):
23 | @rank_zero_only
24 | def _write_logs(
25 | self,
26 | trainer: pl.Trainer,
27 | pl_module: pl.LightningModule,
28 | type_path: str,
29 | save_generations=True,
30 | ) -> None:
31 | logger.info(
32 | f"***** {type_path} results at step {trainer.global_step:05d} *****"
33 | )
34 | metrics = trainer.callback_metrics
35 | trainer.logger.log_metrics(
36 | {
37 | k: v
38 | for k, v in metrics.items()
39 | if k not in ["log", "progress_bar", "preds"]
40 | }
41 | )
42 | # Log results
43 | od = Path(pl_module.config.output_dir)
44 | if type_path == "test":
45 | results_file = od / "test_results.txt"
46 | generations_file = od / "test_generations.txt"
47 | else:
48 | # this never gets hit. I prefer not to save intermediate generations, and results are in metrics.json
49 | # If people want this it will be easy enough to add back.
50 | results_file = od / f"{type_path}_results/{trainer.global_step:05d}.txt"
51 | generations_file = (
52 | od / f"{type_path}_generations/{trainer.global_step:05d}.txt"
53 | )
54 | results_file.parent.mkdir(exist_ok=True)
55 | generations_file.parent.mkdir(exist_ok=True)
56 | with open(results_file, "a+") as writer:
57 | for key in sorted(metrics):
58 | if key in ["log", "progress_bar", "preds"]:
59 | continue
60 | val = metrics[key]
61 | if isinstance(val, torch.Tensor):
62 | val = val.item()
63 | msg = f"{key}: {val:.6f}\n"
64 | writer.write(msg)
65 |
66 | if not save_generations:
67 | return
68 |
69 | if "preds" in metrics:
70 | content = "\n".join(metrics["preds"])
71 | generations_file.open("w+").write(content)
72 |
73 | @rank_zero_only
74 | def on_train_start(self, trainer, pl_module):
75 | try:
76 | npars = pl_module.model.model.num_parameters()
77 | except AttributeError:
78 | npars = pl_module.model.num_parameters()
79 |
80 | n_trainable_pars = count_trainable_parameters(pl_module)
81 | # mp stands for million parameters
82 | trainer.logger.log_metrics(
83 | {"n_params": npars, "mp": npars / 1e6, "grad_mp": n_trainable_pars / 1e6}
84 | )
85 |
86 | @rank_zero_only
87 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
88 | return self._write_logs(trainer, pl_module, "test")
89 |
90 |
91 | def get_checkpoint_callback(output_dir, metric):
92 | """Saves the best model by validation ROUGE2 score."""
93 | if metric == "rouge2":
94 | exp = "{val_avg_rouge2:.4f}-{step_count}"
95 | elif metric == "bleu":
96 | exp = "{val_avg_bleu:.4f}-{step_count}"
97 | else:
98 | raise NotImplementedError(
99 | f"seq2seq callbacks only support rouge2 and bleu, got {metric},"
100 | "You can make your own by adding to this function."
101 | )
102 |
103 | checkpoint_callback = ModelCheckpoint(
104 | dirpath=output_dir,
105 | filename=exp,
106 | monitor=f"val_{metric}",
107 | mode="max",
108 | save_top_k=1,
109 | )
110 | return checkpoint_callback
111 |
112 |
113 | class LoggingCallback(pl.Callback):
114 | def on_batch_end(self, trainer, pl_module):
115 | lrs = {f"lr_group_{i}": lr for i, lr in enumerate(self.lr_scheduler.get_lr())}
116 | pl_module.logger.log_metrics(lrs)
117 |
118 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
119 | rank_zero_info("***** Validation results *****")
120 | metrics = trainer.callback_metrics
121 | # Log results
122 | for key in sorted(metrics):
123 | if key not in ["log", "progress_bar"]:
124 | rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
125 |
126 | def on_test_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
127 | rank_zero_info("***** Test results *****")
128 | metrics = trainer.callback_metrics
129 | # Log and save results to file
130 | output_test_results_file = os.path.join(
131 | pl_module.base_config.output_dir, "test_results.txt"
132 | )
133 | with open(output_test_results_file, "w") as writer:
134 | for key in sorted(metrics):
135 | if key not in ["log", "progress_bar"]:
136 | rank_zero_info("{} = {}\n".format(key, str(metrics[key])))
137 | writer.write("{} = {}\n".format(key, str(metrics[key])))
138 |
--------------------------------------------------------------------------------
/kogito/core/head.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 | from enum import Enum
3 |
4 |
5 | class KnowledgeHeadType(Enum):
6 | """
7 | Type of a Knowledge Head
8 | """
9 |
10 | TEXT = "text"
11 | SENTENCE = "sentence"
12 | NOUN_PHRASE = "noun_phrase"
13 | VERB_PHRASE = "verb_phrase"
14 |
15 |
16 | class KnowledgeHead:
17 | """
18 | Represents a concept of Knowledge Head.
19 | """
20 |
21 | def __init__(
22 | self,
23 | text: str,
24 | type: KnowledgeHeadType = KnowledgeHeadType.TEXT,
25 | entity: Any = None,
26 | verbalizer: Optional[Callable] = None,
27 | ) -> None:
28 | """Initialize a Knowledge Head.
29 |
30 | Args:
31 | text (str): Head text.
32 | type (KnowledgeHeadType, optional): Type of a Knowledge head. Defaults to KnowledgeHeadType.TEXT.
33 | entity (Any, optional): External Knowledge head entity. Defaults to None.
34 | verbalizer (Optional[Callable], optional): Function to convert knowledge head to natural text.
35 | Defaults to None.
36 | """
37 | self.text = text.strip()
38 | self.type = type
39 | self.entity = entity
40 | self.verbalizer = verbalizer
41 |
42 | def __eq__(self, other: object) -> bool:
43 | return isinstance(other, KnowledgeHead) and self.text == other.text
44 |
45 | def __ne__(self, other) -> bool:
46 | return not self.__eq__(other)
47 |
48 | def __hash__(self) -> int:
49 | return hash(self.text)
50 |
51 | def verbalize(self) -> Optional[str]:
52 | """Convert head to a meaningful text.
53 |
54 | Returns:
55 | Optional[str]: Verbalized head
56 | """
57 | if self.verbalizer:
58 | return self.verbalizer(self.text)
59 |
60 | def __repr__(self) -> str:
61 | return str(self.text)
62 |
63 | def copy(self) -> "KnowledgeHead":
64 | """Copy itself
65 |
66 | Returns:
67 | KnowledgeHead: Copied knowledge head
68 | """
69 | return KnowledgeHead(
70 | text=self.text,
71 | type=self.type,
72 | entity=self.entity,
73 | verbalizer=self.verbalizer,
74 | )
75 |
76 |
77 | def head_verbalizer(head: str):
78 | return head
79 |
--------------------------------------------------------------------------------
/kogito/core/linker.py:
--------------------------------------------------------------------------------
1 | from typing import List, Union, Tuple
2 | from abc import ABC, abstractmethod, abstractclassmethod
3 |
4 | from kogito.core.knowledge import KnowledgeGraph
5 |
6 |
7 | class KnowledgeLinker(ABC):
8 | """Base Knowledge Linker"""
9 |
10 | @abstractmethod
11 | def save_pretrained(self, save_path: str) -> None:
12 | """Save linker as a pretrained model
13 |
14 | Args:
15 | save_path (str): Directory to save the linker to.
16 |
17 | Raises:
18 | NotImplementedError: This method has to be implemented by
19 | concrete subclasses.
20 | """
21 | raise NotImplementedError
22 |
23 | @abstractclassmethod
24 | def from_pretrained(cls, model_name_or_path: str) -> "KnowledgeLinker":
25 | """Load model from a pretrained model path
26 | This method can load linkers either from HuggingFace by model name
27 | or from disk by model path.
28 |
29 | Args:
30 | model_name_or_path (str): HuggingFace model name or local model path to load from.
31 |
32 | Raises:
33 | NotImplementedError: This method has to be implemented by
34 | concrete subclasses.
35 |
36 | Returns:
37 | KnowledgeLinker: Loaded knowledge linker.
38 | """
39 | raise NotImplementedError
40 |
41 | @abstractmethod
42 | def link(
43 | self, input_graph: KnowledgeGraph, context: Union[List[str], str]
44 | ) -> List[List[float]]:
45 | """Link given knowledge graph with the context.
46 | This method computes a relevance probability for each knowledge in the graph
47 | with respect to the given context and returns these probabilities in a list
48 | in the same order as the knowledge tuples are in the given graph. Note that returned object
49 | is a list of list of numbers because a knowledge tuple might have multiple tails and the probability
50 | is calculated for each combination.
51 |
52 | Args:
53 | input_graph (KnowledgeGraph): Input graph to link.
54 | context (Union[List[str], str]): Context text. Can be either given as a list of
55 | sentences or as a string, in which case, it will be
56 | split into sentences using spacy engine.
57 |
58 | Returns:
59 | List[List[float]]: List of relevance probabilities for each tail
60 | """
61 | raise NotImplementedError
62 |
63 | def filter(
64 | self,
65 | input_graph: KnowledgeGraph,
66 | context: Union[List[str], str],
67 | threshold: float = 0.5,
68 | return_probs: bool = False,
69 | ) -> Union[KnowledgeGraph, Tuple[KnowledgeGraph, List[List[float]]]]:
70 | """Filter given graph based on context relevancy.
71 | This method under the hood links the graph to the context and then filters knowledge tuples from the graph
72 | which have a relevance probability lower than the given threshold. Filtered knowledge tuples
73 | are returned as a new knowledge graph. If there are multiple tails for a given knowledge, these tails will be
74 | filtered as well.
75 |
76 | Args:
77 | input_graph (KnowledgeGraph): Input graph to filter.
78 | context (Union[List[str], str]): Context text. Can be either given as a list of
79 | sentences or as a string, in which case, it will be
80 | split into sentences using spacy engine.
81 | threshold (float, optional): Relevance probability used for filtering. Defaults to 0.5.
82 | return_probs (bool, optional): Whether to return all the relevancy probs for the input graph.
83 | Defaults to False.
84 | Returns:
85 | Union[KnowledgeGraph, Tuple[KnowledgeGraph, List[List[float]]]]:
86 | Filtered knowledge graph based on the relevancy scores and optionally, the relevancy scores.
87 | """
88 | probs = self.link(input_graph, context)
89 | filtered_kgs = []
90 |
91 | for kg, tail_probs in zip(input_graph, probs):
92 | filtered_tails = []
93 |
94 | for i, prob in enumerate(tail_probs):
95 | if prob >= threshold:
96 | filtered_tails.append(kg.tails[i])
97 |
98 | if filtered_tails:
99 | kg.tails = filtered_tails
100 | filtered_kgs.append(kg)
101 |
102 | output_graph = KnowledgeGraph(filtered_kgs)
103 |
104 | if return_probs:
105 | return output_graph, probs
106 |
107 | return output_graph
108 |
--------------------------------------------------------------------------------
/kogito/core/model.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from abc import ABC, abstractmethod, abstractclassmethod
4 | from kogito.core.knowledge import KnowledgeGraph
5 | from kogito.evaluation.eval import topk_eval, METRIC_MAP
6 |
7 |
8 | class KnowledgeModel(ABC):
9 | """
10 | Base class to represent a Knowledge Model.
11 | """
12 |
13 | @abstractmethod
14 | def train(self, train_graph: KnowledgeGraph, *args, **kwargs) -> "KnowledgeModel":
15 | """Train a knowledge model
16 |
17 | Args:
18 | train_graph (KnowledgeGraph): Training dataset
19 |
20 | Raises:
21 | NotImplementedError: This method has to be implemented
22 | by concrete subclasses.
23 |
24 | Returns:
25 | KnowledgeModel: Trained knowledge model
26 | """
27 | raise NotImplementedError
28 |
29 | @abstractmethod
30 | def generate(self, input_graph: KnowledgeGraph, *args, **kwargs) -> KnowledgeGraph:
31 | """Generate inferences from knowledge model
32 |
33 | Args:
34 | input_graph (KnowledgeGraph): Input dataset
35 |
36 | Raises:
37 | NotImplementedError: This method has to be implemented by
38 | concrete subclasses.
39 |
40 | Returns:
41 | KnowledgeGraph: Input graph with tails generated
42 | """
43 | raise NotImplementedError
44 |
45 | @abstractmethod
46 | def save_pretrained(self, save_path: str) -> None:
47 | """Save model as a pretrained model
48 |
49 | Args:
50 | save_path (str): Directory to save the model to.
51 |
52 | Raises:
53 | NotImplementedError: This method has to be implemented by
54 | concrete subclasses.
55 | """
56 | raise NotImplementedError
57 |
58 | @abstractclassmethod
59 | def from_pretrained(cls, model_name_or_path: str) -> "KnowledgeModel":
60 | """Load model from a pretrained model path
61 | This method can load models either from HuggingFace by model name
62 | or from disk by model path.
63 |
64 | Args:
65 | model_name_or_path (str): HuggingFace model name or local model path to load from.
66 |
67 | Raises:
68 | NotImplementedError: This method has to be implemented by
69 | concrete subclasses.
70 |
71 | Returns:
72 | KnowledgeModel: Loaded knowledge model.
73 | """
74 | raise NotImplementedError
75 |
76 | def evaluate(
77 | self,
78 | input_graph: KnowledgeGraph,
79 | metrics: List[str] = ["bleu", "meteor", "rouge", "cider", "bert-score"],
80 | top_k: int = 1,
81 | *args,
82 | **kwargs,
83 | ) -> dict:
84 | """Evaluate model on various metrics.
85 | Input graph should contain the reference tails, so that it can be used to score the model generations
86 | on the same input graph. Any arguments provided aside from the ones accepted by this method will be
87 | passed onto the ``KnowledgeModel.generate`` method.
88 |
89 | Args:
90 | input_graph (KnowledgeGraph): Input graph to evaluate. Should contain the ground truth tails.
91 | metrics (List[str], optional): Metrics to compute.
92 | Defaults to ["bleu", "meteor", "rouge", "cider", "bert-score"].
93 | top_k (int, optional): Top k generations to evaluate. Defaults to 1.
94 | *args (optional): Extra arguments for `KnowledgeModel.generate` method.
95 | **kwargs (optional): Extra keyword arguments for ``KnowledgeModel.generate`` method.
96 |
97 | Returns:
98 | dict: Dictionary of scores
99 | """
100 | return evaluate(self, input_graph, metrics, top_k=top_k, *args, **kwargs)
101 |
102 |
103 | def evaluate(
104 | model: KnowledgeModel,
105 | input_graph: KnowledgeGraph,
106 | metrics: List[str] = ["bleu", "meteor", "rouge", "cider", "bert-score"],
107 | top_k: int = 1,
108 | *args,
109 | **kwargs,
110 | ):
111 | if not set(metrics).issubset(set(METRIC_MAP.keys())):
112 | raise ValueError(
113 | f"Invalid evaluation metrics found: {set(metrics) - set(METRIC_MAP.keys())}"
114 | )
115 |
116 | output_graph = model.generate(input_graph=input_graph, *args, **kwargs)
117 | evaluation_data = []
118 |
119 | for input_kg, output_kg in zip(input_graph, output_graph):
120 | assert input_kg.head == output_kg.head
121 | assert input_kg.relation == output_kg.relation
122 | assert len(input_kg.tails) > 0
123 |
124 | evaluation_data.append((output_kg, input_kg.tails))
125 |
126 | return topk_eval(evaluation_data, metrics, k=top_k)
127 |
--------------------------------------------------------------------------------
/kogito/core/processors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/processors/__init__.py
--------------------------------------------------------------------------------
/kogito/core/processors/data/vocab_glove_100d.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/processors/data/vocab_glove_100d.npy
--------------------------------------------------------------------------------
/kogito/core/processors/head.py:
--------------------------------------------------------------------------------
1 | import string
2 | from abc import ABC, abstractmethod
3 | from typing import List, Optional
4 |
5 | from spacy.tokens import Doc
6 | from spacy.language import Language
7 | from spacy.lang.en.stop_words import STOP_WORDS
8 |
9 | from kogito.core.head import KnowledgeHead, KnowledgeHeadType
10 | from kogito.core.utils import IGNORE_WORDS
11 |
12 |
13 | class KnowledgeHeadExtractor(ABC):
14 | """Base class for head extraction"""
15 |
16 | def __init__(self, name: str, lang: Optional[Language] = None) -> None:
17 | """Initialize a head extractor
18 |
19 | Args:
20 | name (str): Unique head extractor name
21 | lang (Optional[Language], optional): Spacy language pipeline to use. Defaults to None.
22 | """
23 | self.name = name
24 | self.lang = lang
25 |
26 | @abstractmethod
27 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]:
28 | """Extract heads from text
29 |
30 | Args:
31 | text (str): Text to extract from
32 | doc (Optional[Doc], optional): Spacy doc to use for extraction. Defaults to None.
33 |
34 | Raises:
35 | NotImplementedError: This method has to be implemented by
36 | concrete subclasses.
37 |
38 | Returns:
39 | List[KnowledgeHead]: List of extracted knowledge heads.
40 | """
41 | raise NotImplementedError
42 |
43 |
44 | class SentenceHeadExtractor(KnowledgeHeadExtractor):
45 | """Extracts sentences as heads from text"""
46 |
47 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]:
48 | if not doc:
49 | doc = self.lang(text)
50 |
51 | heads = []
52 |
53 | for sentence in doc.sents:
54 | if sentence.text.strip():
55 | heads.append(
56 | KnowledgeHead(
57 | text=sentence.text,
58 | type=KnowledgeHeadType.SENTENCE,
59 | entity=sentence,
60 | )
61 | )
62 |
63 | return heads
64 |
65 |
66 | class NounPhraseHeadExtractor(KnowledgeHeadExtractor):
67 | """Extracts noun phrases as heads from text"""
68 |
69 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]:
70 | if not doc:
71 | doc = self.lang(text)
72 |
73 | heads = []
74 | head_texts = set()
75 |
76 | for token in doc:
77 | if (
78 | token.text.strip().lower() not in STOP_WORDS.union(IGNORE_WORDS)
79 | and token.pos_ == "NOUN"
80 | ):
81 | token_text = token.text.strip(string.punctuation + " ")
82 | if token_text not in head_texts and len(token_text) > 1:
83 | head_texts.add(token_text)
84 | heads.append(
85 | KnowledgeHead(
86 | text=token.text.strip(),
87 | type=KnowledgeHeadType.NOUN_PHRASE,
88 | entity=token,
89 | )
90 | )
91 |
92 | for phrase in doc.noun_chunks:
93 | clean_phrase = []
94 | phrase_doc = self.lang(phrase.text)
95 |
96 | for token in phrase_doc:
97 | if token.text.strip().lower() not in STOP_WORDS.union(IGNORE_WORDS):
98 | clean_phrase.append(token.text)
99 |
100 | clean_text = " ".join(clean_phrase).strip(string.punctuation + " ")
101 |
102 | if clean_text and clean_text not in head_texts and len(clean_text) > 1:
103 | head_texts.add(clean_text)
104 | heads.append(
105 | KnowledgeHead(
106 | text=clean_text,
107 | type=KnowledgeHeadType.NOUN_PHRASE,
108 | entity=phrase,
109 | )
110 | )
111 |
112 | return heads
113 |
114 |
115 | class VerbPhraseHeadExtractor(KnowledgeHeadExtractor):
116 | """Extracts verb phrases as heads from text"""
117 |
118 | def extract(self, text: str, doc: Optional[Doc] = None) -> List[KnowledgeHead]:
119 | if not doc:
120 | doc = self.lang(text)
121 |
122 | heads = []
123 | head_texts = set()
124 |
125 | for token in doc:
126 | if token.pos_ == "VERB":
127 | verb_text = f"to {token.lemma_}"
128 |
129 | if verb_text not in head_texts:
130 | head_texts.add(verb_text)
131 | heads.append(
132 | KnowledgeHead(
133 | text=verb_text,
134 | type=KnowledgeHeadType.VERB_PHRASE,
135 | entity=token,
136 | )
137 | )
138 |
139 | for child in token.children:
140 | if child.dep_ in ("attr", "dobj"):
141 | child_text = f"{token.lemma_} {child.text}"
142 | if child_text not in head_texts:
143 | head_texts.add(child_text)
144 | heads.append(
145 | KnowledgeHead(
146 | text=child_text,
147 | type=KnowledgeHeadType.VERB_PHRASE,
148 | entity=[token, child],
149 | )
150 | )
151 |
152 | return heads
153 |
--------------------------------------------------------------------------------
/kogito/core/processors/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/core/processors/models/__init__.py
--------------------------------------------------------------------------------
/kogito/core/processors/models/bert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | from torch.utils.data import Dataset
5 | from torch.optim import Adam
6 | from torch import nn
7 | import pytorch_lightning as pl
8 | from transformers import BertTokenizer, BertModel, PretrainedConfig, PreTrainedModel
9 |
10 | from kogito.core.processors.models.utils import Evaluator
11 |
12 |
13 | class BERTHeadDataset(Dataset):
14 | def __init__(self, data, tokenizer_type="uncased"):
15 | self.tokenizer = BertTokenizer.from_pretrained(f"bert-base-{tokenizer_type}")
16 | self.labels = (
17 | np.asarray(data["label"].to_list())
18 | if isinstance(data, pd.DataFrame)
19 | else None
20 | )
21 | texts = data["text"] if isinstance(data, pd.DataFrame) else data
22 | self.features = [
23 | self.tokenizer(
24 | text,
25 | padding="max_length",
26 | max_length=32,
27 | truncation=True,
28 | return_tensors="pt",
29 | )
30 | for text in texts
31 | ]
32 |
33 | def __len__(self):
34 | return len(self.features)
35 |
36 | def __getitem__(self, idx):
37 | if self.labels is not None:
38 | return self.features[idx], self.labels[idx]
39 | return self.features[idx]
40 |
41 |
42 | class BERTConfig(PretrainedConfig):
43 | def __init__(
44 | self,
45 | num_classes=3,
46 | dropout=0.5,
47 | learning_rate=1e-4,
48 | freeze_emb=False,
49 | model_case="uncased",
50 | **kwargs,
51 | ):
52 | self.num_classes = num_classes
53 | self.dropout = dropout
54 | self.learning_rate = learning_rate
55 | self.freeze_emb = freeze_emb
56 | self.model_case = model_case
57 | super().__init__(**kwargs)
58 |
59 |
60 | class BERTClassifier(PreTrainedModel, Evaluator, pl.LightningModule):
61 | config_class = BERTConfig
62 |
63 | def __init__(self, config: BERTConfig):
64 | super().__init__(config)
65 | self.bert = BertModel.from_pretrained(f"bert-base-{config.model_case}")
66 | self.dropout = nn.Dropout(config.dropout)
67 | self.linear = nn.Linear(768, config.num_classes)
68 |
69 | if config.freeze_emb:
70 | for parameter in self.bert.parameters():
71 | parameter.requires_grad = False
72 | self.classifier = nn.Sequential(self.linear)
73 | else:
74 | self.classifier = nn.Sequential(self.dropout, self.linear)
75 |
76 | self.criterion = nn.BCEWithLogitsLoss()
77 | self.learning_rate = config.learning_rate
78 |
79 | self.save_hyperparameters(config.to_dict(), ignore="config")
80 |
81 | def forward(self, input_ids, mask):
82 | _, outputs = self.bert(
83 | input_ids=input_ids, attention_mask=mask, return_dict=False
84 | )
85 | outputs = self.classifier(outputs)
86 | return outputs
87 |
88 | def training_step(self, batch, batch_idx):
89 | X, y = batch
90 | mask = X["attention_mask"]
91 | input_ids = X["input_ids"].squeeze(1)
92 | outputs = self.forward(input_ids, mask)
93 | train_loss = self.criterion(outputs, y.float())
94 | preds = torch.sigmoid(outputs)
95 | self.log("train_loss", train_loss, on_epoch=True)
96 | self.log_metrics(preds, y, type="train")
97 | return train_loss
98 |
99 | def validation_step(self, batch, batch_idx):
100 | X, y = batch
101 | mask = X["attention_mask"]
102 | input_ids = X["input_ids"].squeeze(1)
103 | outputs = self.forward(input_ids, mask)
104 | val_loss = self.criterion(outputs, y.float())
105 | preds = torch.sigmoid(outputs)
106 | self.log("val_loss", val_loss, on_epoch=True)
107 | self.log_metrics(preds, y, type="val")
108 | return val_loss
109 |
110 | def test_step(self, batch, batch_idx):
111 | X, y = batch
112 | mask = X["attention_mask"]
113 | input_ids = X["input_ids"].squeeze(1)
114 | outputs = self.forward(input_ids, mask)
115 | test_loss = self.criterion(outputs, y.float())
116 | preds = torch.sigmoid(outputs)
117 | self.log("test_loss", test_loss, on_epoch=True)
118 | self.log_metrics(preds, y, type="test")
119 | return test_loss
120 |
121 | def predict_step(self, batch, batch_idx):
122 | X = batch
123 | mask = X["attention_mask"]
124 | input_ids = X["input_ids"].squeeze(1)
125 | outputs = self.forward(input_ids, mask)
126 | preds = torch.sigmoid(outputs)
127 | return preds
128 |
129 | def configure_optimizers(self):
130 | optimizer = Adam(self.parameters(), lr=self.learning_rate)
131 | return optimizer
132 |
--------------------------------------------------------------------------------
/kogito/core/processors/models/distilbert.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pandas as pd
4 | from torch.utils.data import Dataset
5 | from torch.optim import Adam
6 | from torch import nn
7 | import pytorch_lightning as pl
8 | from transformers import (
9 | DistilBertTokenizer,
10 | DistilBertModel,
11 | PretrainedConfig,
12 | PreTrainedModel,
13 | )
14 |
15 | from kogito.core.processors.models.utils import Evaluator
16 |
17 |
18 | class DistilBERTHeadDataset(Dataset):
19 | def __init__(self, data, tokenizer_type="uncased"):
20 | self.tokenizer = DistilBertTokenizer.from_pretrained(
21 | f"distilbert-base-{tokenizer_type}"
22 | )
23 | self.labels = (
24 | np.asarray(data["label"].to_list())
25 | if isinstance(data, pd.DataFrame)
26 | else None
27 | )
28 | texts = data["text"] if isinstance(data, pd.DataFrame) else data
29 | self.features = [
30 | self.tokenizer(
31 | text,
32 | padding="max_length",
33 | max_length=32,
34 | truncation=True,
35 | return_tensors="pt",
36 | )
37 | for text in texts
38 | ]
39 |
40 | def __len__(self):
41 | return len(self.features)
42 |
43 | def __getitem__(self, idx):
44 | if self.labels is not None:
45 | return self.features[idx], self.labels[idx]
46 | return self.features[idx]
47 |
48 |
49 | class DistilBERTConfig(PretrainedConfig):
50 | def __init__(
51 | self,
52 | num_classes=3,
53 | dropout=0.5,
54 | learning_rate=1e-4,
55 | freeze_emb=False,
56 | model_case="uncased",
57 | **kwargs,
58 | ):
59 | self.num_classes = num_classes
60 | self.dropout = dropout
61 | self.learning_rate = learning_rate
62 | self.freeze_emb = freeze_emb
63 | self.model_case = model_case
64 | super().__init__(**kwargs)
65 |
66 |
67 | class DistilBERTClassifier(PreTrainedModel, Evaluator, pl.LightningModule):
68 | config_class = DistilBERTConfig
69 |
70 | def __init__(self, config: DistilBERTConfig):
71 | super().__init__(config)
72 | self.distilbert = DistilBertModel.from_pretrained(
73 | f"distilbert-base-{config.model_case}"
74 | )
75 | self.dropout = nn.Dropout(config.dropout)
76 | self.linear = nn.Linear(768, config.num_classes)
77 |
78 | if config.freeze_emb:
79 | for parameter in self.distilbert.parameters():
80 | parameter.requires_grad = False
81 | self.classifier = nn.Sequential(self.linear)
82 | else:
83 | self.classifier = nn.Sequential(self.dropout, self.linear)
84 |
85 | self.criterion = nn.BCEWithLogitsLoss()
86 | self.learning_rate = config.learning_rate
87 |
88 | self.save_hyperparameters(config.to_dict(), ignore="config")
89 |
90 | def forward(self, input_ids, mask):
91 | outputs = self.distilbert(
92 | input_ids=input_ids, attention_mask=mask, return_dict=False
93 | )
94 | outputs = self.classifier(outputs[0][:, 0, :])
95 | return outputs
96 |
97 | def training_step(self, batch, batch_idx):
98 | X, y = batch
99 | mask = X["attention_mask"]
100 | input_ids = X["input_ids"].squeeze(1)
101 | outputs = self.forward(input_ids, mask)
102 | train_loss = self.criterion(outputs, y.float())
103 | preds = torch.sigmoid(outputs)
104 | self.log("train_loss", train_loss, on_epoch=True)
105 | self.log_metrics(preds, y, type="train")
106 | return train_loss
107 |
108 | def validation_step(self, batch, batch_idx):
109 | X, y = batch
110 | mask = X["attention_mask"]
111 | input_ids = X["input_ids"].squeeze(1)
112 | outputs = self.forward(input_ids, mask)
113 | val_loss = self.criterion(outputs, y.float())
114 | preds = torch.sigmoid(outputs)
115 | self.log("val_loss", val_loss, on_epoch=True)
116 | self.log_metrics(preds, y, type="val")
117 | return val_loss
118 |
119 | def test_step(self, batch, batch_idx):
120 | X, y = batch
121 | mask = X["attention_mask"]
122 | input_ids = X["input_ids"].squeeze(1)
123 | outputs = self.forward(input_ids, mask)
124 | test_loss = self.criterion(outputs, y.float())
125 | preds = torch.sigmoid(outputs)
126 | self.log("test_loss", test_loss, on_epoch=True)
127 | self.log_metrics(preds, y, type="test")
128 | return test_loss
129 |
130 | def predict_step(self, batch, batch_idx):
131 | X = batch
132 | mask = X["attention_mask"]
133 | input_ids = X["input_ids"].squeeze(1)
134 | outputs = self.forward(input_ids, mask)
135 | preds = torch.sigmoid(outputs)
136 | return preds
137 |
138 | def configure_optimizers(self):
139 | optimizer = Adam(self.parameters(), lr=self.learning_rate)
140 | return optimizer
141 |
--------------------------------------------------------------------------------
/kogito/core/processors/models/utils.py:
--------------------------------------------------------------------------------
1 | import torchmetrics
2 | import numpy as np
3 | import spacy
4 |
5 |
6 | def text_to_embedding(text, vocab, embedding_matrix, pooling="max", lang=None):
7 | if not lang:
8 | lang = spacy.load("en_core_web_sm")
9 |
10 | doc = lang(text)
11 | vectors = []
12 | for token in doc:
13 | if token.text in vocab:
14 | vectors.append(embedding_matrix[vocab[token.text]])
15 |
16 | if vectors:
17 | if pooling == "max":
18 | return np.amax(np.array(vectors, dtype=np.float32), axis=0)
19 | return np.mean(vectors, axis=0, dtype=np.float32)
20 |
21 |
22 | class Evaluator:
23 | def __init__(self, *args, **kwargs) -> None:
24 | super().__init__()
25 | self.metrics = dict(
26 | train_accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=3),
27 | # (weighted)
28 | train_precision=torchmetrics.Precision(task="multiclass", num_classes=3, average="weighted"),
29 | train_recall=torchmetrics.Recall(task="multiclass", num_classes=3, average="weighted"),
30 | train_f1=torchmetrics.F1Score(task="multiclass", num_classes=3, average="weighted"),
31 | # (micro)
32 | train_precision_micro=torchmetrics.Precision(
33 | task="multiclass", num_classes=3, average="micro"
34 | ),
35 | train_recall_micro=torchmetrics.Recall(task="multiclass", num_classes=3, average="micro"),
36 | train_f1_micro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="micro"),
37 | # (macro)
38 | train_precision_macro=torchmetrics.Precision(
39 | task="multiclass", num_classes=3, average="macro"
40 | ),
41 | train_recall_macro=torchmetrics.Recall(task="multiclass", num_classes=3, average="macro"),
42 | train_f1_macro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="macro"),
43 | # (per class)
44 | train_precision_class=torchmetrics.Precision(task="multiclass", num_classes=3, average="none"),
45 | train_recall_class=torchmetrics.Recall(task="multiclass", num_classes=3, average="none"),
46 | train_f1_class=torchmetrics.F1Score(task="multiclass", num_classes=3, average="none"),
47 | # Validation metrics
48 | val_accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=3),
49 | # (weighted)
50 | val_precision=torchmetrics.Precision(task="multiclass", num_classes=3, average="weighted"),
51 | val_recall=torchmetrics.Recall(task="multiclass", num_classes=3, average="weighted"),
52 | val_f1=torchmetrics.F1Score(task="multiclass", num_classes=3, average="weighted"),
53 | # (micro)
54 | val_precision_micro=torchmetrics.Precision(task="multiclass", num_classes=3, average="micro"),
55 | val_recall_micro=torchmetrics.Recall(task="multiclass", num_classes=3, average="micro"),
56 | val_f1_micro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="micro"),
57 | # (macro)
58 | val_precision_macro=torchmetrics.Precision(task="multiclass", num_classes=3, average="macro"),
59 | val_recall_macro=torchmetrics.Recall(task="multiclass", num_classes=3, average="macro"),
60 | val_f1_macro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="macro"),
61 | # (per class)
62 | val_precision_class=torchmetrics.Precision(task="multiclass", num_classes=3, average="none"),
63 | val_recall_class=torchmetrics.Recall(task="multiclass", num_classes=3, average="none"),
64 | val_f1_class=torchmetrics.F1Score(task="multiclass", num_classes=3, average="none"),
65 | # Test metrics
66 | test_accuracy=torchmetrics.Accuracy(task="multiclass", num_classes=3),
67 | # (weighted)
68 | test_precision=torchmetrics.Precision(task="multiclass", num_classes=3, average="weighted"),
69 | test_recall=torchmetrics.Recall(task="multiclass", num_classes=3, average="weighted"),
70 | test_f1=torchmetrics.F1Score(task="multiclass", num_classes=3, average="weighted"),
71 | # (micro)
72 | test_precision_micro=torchmetrics.Precision(task="multiclass", num_classes=3, average="micro"),
73 | test_recall_micro=torchmetrics.Recall(task="multiclass", num_classes=3, average="micro"),
74 | test_f1_micro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="micro"),
75 | # (macro)
76 | test_precision_macro=torchmetrics.Precision(task="multiclass", num_classes=3, average="macro"),
77 | test_recall_macro=torchmetrics.Recall(task="multiclass", num_classes=3, average="macro"),
78 | test_f1_macro=torchmetrics.F1Score(task="multiclass", num_classes=3, average="macro"),
79 | # (per class)
80 | test_precision_class=torchmetrics.Precision(task="multiclass", num_classes=3, average="none"),
81 | test_recall_class=torchmetrics.Recall(task="multiclass", num_classes=3, average="none"),
82 | test_f1_class=torchmetrics.F1Score(task="multiclass", num_classes=3, average="none"),
83 | )
84 |
85 | def log_metrics(self, preds, y, type):
86 | for metric_name, metric in self.metrics.items():
87 | if metric_name.startswith(type):
88 | metric(preds.cpu(), y.cpu())
89 | value = metric.compute()
90 | if len(value.shape) > 0:
91 | for idx, val in enumerate(value):
92 | self.log(
93 | f"{metric_name}_{idx}", val, on_epoch=True, on_step=False
94 | )
95 | else:
96 | self.log(metric_name, value, on_epoch=True, on_step=False)
97 |
--------------------------------------------------------------------------------
/kogito/evaluation/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Manish Joshi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/kogito/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/evaluation/__init__.py
--------------------------------------------------------------------------------
/kogito/evaluation/bert_score/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/evaluation/bert_score/__init__.py
--------------------------------------------------------------------------------
/kogito/evaluation/bert_score/bert_score.py:
--------------------------------------------------------------------------------
1 | from kogito.evaluation.bert_score.score import score
2 |
3 | # Code for BertScore reused from original implementation: https://github.com/Tiiiger/bert_score
4 |
5 |
6 | class BertScore:
7 | def __init__(self):
8 | self._hypo_for_image = {}
9 | self.ref_for_image = {}
10 |
11 | def compute_score(self, gts, res):
12 |
13 | assert gts.keys() == res.keys()
14 | imgIds = gts.keys()
15 |
16 | hyp_input = []
17 | ref_input = []
18 | same_indices = []
19 | for id in imgIds:
20 | hypo = res[id]
21 | ref = gts[id]
22 |
23 | # Sanity check.
24 | assert type(hypo) is list
25 | assert len(hypo) == 1
26 | assert type(ref) is list
27 | assert len(ref) >= 1
28 |
29 | hyp_input += [hypo[0]] * len(ref)
30 | ref_input += ref
31 | same_indices.append(len(ref_input))
32 |
33 | p, r, f_scores = score(hyp_input, ref_input)
34 |
35 | prev_idx = 0
36 | aggreg_f1_scores = []
37 | for idx in same_indices:
38 | aggreg_f1_scores.append(f_scores[prev_idx:idx].mean().cpu().item())
39 | prev_idx = idx
40 |
41 | return sum(aggreg_f1_scores) / len(aggreg_f1_scores), aggreg_f1_scores
42 |
43 | def method(self):
44 | return "Bert Score"
45 |
--------------------------------------------------------------------------------
/kogito/evaluation/bert_score/score.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import defaultdict
3 | from transformers import BertTokenizer, BertModel
4 |
5 | from kogito.evaluation.bert_score.utils import (
6 | get_idf_dict,
7 | bert_cos_score_idf,
8 | bert_types,
9 | )
10 |
11 |
12 | def score(
13 | cands,
14 | refs,
15 | bert="bert-base-multilingual-cased",
16 | num_layers=8,
17 | no_idf=False,
18 | batch_size=64,
19 | ):
20 | """
21 | BERTScore metric.
22 | Args:
23 | - :param: `cands` (list of str): candidate sentences
24 | - :param: `refs` (list of str): reference sentences
25 | - :param: `bert` (str): bert specification
26 | - :param: `num_layers` (int): the layer of representation to use
27 | - :param: `verbose` (bool): turn on intermediate status update
28 | - :param: `no_idf` (bool): do not use idf weighting
29 | - :param: `batch_size` (int): bert score processing batch size
30 | """
31 | assert len(cands) == len(refs)
32 | assert bert in bert_types
33 |
34 | tokenizer = BertTokenizer.from_pretrained(bert)
35 | model = BertModel.from_pretrained(bert)
36 | model.eval()
37 | device = "cuda" if torch.cuda.is_available() else "cpu"
38 | model.to(device)
39 |
40 | # drop unused layers
41 | model.encoder.layer = torch.nn.ModuleList(
42 | [layer for layer in model.encoder.layer[:num_layers]]
43 | )
44 |
45 | if no_idf:
46 | idf_dict = defaultdict(lambda: 1.0)
47 | # set idf for [SEP] and [CLS] to 0
48 | idf_dict[101] = 0
49 | idf_dict[102] = 0
50 | else:
51 | idf_dict = get_idf_dict(refs, tokenizer)
52 |
53 | all_preds = bert_cos_score_idf(
54 | model, refs, cands, tokenizer, idf_dict, device=device, batch_size=batch_size
55 | )
56 |
57 | P = all_preds[:, 0].cpu()
58 | R = all_preds[:, 1].cpu()
59 | F1 = all_preds[:, 2].cpu()
60 |
61 | return P, R, F1
62 |
--------------------------------------------------------------------------------
/kogito/evaluation/bleu/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/kogito/evaluation/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/kogito/evaluation/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = "tylin"
2 |
--------------------------------------------------------------------------------
/kogito/evaluation/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | from kogito.evaluation.bleu.bleu_scorer import BleuScorer
12 |
13 |
14 | class Bleu:
15 | def __init__(self, n=4):
16 | # default compute Blue score up to 4
17 | self._n = n
18 | self._hypo_for_image = {}
19 | self.ref_for_image = {}
20 |
21 | def compute_score(self, gts, res):
22 |
23 | assert gts.keys() == res.keys()
24 | imgIds = gts.keys()
25 |
26 | bleu_scorer = BleuScorer(n=self._n)
27 | for id in imgIds:
28 | hypo = res[id]
29 | ref = gts[id]
30 |
31 | # Sanity check.
32 | assert type(hypo) is list
33 | assert len(hypo) == 1
34 | assert type(ref) is list
35 | assert len(ref) >= 1
36 |
37 | bleu_scorer += (hypo[0], ref)
38 |
39 | # score, scores = bleu_scorer.compute_score(option='shortest')
40 | score, scores = bleu_scorer.compute_score(option="closest", verbose=0)
41 | # score, scores = bleu_scorer.compute_score(option='average', verbose=0)
42 |
43 | # return (bleu, bleu_info)
44 | return score, scores
45 |
46 | def method(self):
47 | return "Bleu"
48 |
--------------------------------------------------------------------------------
/kogito/evaluation/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = "tylin"
2 |
--------------------------------------------------------------------------------
/kogito/evaluation/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 |
10 | from kogito.evaluation.cider.cider_scorer import CiderScorer
11 |
12 |
13 | class Cider:
14 | """
15 | Main Class to compute the CIDEr metric
16 |
17 | """
18 |
19 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
20 | # set cider to sum over 1 to 4-grams
21 | self._n = n
22 | # set the standard deviation parameter for gaussian penalty
23 | self._sigma = sigma
24 |
25 | def compute_score(self, gts, res):
26 | """
27 | Main function to compute CIDEr score
28 | :param hypo_for_image (dict) : dictionary with key and value
29 |
30 | ref_for_image (dict) : dictionary with key and value
31 |
32 | :return: cider (float) : computed CIDEr score for the corpus
33 | """
34 |
35 | assert gts.keys() == res.keys()
36 | imgIds = gts.keys()
37 |
38 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
39 |
40 | for id in imgIds:
41 | hypo = res[id]
42 | ref = gts[id]
43 |
44 | # Sanity check.
45 | assert type(hypo) is list
46 | assert len(hypo) == 1
47 | assert type(ref) is list
48 | assert len(ref) > 0
49 |
50 | cider_scorer += (hypo[0], ref)
51 |
52 | (score, scores) = cider_scorer.compute_score()
53 |
54 | return score, scores
55 |
56 | def method(self):
57 | return "CIDEr"
58 |
--------------------------------------------------------------------------------
/kogito/evaluation/eval.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | from kogito.evaluation.bleu.bleu import Bleu
4 | from kogito.evaluation.meteor.meteor import Meteor
5 | from kogito.evaluation.rouge.rouge import Rouge
6 | from kogito.evaluation.cider.cider import Cider
7 | from kogito.evaluation.bert_score.bert_score import BertScore
8 | from kogito.core.knowledge import Knowledge
9 |
10 |
11 | METRIC_MAP = {
12 | "bleu": (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
13 | "meteor": (Meteor(), "METEOR"),
14 | "rouge": (Rouge(), "ROUGE_L"),
15 | "cider": (Cider(), "CIDEr"),
16 | "bert-score": (BertScore(), "Bert Score"),
17 | }
18 |
19 |
20 | class Evaluator:
21 | def __init__(self, gts, res, metrics):
22 | self.gts = gts
23 | self.res = res
24 | self.scorers = [METRIC_MAP[metric] for metric in metrics]
25 |
26 | def evaluate(self):
27 | score_dict = {}
28 |
29 | for scorer, method in self.scorers:
30 | score, _ = scorer.compute_score(self.gts, self.res)
31 | if type(method) == list:
32 | for sc, m in zip(score, method):
33 | score_dict[m] = str(sc)
34 | else:
35 | score_dict[method] = score
36 |
37 | return score_dict
38 |
39 |
40 | def topk_eval(data: List[Tuple[Knowledge, List[str]]], metrics, k=1):
41 | topk_gts = {}
42 | topk_res = {}
43 |
44 | for i, (kg, reference) in enumerate(data):
45 | for j, g in enumerate(kg.tails[:k]):
46 | key = str(i) + "_" + str(j)
47 | topk_gts[key] = reference
48 | topk_res[key] = [g]
49 |
50 | evaluator = Evaluator(topk_gts, topk_res, metrics)
51 | scores = evaluator.evaluate()
52 |
53 | return scores
54 |
--------------------------------------------------------------------------------
/kogito/evaluation/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = "tylin"
2 |
--------------------------------------------------------------------------------
/kogito/evaluation/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 |
7 | from nltk.translate.meteor_score import meteor_score
8 | from nltk.tokenize import word_tokenize
9 |
10 |
11 | class Meteor:
12 | def __init__(self):
13 | pass
14 |
15 | def compute_score(self, gts, res):
16 | assert gts.keys() == res.keys()
17 | imgIds = gts.keys()
18 | scores = []
19 |
20 | for i in imgIds:
21 | assert len(res[i]) == 1
22 | score = round(
23 | meteor_score(
24 | [word_tokenize(s) for s in gts[i]], word_tokenize(res[i][0])
25 | ),
26 | 4,
27 | )
28 | scores.append(score)
29 |
30 | return sum(scores) / len(scores), scores
31 |
32 | def method(self):
33 | return "METEOR"
34 |
--------------------------------------------------------------------------------
/kogito/evaluation/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = "vrama91"
2 |
--------------------------------------------------------------------------------
/kogito/evaluation/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 |
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if len(string) < len(sub):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
26 |
27 | for j in range(1, len(sub) + 1):
28 | for i in range(1, len(string) + 1):
29 | if string[i - 1] == sub[j - 1]:
30 | lengths[i][j] = lengths[i - 1][j - 1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 |
37 | class Rouge:
38 | """
39 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
40 |
41 | """
42 |
43 | def __init__(self):
44 | # vrama91: updated the value below based on discussion with Hovey
45 | self.beta = 1.2
46 |
47 | def calc_score(self, candidate, refs):
48 | """
49 | Compute ROUGE-L score given one candidate and references for an image
50 | :param candidate: str : candidate sentence to be evaluated
51 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
52 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
53 | """
54 | assert len(candidate) == 1
55 | assert len(refs) > 0
56 | prec = []
57 | rec = []
58 |
59 | # split into tokens
60 | token_c = candidate[0].split(" ")
61 |
62 | for reference in refs:
63 | # split into tokens
64 | token_r = reference.split(" ")
65 | # compute the longest common subsequence
66 | lcs = my_lcs(token_r, token_c)
67 | prec.append(lcs / float(len(token_c)))
68 | rec.append(lcs / float(len(token_r)))
69 |
70 | prec_max = max(prec)
71 | rec_max = max(rec)
72 |
73 | if prec_max != 0 and rec_max != 0:
74 | score = ((1 + self.beta**2) * prec_max * rec_max) / float(
75 | rec_max + self.beta**2 * prec_max
76 | )
77 | else:
78 | score = 0.0
79 | return score
80 |
81 | def compute_score(self, gts, res):
82 | """
83 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
84 | Invoked by evaluate_captions.py
85 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and
86 | "tokenized sentences" as values
87 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and
88 | "tokenized sentences" as values
89 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
90 | """
91 | assert gts.keys() == res.keys()
92 | imgIds = gts.keys()
93 |
94 | score = []
95 | for id in imgIds:
96 | hypo = res[id]
97 | ref = gts[id]
98 |
99 | score.append(self.calc_score(hypo, ref))
100 |
101 | # Sanity check.
102 | assert type(hypo) is list
103 | assert len(hypo) == 1
104 | assert type(ref) is list
105 | assert len(ref) > 0
106 |
107 | average_score = np.mean(np.array(score))
108 | return average_score, np.array(score)
109 |
110 | def method(self):
111 | return "Rouge"
112 |
--------------------------------------------------------------------------------
/kogito/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/__init__.py
--------------------------------------------------------------------------------
/kogito/models/bart/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/bart/__init__.py
--------------------------------------------------------------------------------
/kogito/models/bart/config.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from dataclasses import dataclass, asdict
3 |
4 | BART_TASKS = ["summarization", "translation"]
5 | FP16_OPT_LEVELS = ["O0", "O1", "O2", "O3"]
6 |
7 |
8 | @dataclass
9 | class COMETBARTConfig:
10 | output_dir: str = None
11 | fp16: bool = False
12 | fp16_opt_level: str = "O2"
13 | tpu_cores: Optional[int] = None
14 | gradient_clip_val: float = 1.0
15 | accumulate_grad_batches: int = 1
16 | seed: int = 42
17 | max_source_length: int = 48
18 | max_target_length: int = 24
19 | val_max_target_length: int = 24
20 | test_max_target_length: int = 24
21 | freeze_encoder: bool = False
22 | freeze_embeds: bool = False
23 | sortish_sampler: bool = True
24 | n_train: int = -1
25 | n_val: int = 500
26 | n_test: int = -1
27 | task: str = "summarization"
28 | src_lang: str = ""
29 | tgt_lang: str = ""
30 | atomic: bool = True
31 | pretrained_model: str = "facebook/bart-large"
32 | pretrained_config: str = None
33 | pretrained_tokenizer: str = None
34 | cache_dir: str = ""
35 | learning_rate: float = 5e-5
36 | weight_decay: float = 0.0
37 | adam_epsilon: float = 1e-8
38 | warmup_steps: int = 0
39 | num_workers: int = 2
40 | max_epochs: int = 3
41 | train_batch_size: int = 32
42 | eval_batch_size: int = 32
43 | gpus: int = 1
44 | decoder_start_token_id: Optional[int] = None
45 |
46 | def __dict__(self):
47 | return asdict(self)
48 |
--------------------------------------------------------------------------------
/kogito/models/gpt2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/gpt2/__init__.py
--------------------------------------------------------------------------------
/kogito/models/gpt2/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_lightning as pl
3 |
4 |
5 | class GPT2Finetuner(pl.LightningModule):
6 | def __init__(self, model, learning_rate=1e-5) -> None:
7 | super().__init__()
8 | self.model = model
9 | self.learning_rate = learning_rate
10 |
11 | def forward(self, input_ids, mask):
12 | return self.model(input_ids=input_ids, attention_mask=mask, labels=input_ids)
13 |
14 | def training_step(self, batch, batch_idx):
15 | X = batch
16 | ids = X["source_ids"]
17 | mask = X["source_mask"]
18 | outputs = self.model(input_ids=ids, attention_mask=mask, labels=ids)
19 | loss = outputs[0]
20 | self.log("train_loss", loss, on_epoch=True)
21 | return loss
22 |
23 | def validation_step(self, batch, batch_idx):
24 | X = batch
25 | ids = X["source_ids"]
26 | mask = X["source_mask"]
27 | outputs = self.model(input_ids=ids, attention_mask=mask, labels=ids)
28 | loss = outputs[0]
29 | self.log("val_loss", loss, on_epoch=True)
30 | return loss
31 |
32 | def configure_optimizers(self):
33 | optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
34 | return optimizer
35 |
--------------------------------------------------------------------------------
/kogito/models/gpt2/zeroshot.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import cuda
4 | from transformers import GPT2Tokenizer, GPT2LMHeadModel
5 |
6 | from kogito.core.model import KnowledgeModel
7 | from kogito.core.knowledge import KnowledgeGraph
8 |
9 | device = "cuda" if cuda.is_available() else "cpu"
10 |
11 |
12 | class GPT2Zeroshot(KnowledgeModel):
13 | """Zeroshot knowledge model based on GPT-2"""
14 |
15 | def __init__(self, gpt2_model: str = "gpt2") -> None:
16 | """Initialize GPT-2 model
17 | Args:
18 | gpt2_model (str, optional): HuggingFace model name for gpt2. Defaults to "gpt2".
19 | """
20 | self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model)
21 | self.model = GPT2LMHeadModel.from_pretrained(gpt2_model)
22 | self.model.to(device)
23 |
24 | def train(self):
25 | raise ValueError("GPT-2 Zeroshot model is not trainable")
26 |
27 | def save_pretrained(self, save_path):
28 | self.model.save_pretrained(save_path)
29 | self.tokenizer.save_pretrained(save_path)
30 |
31 | @classmethod
32 | def from_pretrained(cls, model_name_or_path: str = "gpt2"):
33 | return cls(model_name_or_path)
34 |
35 | def generate(
36 | self, input_graph: KnowledgeGraph, seed: int = 42, **kwargs
37 | ) -> KnowledgeGraph:
38 | """Generate inferences from GPT2 model
39 | Args:
40 | input_graph (KnowledgeGraph): Input dataset
41 | seed (int, optional): Random seed. Defaults to 42.
42 | kwargs: Additional arguments to pass to the model.generate() function
43 | Returns:
44 | KnowledgeGraph: Completed knowledge graph
45 | """
46 | torch.manual_seed(seed)
47 | np.random.seed(seed)
48 | torch.backends.cudnn.deterministic = True
49 |
50 | if "top_k" not in kwargs:
51 | kwargs["top_k"] = 1
52 |
53 | if "top_p" not in kwargs:
54 | kwargs["top_p"] = 0.9
55 |
56 | if "num_return_sequences" not in kwargs:
57 | kwargs["num_return_sequences"] = 3
58 |
59 | if "num_beams" not in kwargs:
60 | kwargs["num_beams"] = 3
61 |
62 | if "temperature" not in kwargs:
63 | kwargs["temperature"] = 0.7
64 |
65 | if "repetition_penalty" not in kwargs:
66 | kwargs["repetition_penalty"] = 1.2
67 |
68 | if "max_length" not in kwargs:
69 | kwargs["max_length"] = 32
70 |
71 | if "do_sample" not in kwargs:
72 | kwargs["do_sample"] = True
73 |
74 | outputs = []
75 | for input_kg in input_graph:
76 | prompt = input_kg.to_prompt()
77 | input_ids = self.tokenizer.encode(
78 | prompt, add_special_tokens=False, return_tensors="pt"
79 | )
80 | input_length = input_ids.size(1)
81 | generations = self.model.generate(
82 | input_ids=input_ids.to(device),
83 | max_length=input_length + kwargs["max_length"],
84 | eos_token_id=198,
85 | **kwargs
86 | )
87 |
88 | if len(generations.shape) > 2:
89 | generations.squeeze_()
90 |
91 | text_generations = []
92 | for gen in generations:
93 | gen = gen.tolist()
94 | text = self.tokenizer.decode(
95 | gen[input_length:], clean_up_tokenization_spaces=True
96 | )
97 | text_generations.append(text.strip())
98 |
99 | output_kg = input_kg.copy()
100 | output_kg.tails = text_generations
101 | outputs.append(output_kg)
102 |
103 | return KnowledgeGraph(outputs)
104 |
--------------------------------------------------------------------------------
/kogito/models/gpt3/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/epfl-nlp/kogito/22e63eebf058f99459ef492c9f73d7d4c78cf537/kogito/models/gpt3/__init__.py
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "kogito"
3 | version = "0.6.3"
4 | description = "A Python NLP Commonsense Knowledge Inference Toolkit"
5 | authors = ["Mete Ismayil ", "Antoine Bosselut "]
6 | license = "Apache License 2.0"
7 | readme = "README.md"
8 | homepage = "https://github.com/epfl-nlp/kogito"
9 | repository = "https://github.com/epfl-nlp/kogito"
10 | documentation = "https://github.com/epfl-nlp/kogito"
11 | keywords = [
12 | "natural language processing",
13 | "nlp",
14 | "natural language understanding",
15 | "commonsense reasoning",
16 | "commonsense inference",
17 | "knowledge inference"
18 | ]
19 | classifiers = []
20 |
21 | [tool.poetry.dependencies]
22 | python = ">=3.8,<3.11"
23 | sacrebleu = "^2.0.0"
24 | rouge-score = "^0.0.4"
25 | pytorch-lightning = "~1.5.10"
26 | pandas = "^1.3.5"
27 | spacy = "^3.2.3"
28 | inflect = "^5.3.0"
29 | transformers = "^4.15.0"
30 | wandb = "^0.12.9"
31 | torch = "^1.10.1"
32 | openai = "^0.18.1"
33 | bert-score = "^0.3.11"
34 | sentencepiece = "^0.1.97"
35 | grpcio = "~1.51.1"
36 |
37 | [tool.poetry.dev-dependencies]
38 | pytest = "*"
39 | flake8 = "*"
40 | black = "~22.12.0"
41 | mypy = "*"
42 | sphinx = "*"
43 | insegel = "*"
44 |
45 | [build-system]
46 | requires = ["poetry-core>=1.0.0"]
47 | build-backend = "poetry.core.masonry.api"
48 |
--------------------------------------------------------------------------------