├── .gitignore ├── LICENSE ├── README.md ├── demo.ipynb ├── encoding.py ├── environment.yml ├── generation.py ├── modeling.py ├── requirements.txt ├── transformer_base ├── __init__.py ├── run_clm.py └── run_summarization.py ├── utils.py └── wrapper.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Deep Cognition and Language Research (DeCLaRe) Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## RelationPrompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction 2 | 3 | [![PWC](https://img.shields.io/badge/PapersWithCode-Benchmark-%232cafb1)](https://paperswithcode.com/paper/relationprompt-leveraging-prompts-to-generate) 4 | [![Colab](https://img.shields.io/badge/Colab-Code%20Demo-%23fe9f00)](https://colab.research.google.com/drive/18lrKD30kxEUolQ61o5nzUJM0rvWgpbFK?usp=sharing) 5 | [![Jupyter](https://img.shields.io/badge/Jupyter-Notebook%20Demo-important)](https://github.com/declare-lab/RelationPrompt/blob/main/demo.ipynb) 6 | 7 | This repository implements our ACL Findings 2022 research paper [RelationPrompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction](https://aclanthology.org/2022.findings-acl.5/). 8 | The goal of Zero-Shot Relation Triplet Extraction (ZeroRTE) is to extract relation triplets of the format `(head entity, tail entity, relation)`, despite not having annotated data for the test relation labels. 9 | 10 | ![diagram](https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/diagram.png) 11 | 12 | ### Installation 13 | 14 | - Python 3.7 15 | - If your GPU uses CUDA 11, first install the specific PyTorch: `pip install torch==1.10.0 --extra-index-url https://download.pytorch.org/whl/cu113` 16 | - Install requirements: `pip install -r requirements.txt` or `conda env create --file environment.yml` 17 | - Download and extract the [datasets here](https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/zero_rte_data.zip) to `outputs/data/splits/zero_rte` 18 | - [FewRel Pretrained Model](https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_fewrel_unseen_10_seed_0.tar) (unseen=10, seed=0) 19 | - [Wiki-ZSL Pretrained Model](https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_wiki_unseen_10_seed_0.tar) (unseen=10, seed=0) 20 | 21 | 22 | ### Data Exploration | [![Colab](https://img.shields.io/badge/Colab-Code%20Demo-%23fe9f00)](https://colab.research.google.com/drive/18lrKD30kxEUolQ61o5nzUJM0rvWgpbFK#scrollTo=vw3NlKDddMIP&line=2&uniqifier=1) 23 | 24 | ``` 25 | from wrapper import Dataset 26 | 27 | data = Dataset.load(path) 28 | for s in data.sents: 29 | print(s.tokens) 30 | for t in s.triplets: 31 | print(t.head, t.tail, t.label) 32 | ``` 33 | 34 | ### Generate with Pretrained Model | [![Colab](https://img.shields.io/badge/Colab-Code%20Demo-%23fe9f00)](https://colab.research.google.com/drive/18lrKD30kxEUolQ61o5nzUJM0rvWgpbFK#scrollTo=tUFis82oGUAS&line=1&uniqifier=1) 35 | 36 | ``` 37 | from wrapper import Generator 38 | 39 | model = Generator(load_dir="gpt2", save_dir="outputs/wrapper/fewrel/unseen_10_seed_0/generator") 40 | model.generate(labels=["location", "religion"], path_out="synthetic.jsonl") 41 | ``` 42 | 43 | ### Extract with Pretrained Model | [![Colab](https://img.shields.io/badge/Colab-Code%20Demo-%23fe9f00)](https://colab.research.google.com/drive/18lrKD30kxEUolQ61o5nzUJM0rvWgpbFK#scrollTo=eGxP3vVmID9W&line=1&uniqifier=1) 44 | 45 | ``` 46 | from wrapper import Extractor 47 | 48 | model = Extractor(load_dir="facebook/bart-base", save_dir="outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final") 49 | model.predict(path_in=path_test, path_out="pred.jsonl") 50 | ``` 51 | 52 | ### Model Training | [![Colab](https://img.shields.io/badge/Colab-Code%20Demo-%23fe9f00)](https://colab.research.google.com/drive/18lrKD30kxEUolQ61o5nzUJM0rvWgpbFK#scrollTo=qi5PAW5ocjfj&line=1&uniqifier=1) 53 | 54 | Train the Generator and Extractor models: 55 | ``` 56 | from pathlib import Path 57 | from wrapper import Generator, Extractor 58 | 59 | generator = Generator( 60 | load_dir="gpt2", 61 | save_dir=str(Path(save_dir) / "generator"), 62 | ) 63 | extractor = Extractor( 64 | load_dir="facebook/bart-base", 65 | save_dir=str(Path(save_dir) / "extractor"), 66 | ) 67 | generator.fit(path_train, path_dev) 68 | extractor.fit(path_train, path_dev) 69 | ``` 70 | 71 | Generate synthetic data with relation triplets for test labels: 72 | ``` 73 | generator.generate(labels_test, path_out=path_synthetic) 74 | ``` 75 | 76 | Train the final Extractor model using the synthetic data and predict on test sentences: 77 | ``` 78 | extractor_final = Extractor( 79 | load_dir=str(Path(save_dir) / "extractor" / "model"), 80 | save_dir=str(Path(save_dir) / "extractor_final"), 81 | ) 82 | extractor_final.fit(path_synthetic, path_dev) 83 | extractor_final.predict(path_in=path_test, path_out=path_pred) 84 | ``` 85 | 86 | ### Experiment Scripts 87 | 88 | Run training in [wrapper.py](https://github.com/declare-lab/RelationPrompt/blob/783f33c301813368a5a6e3bdbbe50c47df7647bf/wrapper.py#L370) (You can change "fewrel" to "wiki" or unseen to 5/10/15 or seed to 0/1/2/3/4): 89 | ``` 90 | python wrapper.py main \ 91 | --path_train outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/train.jsonl \ 92 | --path_dev outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/dev.jsonl \ 93 | --path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ 94 | --save_dir outputs/wrapper/fewrel/unseen_10_seed_0 95 | ``` 96 | 97 | Run evaluation (Single-triplet setting) 98 | ``` 99 | python wrapper.py run_eval \ 100 | --path_model outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final \ 101 | --path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ 102 | --mode single 103 | ``` 104 | 105 | Run evaluation (Multi-triplet setting) 106 | ``` 107 | python wrapper.py run_eval \ 108 | --path_model outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final \ 109 | --path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ 110 | --mode multi 111 | ``` 112 | 113 | ### Research Citation 114 | If the code is useful for your research project, we appreciate if you cite the following [paper](https://aclanthology.org/2022.findings-acl.5/): 115 | ``` 116 | @inproceedings{chia-etal-2022-relationprompt, 117 | title = "{R}elation{P}rompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction", 118 | author = "Chia, Yew Ken and 119 | Bing, Lidong and 120 | Poria, Soujanya and 121 | Si, Luo", 122 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", 123 | month = may, 124 | year = "2022", 125 | address = "Dublin, Ireland", 126 | publisher = "Association for Computational Linguistics", 127 | url = "https://aclanthology.org/2022.findings-acl.5", 128 | doi = "10.18653/v1/2022.findings-acl.5", 129 | pages = "45--57", 130 | abstract = "Despite the importance of relation extraction in building and representing knowledge, less research is focused on generalizing to unseen relations types. We introduce the task setting of Zero-Shot Relation Triplet Extraction (ZeroRTE) to encourage further research in low-resource relation extraction methods. Given an input sentence, each extracted triplet consists of the head entity, relation label, and tail entity where the relation label is not seen at the training stage. To solve ZeroRTE, we propose to synthesize relation examples by prompting language models to generate structured texts. Concretely, we unify language model prompts and structured text approaches to design a structured prompt template for generating synthetic relation samples when conditioning on relation label prompts (RelationPrompt). To overcome the limitation for extracting multiple relation triplets in a sentence, we design a novel Triplet Search Decoding method. Experiments on FewRel and Wiki-ZSL datasets show the efficacy of RelationPrompt for the ZeroRTE task and zero-shot relation classification. Our code and data are available at github.com/declare-lab/RelationPrompt.", 131 | } 132 | ``` 133 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "source": [ 6 | "### RelationPrompt: Leveraging Prompts to Generate Synthetic Data for Zero-Shot Relation Triplet Extraction\n", 7 | "\n", 8 | "GitHub: https://github.com/declare-lab/RelationPrompt" 9 | ], 10 | "metadata": { 11 | "id": "qm5jvHp3vpKT" 12 | } 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": { 18 | "id": "VP-z5ENrR7S3", 19 | "colab": { 20 | "base_uri": "https://localhost:8080/" 21 | }, 22 | "outputId": "a7f24d70-77b4-4ede-b7a8-0e3e2a38c10e" 23 | }, 24 | "outputs": [ 25 | { 26 | "output_type": "stream", 27 | "name": "stdout", 28 | "text": [ 29 | "fatal: destination path 'RelationPrompt' already exists and is not an empty directory.\n", 30 | "HEAD is now at 8ce3656 Upgrade torch version 1.9.0 -> 1.10.0\n", 31 | "File ‘model_fewrel_unseen_10_seed_0.tar’ already there; not retrieving.\n", 32 | "\n" 33 | ] 34 | } 35 | ], 36 | "source": [ 37 | "!git clone https://github.com/declare-lab/RelationPrompt.git\n", 38 | "!cd RelationPrompt && git checkout 8ce3656\n", 39 | "!cp -a RelationPrompt/* .\n", 40 | "!wget -q -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/zero_rte_data.zip\n", 41 | "!wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_fewrel_unseen_10_seed_0.tar\n", 42 | "!tar -xf model_fewrel_unseen_10_seed_0.tar\n", 43 | "# !wget -nc https://github.com/declare-lab/RelationPrompt/releases/download/v1.0.0/model_wiki_unseen_10_seed_0.tar\n", 44 | "!unzip -nq zero_rte_data.zip\n", 45 | "!pip install -q -r requirements.txt" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "source": [ 51 | "#@title Data Parameters\n", 52 | "data_name = \"fewrel\" #@param [\"fewrel\", \"wiki\"]\n", 53 | "num_unseen_labels = 10 #@param [5,10,15]\n", 54 | "random_seed = 0 #@param [0,1,2,3,4]\n", 55 | "data_limit = 5000 #@param {type:\"number\"}\n", 56 | "data_dir = f\"outputs/data/splits/zero_rte/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}\"\n", 57 | "print(dict(data_dir=data_dir))" 58 | ], 59 | "metadata": { 60 | "colab": { 61 | "base_uri": "https://localhost:8080/" 62 | }, 63 | "id": "Gj8sD5YNdgLJ", 64 | "outputId": "08ead310-7f39-46d7-d62d-dbe8602c7dc1" 65 | }, 66 | "execution_count": null, 67 | "outputs": [ 68 | { 69 | "output_type": "stream", 70 | "name": "stdout", 71 | "text": [ 72 | "{'data_dir': 'outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0'}\n" 73 | ] 74 | } 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "source": [ 80 | "# Data Setup\n", 81 | "import json\n", 82 | "import random\n", 83 | "from pathlib import Path\n", 84 | "from wrapper import Generator, Extractor, Dataset\n", 85 | "\n", 86 | "def truncate_data(path:str, limit:int, path_out:str):\n", 87 | " # Use a subset of data for quick demo on Colab\n", 88 | " data = Dataset.load(path)\n", 89 | " random.seed(0)\n", 90 | " random.shuffle(data.sents)\n", 91 | " data.sents = data.sents[:limit]\n", 92 | " data.save(path_out)\n", 93 | "\n", 94 | "path_train = \"train.jsonl\"\n", 95 | "path_dev = \"dev.jsonl\"\n", 96 | "path_test = \"test.jsonl\"\n", 97 | "truncate_data(f\"{data_dir}/train.jsonl\", limit=data_limit, path_out=path_train)\n", 98 | "truncate_data(f\"{data_dir}/dev.jsonl\", limit=data_limit // 10, path_out=path_dev)\n", 99 | "truncate_data(f\"{data_dir}/test.jsonl\", limit=data_limit // 10, path_out=path_test)" 100 | ], 101 | "metadata": { 102 | "id": "iO--Mb9nHgGG" 103 | }, 104 | "execution_count": null, 105 | "outputs": [] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "source": [ 110 | "# Data Exploration\n", 111 | "\n", 112 | "def explore_data(path: str):\n", 113 | " data = Dataset.load(path)\n", 114 | " print(\"labels:\", data.get_labels())\n", 115 | " print()\n", 116 | " for s in random.sample(data.sents, k=3):\n", 117 | " print(\"tokens:\", s.tokens)\n", 118 | " for t in s.triplets:\n", 119 | " print(\"head:\", t.head)\n", 120 | " print(\"tail:\", t.tail)\n", 121 | " print(\"relation:\", t.label)\n", 122 | " print()\n", 123 | "\n", 124 | "explore_data(path_train)" 125 | ], 126 | "metadata": { 127 | "colab": { 128 | "base_uri": "https://localhost:8080/" 129 | }, 130 | "id": "vw3NlKDddMIP", 131 | "outputId": "d100f0aa-58fa-454a-bedf-8166ab7150ed" 132 | }, 133 | "execution_count": null, 134 | "outputs": [ 135 | { 136 | "output_type": "stream", 137 | "name": "stdout", 138 | "text": [ 139 | "labels: ['after a work by', 'applies to jurisdiction', 'architect', 'characters', 'child', 'constellation', 'contains administrative territorial entity', 'country', 'country of citizenship', 'country of origin', 'crosses', 'developer', 'director', 'distributed by', 'father', 'field of work', 'followed by', 'follows', 'genre', 'has part', 'head of government', 'headquarters location', 'heritage designation', 'instance of', 'instrument', 'language of work or name', 'league', 'licensed to broadcast to', 'located in or next to body of water', 'located in the administrative territorial entity', 'located on terrain feature', 'location of formation', 'manufacturer', 'member of', 'military branch', 'military rank', 'mother', 'mountain range', 'mouth of the watercourse', 'movement', 'notable work', 'occupant', 'occupation', 'operator', 'original language of film or TV show', 'part of', 'participant', 'participating team', 'performer', 'place served by transport hub', 'publisher', 'record label', 'residence', 'said to be the same as', 'screenwriter', 'sibling', 'sport', 'sports season of league or competition', 'spouse', 'subsidiary', 'successful candidate', 'tributary', 'voice type', 'winner', 'work location']\n", 140 | "\n", 141 | "tokens: ['In', 'the', 'Ulster', 'Cycle', 'of', 'Irish', 'mythology', ',', 'Lugaid', 'mac', 'Con', 'Roí', 'was', 'the', 'son', 'of', 'Cú', 'Roí', 'mac', 'Dáire', '.']\n", 142 | "head: [2, 3]\n", 143 | "tail: [16, 17]\n", 144 | "relation: characters\n", 145 | "\n", 146 | "tokens: ['Wanandi', 'was', 'a', 'leading', 'student', 'activist', 'during', 'the', '1965', '-', '66', 'in', 'Indonesia', 'when', ',', 'over', 'time', ',', 'president', 'Sukarno', 'was', 'removed', 'from', 'power', 'and', 'Soeharto', 'became', 'the', 'second', 'president', 'of', 'Indonesia', '.']\n", 147 | "head: [25]\n", 148 | "tail: [12]\n", 149 | "relation: country of citizenship\n", 150 | "\n", 151 | "tokens: ['The', 'Temple', 'of', 'Proserpina', 'or', 'Temple', 'of', 'ProserpineSome', 'theories', 'suggest', 'that', 'the', 'temple', 'was', 'a', 'Greek', 'Temple', 'dedicated', 'to', 'Persephone', ',', 'the', 'Greek', 'equivalent', 'to', 'the', 'Roman', 'Goddess', 'Proserpina', '.']\n", 152 | "head: [19]\n", 153 | "tail: [3]\n", 154 | "relation: said to be the same as\n", 155 | "\n" 156 | ] 157 | } 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "source": [ 163 | "# Use Pretrained Model for Generation\n", 164 | "model = Generator(load_dir=\"gpt2\", save_dir=\"outputs/wrapper/fewrel/unseen_10_seed_0/generator\")\n", 165 | "model.generate(labels=[\"location\", \"religion\"], path_out=\"synthetic.jsonl\")\n", 166 | "explore_data(path=\"synthetic.jsonl\")" 167 | ], 168 | "metadata": { 169 | "colab": { 170 | "base_uri": "https://localhost:8080/" 171 | }, 172 | "id": "tUFis82oGUAS", 173 | "outputId": "23e65d65-fdce-4f38-bc9a-ddf8f9cb6fc6" 174 | }, 175 | "execution_count": null, 176 | "outputs": [ 177 | { 178 | "output_type": "stream", 179 | "name": "stdout", 180 | "text": [ 181 | "labels: ['location', 'religion']\n", 182 | "\n", 183 | "tokens: ['In', '2007', ',', 'he', 'joined', 'a', 'group', 'of', 'artists', 'known', 'as', 'the', 'Moth', 'Boys', ',', 'an', 'annual', 'neo', '-', 'pop', 'quartet', 'that', 'plays', 'in', 'several', 'venues', 'around', 'the', 'country', 'in', 'Las', 'Vegas', '.']\n", 184 | "head: [12, 13]\n", 185 | "tail: [30, 31]\n", 186 | "relation: location\n", 187 | "\n", 188 | "tokens: ['There', 'is', 'a', 'section', 'of', 'the', 'town', 'under', '\"', 'the', 'Graziano', '\"', 'River', ',', 'a', 'channel', 'flowing', 'the', 'river', 'in', 'southwestern', 'Italy', 'from', 'the', 'island', 'of', 'Sardinia', 'to', 'Italy', '.']\n", 189 | "head: [10]\n", 190 | "tail: [26]\n", 191 | "relation: location\n", 192 | "\n", 193 | "tokens: ['In', 'August', '2012', ',', 'the', 'station', 'opened', 'on', 'its', 'regular', 'schedule', 'between', 'Minto', 'Plaza', 'in', 'Osaka', 'and', 'Keito', 'Station', 'in', 'the', 'city', 'of', 'Nara', 'in', 'Japan', '.']\n", 194 | "head: [17, 18]\n", 195 | "tail: [15]\n", 196 | "relation: location\n", 197 | "\n" 198 | ] 199 | } 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "source": [ 205 | "# Use Pretrained Model for Extraction\n", 206 | "model = Extractor(load_dir=\"facebook/bart-base\", save_dir=\"outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final\")\n", 207 | "model.predict(path_in=path_test, path_out=\"pred.jsonl\")\n", 208 | "explore_data(path=\"pred.jsonl\")" 209 | ], 210 | "metadata": { 211 | "colab": { 212 | "base_uri": "https://localhost:8080/" 213 | }, 214 | "id": "eGxP3vVmID9W", 215 | "outputId": "c9ba9b59-faf5-4fec-c54e-2d4878176a2c" 216 | }, 217 | "execution_count": null, 218 | "outputs": [ 219 | { 220 | "output_type": "stream", 221 | "name": "stdout", 222 | "text": [ 223 | "{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=64, grad_accumulation=2, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n" 224 | ] 225 | }, 226 | { 227 | "output_type": "stream", 228 | "name": "stderr", 229 | "text": [ 230 | "100%|██████████| 8/8 [00:07<00:00, 1.11it/s]" 231 | ] 232 | }, 233 | { 234 | "output_type": "stream", 235 | "name": "stdout", 236 | "text": [ 237 | "labels: ['', 'competition class', 'location', 'member of political party', 'nominated for', 'operating system', 'original broadcaster', 'owned by', 'position played on team / speciality', 'religion']\n", 238 | "\n", 239 | "tokens: ['The', 'South', 'Bank', 'Show', 'is', 'a', 'television', 'arts', 'magazine', 'show', 'that', 'was', 'produced', 'by', 'ITV', 'between', '1978', 'and', '2010', ',', 'and', 'by', 'Sky', 'Arts', 'from', '2012', '.']\n", 240 | "head: [1, 2, 3]\n", 241 | "tail: [14]\n", 242 | "relation: original broadcaster\n", 243 | "\n", 244 | "tokens: ['Then', 'Senator', 'Neptali', 'Gonzales', ',', 'whom', 'Maceda', 'helped', ',', 'was', 'installed', 'as', 'Senate', 'President', 'from', '1992', '-', '1993', 'and', '1995', '-', '1996', 'succeeded', 'him', '.']\n", 245 | "head: [2, 3]\n", 246 | "tail: [12, 13]\n", 247 | "relation: position played on team / speciality\n", 248 | "\n", 249 | "tokens: ['In', '1908', 'she', 'won', 'the', 'singles', 'title', 'at', 'the', 'Welsh', 'Championships', 'in', 'Newport', 'and', 'successfully', 'defended', 'it', 'in', '1909', '.', 'she', 'also', 'won', 'the', 'Scottish', 'Championships', 'singles', 'title', 'twice', '1908', 'to', '1909', '.']\n", 250 | "head: [9, 10]\n", 251 | "tail: [5]\n", 252 | "relation: competition class\n", 253 | "\n" 254 | ] 255 | }, 256 | { 257 | "output_type": "stream", 258 | "name": "stderr", 259 | "text": [ 260 | "\n" 261 | ] 262 | } 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "source": [ 268 | "# Full Training\n", 269 | "save_dir = f\"outputs/wrapper/{data_name}/unseen_{num_unseen_labels}_seed_{random_seed}\"\n", 270 | "print(dict(save_dir=save_dir))\n", 271 | "model_kwargs = dict(batch_size=32, grad_accumulation=4) # For lower memory on Colab\n", 272 | "\n", 273 | "generator = Generator(\n", 274 | " load_dir=\"gpt2\",\n", 275 | " save_dir=str(Path(save_dir) / \"generator\"),\n", 276 | " model_kwargs=model_kwargs,\n", 277 | ")\n", 278 | "extractor = Extractor(\n", 279 | " load_dir=\"facebook/bart-base\",\n", 280 | " save_dir=str(Path(save_dir) / \"extractor\"),\n", 281 | " model_kwargs=model_kwargs,\n", 282 | ")\n", 283 | "\n", 284 | "generator.fit(path_train, path_dev)\n", 285 | "extractor.fit(path_train, path_dev)\n", 286 | "path_synthetic = str(Path(save_dir) / \"synthetic.jsonl\")\n", 287 | "labels_dev = Dataset.load(path_dev).get_labels()\n", 288 | "labels_test = Dataset.load(path_test).get_labels()\n", 289 | "generator.generate(labels_dev + labels_test, path_out=path_synthetic)\n", 290 | "\n", 291 | "extractor_final = Extractor(\n", 292 | " load_dir=str(Path(save_dir) / \"extractor\" / \"model\"),\n", 293 | " save_dir=str(Path(save_dir) / \"extractor_final\"),\n", 294 | " model_kwargs=model_kwargs,\n", 295 | ")\n", 296 | "extractor_final.fit(path_synthetic, path_dev)\n", 297 | "\n", 298 | "path_pred = str(Path(save_dir) / \"pred.jsonl\")\n", 299 | "extractor_final.predict(path_in=path_test, path_out=path_pred)\n", 300 | "results = extractor_final.score(path_pred, path_test)\n", 301 | "print(json.dumps(results, indent=2))" 302 | ], 303 | "metadata": { 304 | "id": "qi5PAW5ocjfj", 305 | "colab": { 306 | "base_uri": "https://localhost:8080/" 307 | }, 308 | "outputId": "633b3678-319d-4f1d-9296-b3e47de80c47" 309 | }, 310 | "execution_count": null, 311 | "outputs": [ 312 | { 313 | "output_type": "stream", 314 | "name": "stdout", 315 | "text": [ 316 | "{'save_dir': 'outputs/wrapper/fewrel/unseen_10_seed_0'}\n", 317 | "{'select_model': RelationGenerator(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/generator/data', model_name='gpt2', do_pretrain=False, encoder_name='generate', pipe_name='text-generation', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, block_size=128)}\n", 318 | "{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/data', model_name='facebook/bart-base', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n", 319 | "{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n", 320 | "{'select_model': NewRelationExtractor(model_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/model', data_dir='outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final/data', model_name='outputs/wrapper/fewrel/unseen_10_seed_0/extractor/model', do_pretrain=False, encoder_name='extract', pipe_name='summarization', batch_size=32, grad_accumulation=4, random_seed=42, warmup_ratio=0.2, lr_pretrain=0.0003, lr_finetune=3e-05, epochs_pretrain=3, epochs_finetune=5, train_fp16=True, max_source_length=128, max_target_length=128)}\n" 321 | ] 322 | }, 323 | { 324 | "output_type": "stream", 325 | "name": "stderr", 326 | "text": [ 327 | "100%|██████████| 16/16 [00:09<00:00, 1.72it/s]" 328 | ] 329 | }, 330 | { 331 | "output_type": "stream", 332 | "name": "stdout", 333 | "text": [ 334 | "{\n", 335 | " \"path_pred\": \"outputs/wrapper/fewrel/unseen_10_seed_0/pred.jsonl\",\n", 336 | " \"path_gold\": \"test.jsonl\",\n", 337 | " \"precision\": 0.328,\n", 338 | " \"recall\": 0.3215686274509804,\n", 339 | " \"score\": 0.32475247524752476\n", 340 | "}\n" 341 | ] 342 | }, 343 | { 344 | "output_type": "stream", 345 | "name": "stderr", 346 | "text": [ 347 | "\n" 348 | ] 349 | } 350 | ] 351 | }, 352 | { 353 | "cell_type": "code", 354 | "source": [ 355 | "" 356 | ], 357 | "metadata": { 358 | "id": "-63QPp0xuDOA" 359 | }, 360 | "execution_count": null, 361 | "outputs": [] 362 | } 363 | ], 364 | "metadata": { 365 | "accelerator": "GPU", 366 | "colab": { 367 | "name": "RelationPrompt Demo.ipynb", 368 | "provenance": [], 369 | "collapsed_sections": [] 370 | }, 371 | "kernelspec": { 372 | "display_name": "Python 3", 373 | "name": "python3" 374 | }, 375 | "language_info": { 376 | "name": "python" 377 | } 378 | }, 379 | "nbformat": 4, 380 | "nbformat_minor": 0 381 | } -------------------------------------------------------------------------------- /encoding.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List, Tuple 3 | 4 | from fire import Fire 5 | from pydantic import BaseModel 6 | from tqdm import tqdm 7 | from transformers import AutoTokenizer 8 | 9 | from transformer_base import run_summarization 10 | from utils import RelationData, RelationSentence 11 | 12 | 13 | class Encoder(BaseModel): 14 | def encode_x(self, x: str) -> str: 15 | raise NotImplementedError 16 | 17 | def encode(self, sent: RelationSentence) -> Tuple[str, str]: 18 | raise NotImplementedError 19 | 20 | def decode(self, x: str, y: str) -> RelationSentence: 21 | raise NotImplementedError 22 | 23 | def decode_x(self, x: str) -> str: 24 | raise NotImplementedError 25 | 26 | def safe_decode(self, x: str, y: str) -> RelationSentence: 27 | text = self.decode_x(x) 28 | try: 29 | s = self.decode(x=x, y=y) 30 | except Exception as e: 31 | s = RelationSentence( 32 | tokens=text.split(), head=[], tail=[], label="", error=str(e), raw=y 33 | ) 34 | return s 35 | 36 | def encode_to_line(self, sent: RelationSentence) -> str: 37 | raise NotImplementedError 38 | 39 | def decode_from_line(self, line: str) -> RelationSentence: 40 | raise NotImplementedError 41 | 42 | def parse_line(self, line: str) -> Tuple[str, str]: 43 | raise NotImplementedError 44 | 45 | 46 | class GenerateEncoder(Encoder): 47 | def encode_x(self, r: str) -> str: 48 | return f"Relation : {r} ." 49 | 50 | def decode_x(self, text: str) -> str: 51 | return text.split("Relation : ")[-1][:-2] 52 | 53 | def encode_triplet(self, sent: RelationSentence) -> str: 54 | s, r, o = sent.as_tuple() 55 | return f"Context : {sent.text} Head Entity : {s} , Tail Entity : {o} ." 56 | 57 | def decode_triplet(self, text: str, label: str) -> RelationSentence: 58 | front, back = text.split(" Head Entity : ") 59 | _, context = front.split("Context : ") 60 | head, back = back.split(" , Tail Entity : ") 61 | tail = back[:-2] 62 | return RelationSentence.from_spans(context, head, tail, label) 63 | 64 | def encode_y(self, sent: RelationSentence) -> str: 65 | return self.encode_x(sent.label) + " " + self.encode_triplet(sent) 66 | 67 | def decode_y(self, text: str, label: str) -> RelationSentence: 68 | del label 69 | front, back = text.split(" . Context : ") 70 | label = self.decode_x(front + " .") 71 | return self.decode_triplet("Context : " + back, label) 72 | 73 | def decode(self, x: str, y: str) -> RelationSentence: 74 | r = self.decode_x(x) 75 | sent = self.decode_y(y, r) 76 | return sent 77 | 78 | def encode(self, sent: RelationSentence) -> Tuple[str, str]: 79 | x = self.encode_x(sent.label) 80 | y = self.encode_y(sent) 81 | return x, y 82 | 83 | def decode_from_line(self, line: str) -> RelationSentence: 84 | x, y = self.parse_line(line) 85 | return self.decode(x, y) 86 | 87 | def encode_to_line(self, sent: RelationSentence) -> str: 88 | x, y = self.encode(sent) 89 | return y + "\n" 90 | 91 | def parse_line(self, line: str) -> Tuple[str, str]: 92 | return "", line.strip() 93 | 94 | 95 | class ExtractEncoder(Encoder): 96 | def encode_x(self, text: str) -> str: 97 | return f"Context : {text}" 98 | 99 | def decode_x(self, x: str) -> str: 100 | return x.split("Context : ")[-1] 101 | 102 | def encode_y(self, sent: RelationSentence) -> str: 103 | s, r, o = sent.as_tuple() 104 | return f"Head Entity : {s} , Tail Entity : {o} , Relation : {r} ." 105 | 106 | def decode_y(self, x: str, y: str) -> RelationSentence: 107 | context = self.decode_x(x) 108 | front, label = y.split(" , Relation : ") 109 | label = label[:-2] 110 | front, tail = front.split(" , Tail Entity : ") 111 | _, head = front.split("Head Entity : ") 112 | return RelationSentence.from_spans(context, head, tail, label) 113 | 114 | def encode_entity_prompt(self, head: str, tail: str) -> str: 115 | return f"Head Entity : {head} , Tail Entity : {tail} , Relation :" 116 | 117 | def encode(self, sent: RelationSentence) -> Tuple[str, str]: 118 | x = self.encode_x(sent.text) 119 | y = self.encode_y(sent) 120 | return x, y 121 | 122 | def decode(self, x: str, y: str) -> RelationSentence: 123 | return self.decode_y(x, y) 124 | 125 | def encode_to_line(self, sent: RelationSentence) -> str: 126 | x, y = self.encode(sent) 127 | return run_summarization.encode_to_line(x, y) 128 | 129 | def decode_from_line(self, line: str) -> RelationSentence: 130 | x, y = self.parse_line(line) 131 | return self.decode(x, y) 132 | 133 | def parse_line(self, line: str) -> Tuple[str, str]: 134 | return run_summarization.decode_from_line(line) 135 | 136 | 137 | def test_encoders( 138 | paths: List[str] = [ 139 | "outputs/data/zsl/wiki/unseen_5_seed_0/train.jsonl", 140 | "outputs/data/zsl/fewrel/unseen_5_seed_0/train.jsonl", 141 | ], 142 | print_limit: int = 4, 143 | encoder_names: List[str] = ["generate", "extract"], 144 | limit: int = 1000, 145 | ): 146 | encoders = {k: select_encoder(k) for k in encoder_names} 147 | 148 | for p in paths: 149 | data = RelationData.load(Path(p)) 150 | _, data = data.train_test_split(min(limit, len(data.sents)), random_seed=0) 151 | 152 | for name, e in tqdm(list(encoders.items())): 153 | num_fail = 0 154 | print(dict(name=name, p=p)) 155 | for s in data.sents: 156 | encoded = e.encode_to_line(s) 157 | x, y = e.parse_line(encoded) 158 | decoded: RelationSentence = e.safe_decode(x, y) 159 | 160 | if decoded.as_tuple() != s.as_tuple(): 161 | if num_fail < print_limit: 162 | print(dict(gold=s.as_tuple(), text=s.text)) 163 | print(dict(pred=decoded.as_tuple(), text=decoded.text)) 164 | print(dict(x=x, y=y, e=decoded.error)) 165 | print() 166 | num_fail += 1 167 | 168 | print(dict(success_rate=1 - (num_fail / len(data.sents)))) 169 | print("#" * 80) 170 | 171 | 172 | def select_encoder(name: str) -> Encoder: 173 | mapping: Dict[str, Encoder] = dict( 174 | extract=ExtractEncoder(), 175 | generate=GenerateEncoder(), 176 | ) 177 | encoder = mapping[name] 178 | return encoder 179 | 180 | 181 | def test_entity_prompts( 182 | path: str = "outputs/data/zsl/wiki/unseen_10_seed_0/test.jsonl", limit: int = 100 183 | ): 184 | def tokenize(text: str, tok) -> List[str]: 185 | return tok.convert_ids_to_tokens(tok(text, add_special_tokens=False).input_ids) 186 | 187 | data = RelationData.load(Path(path)) 188 | e = ExtractEncoder() 189 | tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base") 190 | print(tokenizer) 191 | for i, s in enumerate(tqdm(data.sents[:limit])): 192 | head, label, tail = s.as_tuple() 193 | x, y = e.encode(s) 194 | prompt = e.encode_entity_prompt(head, tail) 195 | tokens_y = tokenize(y, tokenizer) 196 | tokens_prompt = tokenize(prompt, tokenizer) 197 | assert tokens_y[: len(tokens_prompt)] == tokens_prompt 198 | if i < 3: 199 | print(tokens_y) 200 | 201 | 202 | if __name__ == "__main__": 203 | Fire() 204 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: relationprompt 2 | channels: 3 | - nvidia 4 | - pytorch 5 | - conda-forge 6 | dependencies: 7 | - python=3.7 8 | - cudatoolkit=10.2 9 | - pytorch=1.10.0 10 | - transformers=4.7.0 11 | - datasets=1.11.0 12 | - pandas=1.2.4 13 | - pydantic=1.8.1 14 | - fastavro=1.4.0 15 | - fire=0.4.0 16 | - nltk=3.6.6 17 | - lxml=4.6.3 18 | - editdistance=0.5.3 19 | - seqeval=1.2.2 20 | prefix: ~/miniconda3/envs/relationprompt 21 | -------------------------------------------------------------------------------- /generation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import torch 4 | from fire import Fire 5 | from torch import Tensor 6 | from transformers import PreTrainedModel, PreTrainedTokenizerFast 7 | 8 | from encoding import ExtractEncoder 9 | from utils import DynamicModel, RelationSentence, find_sublist_index 10 | 11 | 12 | class TextGenerator(DynamicModel): 13 | model: PreTrainedModel 14 | tokenizer: PreTrainedTokenizerFast 15 | scores: Optional[List[Tensor]] = None 16 | max_length: int 17 | 18 | def tokenize(self, texts: List[str], **kwargs): 19 | return self.tokenizer( 20 | texts, 21 | padding=True, 22 | truncation=True, 23 | max_length=self.max_length, 24 | return_tensors="pt", 25 | **kwargs, 26 | ).to(self.model.device) 27 | 28 | def run( 29 | self, 30 | texts: List[str], 31 | do_sample=True, 32 | top_k=50, 33 | temperature=1.0, 34 | num_return: int = 4, 35 | prompt: Optional[str] = None, 36 | prompt_ids: Optional[List[int]] = None, 37 | multi_prompt_ids: Optional[List[List[int]]] = None, 38 | decoder_input_ids: Optional[Tensor] = None, 39 | save_scores: bool = False, 40 | **kwargs, 41 | ) -> List[str]: 42 | # https://huggingface.co/transformers/v4.7.0/main_classes/model.html#generation 43 | tok = self.tokenizer 44 | eos, bos = tok.eos_token_id, tok.bos_token_id 45 | 46 | if prompt is not None: 47 | prompt_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids 48 | if prompt_ids is not None: 49 | prompt_ids = [eos, bos] + prompt_ids 50 | decoder_input_ids = torch.tensor([prompt_ids]) 51 | if multi_prompt_ids is not None: 52 | assert len(texts) == len(multi_prompt_ids) 53 | multi_prompt_ids = [[eos, bos] + lst for lst in multi_prompt_ids] 54 | decoder_input_ids = torch.tensor(multi_prompt_ids) 55 | if decoder_input_ids is not None: 56 | kwargs.update(decoder_input_ids=decoder_input_ids.to(self.model.device)) 57 | 58 | outputs = self.model.generate( 59 | **self.tokenize(texts), 60 | do_sample=do_sample, 61 | top_k=top_k, 62 | temperature=temperature, 63 | num_return_sequences=num_return, 64 | return_dict_in_generate=True, 65 | output_scores=save_scores, 66 | max_length=self.max_length, 67 | **kwargs, 68 | ) 69 | 70 | self.scores = None 71 | if save_scores: 72 | self.scores = [_ for _ in torch.stack(outputs.scores, 1).cpu()] 73 | return self.decode(outputs.sequences) 74 | 75 | def decode(self, outputs) -> List[str]: 76 | tok = self.tokenizer 77 | texts = tok.batch_decode( 78 | outputs, skip_special_tokens=False, clean_up_tokenization_spaces=False 79 | ) 80 | 81 | # Manually remove in case we have custom special tokens 82 | special_tokens = [tok.eos_token, tok.bos_token, tok.pad_token] 83 | for i, t in enumerate(texts): 84 | for token in special_tokens: 85 | t = t.replace(token, "") 86 | texts[i] = t 87 | return texts 88 | 89 | 90 | class LabelConstraint: 91 | def __init__( 92 | self, 93 | labels: List[str], 94 | tokenizer: PreTrainedTokenizerFast, 95 | prefix: str = " Relation :", 96 | ): 97 | self.prefix: List[int] = tokenizer(prefix, add_special_tokens=False).input_ids 98 | self.label_map: Dict[int, str] = { 99 | tokenizer(" " + x, add_special_tokens=False).input_ids[0]: x for x in labels 100 | } 101 | self.tokenizer = tokenizer 102 | 103 | def run(self, triplet: RelationSentence, scores: Tensor) -> RelationSentence: 104 | triplet = triplet.copy(deep=True) 105 | assert scores.ndim == 2 106 | token_ids = scores.argmax(dim=-1).int().tolist() 107 | i = find_sublist_index(token_ids, self.prefix) 108 | if i == -1: 109 | return triplet 110 | 111 | position = i + len(self.prefix) 112 | best = "" 113 | best_score = -1e9 114 | for j, label in self.label_map.items(): 115 | score = scores[position, j].item() 116 | if score > best_score: 117 | best = label 118 | best_score = score 119 | 120 | if triplet.label in self.label_map.values(): 121 | assert best == triplet.label 122 | 123 | assert len(best) > 0 124 | triplet.label = best 125 | triplet.score = best_score 126 | return triplet 127 | 128 | 129 | class TripletSearchDecoder(DynamicModel): 130 | gen: TextGenerator 131 | constraint: LabelConstraint 132 | encoder: ExtractEncoder 133 | top_k: int = 4 134 | 135 | def generate(self, text: str, **kwargs) -> Tuple[str, Tensor]: 136 | outputs = self.gen.run( 137 | [text], 138 | do_sample=False, 139 | num_return=1, 140 | num_beams=1, 141 | save_scores=True, 142 | **kwargs, 143 | ) 144 | 145 | assert len(outputs) == 1 146 | assert self.gen.scores is not None 147 | scores = torch.log_softmax(self.gen.scores[0], dim=-1) 148 | assert scores.ndim == 2 149 | return outputs[0], scores 150 | 151 | def find_prefix_end(self, token_ids: List[str], prefix: str) -> int: 152 | prefix_ids = self.gen.tokenizer(prefix, add_special_tokens=False).input_ids 153 | i = find_sublist_index(token_ids, prefix_ids) 154 | position = i + len(prefix_ids) 155 | return position 156 | 157 | def branch( 158 | self, text: str, prefix: str, prompt: Optional[str] = None, **kwargs 159 | ) -> List[Tuple[str, float]]: 160 | _, scores = self.generate(text, prompt=prompt, **kwargs) 161 | token_ids = scores.argmax(dim=-1).int().tolist() 162 | i = self.find_prefix_end(token_ids, prefix) 163 | 164 | pairs = [] 165 | for j in torch.argsort(scores[i])[-self.top_k :]: 166 | p = (prompt or "") + self.gen.decode([token_ids[:i] + [j]])[0] 167 | pairs.append((p, scores[i, j].item())) 168 | 169 | return pairs 170 | 171 | def run(self, text: str) -> List[RelationSentence]: 172 | x = self.encoder.encode_x(text) 173 | outputs = [] 174 | 175 | for prompt_a, score_a in self.branch(x, prefix="Head Entity :"): 176 | for prompt_b, score_b in self.branch( 177 | x, prefix=" Tail Entity :", prompt=prompt_a 178 | ): 179 | output, scores = self.generate(x, prompt=prompt_b) 180 | token_ids = token_ids = scores.argmax(dim=-1).int().tolist() 181 | i = self.find_prefix_end(token_ids, prefix=" Relation :") 182 | score_c = max(scores[i].tolist()) 183 | s = self.encoder.safe_decode(x=x, y=output) 184 | s = self.constraint.run(s, scores) 185 | # score_c = s.score # From LabelConstraint 186 | s.score = (score_a + score_b + score_c) / 3 187 | outputs.append(s) 188 | 189 | return outputs 190 | 191 | 192 | if __name__ == "__main__": 193 | Fire() 194 | -------------------------------------------------------------------------------- /modeling.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from fire import Fire 6 | from tqdm import tqdm 7 | from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer, 8 | IntervalStrategy, Pipeline, TrainingArguments, 9 | pipeline, set_seed) 10 | 11 | from encoding import select_encoder 12 | from generation import TextGenerator 13 | from transformer_base import run_clm, run_summarization 14 | from utils import DynamicModel, RelationData, RelationSentence 15 | 16 | 17 | class RelationModel(DynamicModel): 18 | model_dir: str 19 | data_dir: str 20 | model_name: str 21 | do_pretrain: bool 22 | encoder_name: str 23 | pipe_name: str 24 | batch_size: int = 64 25 | grad_accumulation: int = 2 26 | random_seed: int = 42 27 | warmup_ratio: float = 0.2 28 | lr_pretrain: float = 3e-4 29 | lr_finetune: float = 3e-5 30 | epochs_pretrain: int = 3 31 | epochs_finetune: int = 5 32 | train_fp16: bool = True 33 | 34 | def fit(self, path_train: str, path_dev: Optional[str] = None): 35 | raise NotImplementedError 36 | 37 | def run(self, *args, **kwargs): 38 | raise NotImplementedError 39 | 40 | def decode(self, *args, **kwargs): 41 | raise NotImplementedError 42 | 43 | def get_lr(self) -> float: 44 | return self.lr_pretrain if self.do_pretrain else self.lr_finetune 45 | 46 | def get_epochs(self) -> int: 47 | return self.epochs_pretrain if self.do_pretrain else self.epochs_finetune 48 | 49 | def make_pipe(self, **kwargs) -> Pipeline: 50 | pipe = pipeline( 51 | self.pipe_name, 52 | model=self.model_dir, 53 | tokenizer=self.model_name, 54 | device=0 if torch.cuda.is_available() else -1, 55 | **kwargs, 56 | ) 57 | return pipe 58 | 59 | def get_encoder(self): 60 | return select_encoder(self.encoder_name) 61 | 62 | def get_train_args(self, do_eval: bool) -> TrainingArguments: 63 | return TrainingArguments( 64 | seed=self.random_seed, 65 | do_train=True, 66 | do_eval=do_eval or None, # False still becomes True after parsing 67 | overwrite_output_dir=True, 68 | per_device_train_batch_size=self.batch_size, 69 | gradient_accumulation_steps=self.grad_accumulation, 70 | warmup_ratio=self.warmup_ratio, 71 | output_dir=self.model_dir, 72 | save_strategy=IntervalStrategy.EPOCH, 73 | evaluation_strategy=IntervalStrategy.EPOCH 74 | if do_eval 75 | else IntervalStrategy.NO, 76 | learning_rate=self.get_lr(), 77 | num_train_epochs=self.get_epochs(), 78 | load_best_model_at_end=True, 79 | fp16=self.train_fp16, 80 | ) 81 | 82 | 83 | class RelationGenerator(RelationModel): 84 | model_name: str = "gpt2" 85 | block_size: int = 128 86 | encoder_name: str = "gpt_new_generate" 87 | pipe_name: str = "text-generation" 88 | 89 | def fit(self, path_train: str, path_dev: Optional[str] = None): 90 | data_args = run_clm.DataTrainingArguments( 91 | concat_texts=False, 92 | train_file=path_train, 93 | validation_file=path_dev, 94 | overwrite_cache=True, 95 | block_size=self.block_size, 96 | ) 97 | train_args = self.get_train_args(do_eval=path_dev is not None) 98 | model_args = run_clm.ModelArguments(model_name_or_path=self.model_name) 99 | run_clm.main( 100 | model_args=model_args, training_args=train_args, data_args=data_args 101 | ) 102 | 103 | def generate( 104 | self, relation: str, num: int, pipe: Pipeline 105 | ) -> Tuple[List[RelationSentence], List[str]]: 106 | set_seed(self.random_seed) 107 | encoder = self.get_encoder() 108 | prompt = encoder.encode_x(relation) 109 | sents, raw = [], [] 110 | errors = set() 111 | 112 | while len(sents) < num: 113 | outputs = pipe( 114 | [prompt], 115 | num_return_sequences=self.batch_size, 116 | max_length=self.block_size, 117 | ) 118 | for o in outputs: 119 | raw.append(o["generated_text"] + "\n") 120 | x, y = encoder.parse_line(raw[-1]) 121 | try: 122 | s = encoder.decode(x=prompt, y=y) 123 | if s.is_valid(): 124 | sents.append(s) 125 | except Exception as e: 126 | errors.add(str(e)) 127 | 128 | print(dict(target=num, success=len(sents), raw=len(raw))) 129 | 130 | assert len(sents) >= num 131 | print(dict(prompt=prompt, success_rate=len(sents) / len(raw), errors=errors)) 132 | return sents[:num], raw 133 | 134 | def run( 135 | self, 136 | labels: List[str], 137 | path_out: Path, 138 | num_samples_per_relation: int, 139 | device: torch.device = torch.device("cuda"), 140 | ) -> RelationData: 141 | pipe = self.make_pipe() 142 | sents_all, raw_all = [], [] 143 | for relation in tqdm(labels): 144 | sents, raw = self.generate(relation, num_samples_per_relation, pipe=pipe) 145 | sents_all.extend(sents) 146 | raw_all.extend(raw) 147 | 148 | with open(path_out, "w") as f: 149 | f.write("".join(raw_all)) 150 | 151 | data = RelationData(sents=sents_all) 152 | return data 153 | 154 | def decode(self, *args, **kwargs): 155 | pass 156 | 157 | 158 | class NewRelationGenerator(RelationModel): 159 | model_name: str = "facebook/bart-base" 160 | max_source_length: int = 128 161 | max_target_length: int = 128 162 | encoder_name: str = "new_generate" 163 | pipe_name: str = "summarization" 164 | 165 | def fit(self, path_train: str, path_dev: Optional[str] = None): 166 | kwargs = {} 167 | 168 | data_args = run_summarization.DataTrainingArguments( 169 | train_file=path_train, 170 | validation_file=path_dev, 171 | overwrite_cache=True, 172 | max_target_length=self.max_target_length, 173 | max_source_length=self.max_source_length, 174 | **kwargs, 175 | ) 176 | train_args = self.get_train_args(do_eval=path_dev is not None) 177 | kwargs = { 178 | k: v for k, v in train_args.to_dict().items() if not k.startswith("_") 179 | } 180 | train_args = run_summarization.Seq2SeqTrainingArguments(**kwargs) 181 | model_args = run_summarization.ModelArguments( 182 | model_name_or_path=self.model_name 183 | ) 184 | run_summarization.main( 185 | model_args=model_args, training_args=train_args, data_args=data_args 186 | ) 187 | 188 | def load_generator(self, device: torch.device) -> TextGenerator: 189 | gen = TextGenerator( 190 | model=AutoModelForSeq2SeqLM.from_pretrained(self.model_dir), 191 | tokenizer=AutoTokenizer.from_pretrained(self.model_dir), 192 | max_length=self.max_target_length, 193 | ) 194 | gen.model = gen.model.to(device) 195 | return gen 196 | 197 | def generate( 198 | self, relation: str, num: int, gen: TextGenerator 199 | ) -> Tuple[List[RelationSentence], List[str]]: 200 | set_seed(self.random_seed) 201 | encoder = self.get_encoder() 202 | prompt = encoder.encode_x(relation) 203 | sents, raw = [], [] 204 | errors = set() 205 | 206 | while len(sents) < num: 207 | outputs = gen.run([prompt], num_return=self.batch_size) 208 | for o in outputs: 209 | raw.append(run_summarization.encode_to_line(x=prompt, y=o)) 210 | try: 211 | s = encoder.decode(x=prompt, y=o) 212 | if s.is_valid(): 213 | sents.append(s) 214 | except Exception as e: 215 | errors.add(str(e)) 216 | 217 | print(dict(target=num, success=len(sents), raw=len(raw))) 218 | 219 | assert len(sents) >= num 220 | print(dict(prompt=prompt, success_rate=len(sents) / len(raw), errors=errors)) 221 | return sents[:num], raw 222 | 223 | def run( 224 | self, 225 | labels: List[str], 226 | path_out: Path, 227 | num_samples_per_relation: int, 228 | device: torch.device = torch.device("cuda"), 229 | ) -> RelationData: 230 | gen = self.load_generator(device=device) 231 | sents_all, raw_all = [], [] 232 | for relation in tqdm(labels): 233 | sents, raw = self.generate(relation, num_samples_per_relation, gen=gen) 234 | sents_all.extend(sents) 235 | raw_all.extend(raw) 236 | 237 | with open(path_out, "w") as f: 238 | f.write("".join(raw_all)) 239 | 240 | data = RelationData(sents=sents_all) 241 | return data 242 | 243 | def decode(self, *args, **kwargs): 244 | pass 245 | 246 | 247 | class NewRelationExtractor(NewRelationGenerator): 248 | encoder_name: str = "new_extract" 249 | 250 | @staticmethod 251 | def gen_texts(texts: List[str], gen: TextGenerator, **kwargs): 252 | return gen.run(texts, do_sample=False, num_return=1, **kwargs) 253 | 254 | def run( 255 | self, 256 | texts: List[str], 257 | path_out: Path, 258 | batch_size: int = 512, 259 | device: torch.device = torch.device("cuda"), 260 | ): 261 | set_seed(self.random_seed) 262 | encoder = self.get_encoder() 263 | prompts = [encoder.encode_x(t) for t in texts] 264 | gen = self.load_generator(device=device) 265 | preds = [] 266 | 267 | for i in tqdm(range(0, len(texts), batch_size), desc="RelationExtractor.run"): 268 | batch = prompts[i : i + batch_size] 269 | outputs = self.gen_texts(batch, gen) 270 | preds.extend(outputs) 271 | 272 | path_out.parent.mkdir(exist_ok=True, parents=True) 273 | with open(path_out, "w") as f: 274 | for x, y in zip(prompts, preds): 275 | f.write(run_summarization.encode_to_line(x=x, y=y)) 276 | 277 | def decode(self, path: Path) -> RelationData: 278 | encoder = self.get_encoder() 279 | with open(path) as f: 280 | sents = [encoder.safe_decode(*encoder.parse_line(line)) for line in f] 281 | 282 | success_rate = len([s for s in sents if s.is_valid()]) / len(sents) 283 | print(dict(success_rate=success_rate)) 284 | data = RelationData(sents=sents) 285 | return data 286 | 287 | 288 | def select_model(name: str, **kwargs) -> RelationModel: 289 | mapping = dict( 290 | generate=RelationGenerator(**kwargs), 291 | new_generate=NewRelationGenerator(**kwargs), 292 | new_extract=NewRelationExtractor(**kwargs), 293 | ) 294 | model = mapping[name] 295 | print(dict(select_model=model)) 296 | return model 297 | 298 | 299 | if __name__ == "__main__": 300 | Fire() 301 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10.0 2 | transformers==4.7.0 3 | datasets==1.11.0 4 | pandas==1.2.4 5 | pydantic==1.8.2 6 | fastavro==1.4.0 7 | fire==0.4.0 8 | nltk==3.6.6 9 | lxml==4.6.5 10 | editdistance==0.5.3 11 | seqeval==1.2.2 12 | -------------------------------------------------------------------------------- /transformer_base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/declare-lab/RelationPrompt/ab415648e05d2feb8ba1e94d7726f2032273d14e/transformer_base/__init__.py -------------------------------------------------------------------------------- /transformer_base/run_clm.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2020 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset. 18 | Here is the full list of checkpoints on the hub that can be fine-tuned by this script: 19 | https://huggingface.co/models?filter=causal-lm 20 | Adapted from: https://github.com/huggingface/transformers/blob/v4.7.0/examples/pytorch/language-modeling/run_clm.py 21 | """ 22 | # You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments. 23 | 24 | import logging 25 | import math 26 | import os 27 | import sys 28 | from dataclasses import dataclass, field 29 | from typing import Optional 30 | 31 | import transformers 32 | from datasets import load_dataset 33 | from transformers import (CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, 34 | AutoConfig, AutoModelForCausalLM, AutoTokenizer, 35 | HfArgumentParser, Trainer, TrainingArguments, 36 | default_data_collator, set_seed) 37 | from transformers.testing_utils import CaptureLogger 38 | from transformers.trainer_utils import get_last_checkpoint 39 | from transformers.utils import check_min_version 40 | 41 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 42 | check_min_version("4.7.0") 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | 47 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 48 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 49 | 50 | 51 | @dataclass 52 | class ModelArguments: 53 | """ 54 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. 55 | """ 56 | 57 | model_name_or_path: Optional[str] = field( 58 | default=None, 59 | metadata={ 60 | "help": "The model checkpoint for weights initialization." 61 | "Don't set if you want to train a model from scratch." 62 | }, 63 | ) 64 | model_type: Optional[str] = field( 65 | default=None, 66 | metadata={ 67 | "help": "If training from scratch, pass a model type from the list: " 68 | + ", ".join(MODEL_TYPES) 69 | }, 70 | ) 71 | config_overrides: Optional[str] = field( 72 | default=None, 73 | metadata={ 74 | "help": "Override some existing default config settings when a model is trained from scratch. Example: " 75 | "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" 76 | }, 77 | ) 78 | config_name: Optional[str] = field( 79 | default=None, 80 | metadata={ 81 | "help": "Pretrained config name or path if not the same as model_name" 82 | }, 83 | ) 84 | tokenizer_name: Optional[str] = field( 85 | default=None, 86 | metadata={ 87 | "help": "Pretrained tokenizer name or path if not the same as model_name" 88 | }, 89 | ) 90 | cache_dir: Optional[str] = field( 91 | default=None, 92 | metadata={ 93 | "help": "Where do you want to store the pretrained models downloaded from huggingface.co" 94 | }, 95 | ) 96 | use_fast_tokenizer: bool = field( 97 | default=True, 98 | metadata={ 99 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 100 | }, 101 | ) 102 | model_revision: str = field( 103 | default="main", 104 | metadata={ 105 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 106 | }, 107 | ) 108 | use_auth_token: bool = field( 109 | default=False, 110 | metadata={ 111 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 112 | "with private models)." 113 | }, 114 | ) 115 | 116 | def __post_init__(self): 117 | if self.config_overrides is not None and ( 118 | self.config_name is not None or self.model_name_or_path is not None 119 | ): 120 | raise ValueError( 121 | "--config_overrides can't be used in combination with --config_name or --model_name_or_path" 122 | ) 123 | 124 | 125 | @dataclass 126 | class DataTrainingArguments: 127 | """ 128 | Arguments pertaining to what data we are going to input our model for training and eval. 129 | """ 130 | 131 | dataset_name: Optional[str] = field( 132 | default=None, 133 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 134 | ) 135 | dataset_config_name: Optional[str] = field( 136 | default=None, 137 | metadata={ 138 | "help": "The configuration name of the dataset to use (via the datasets library)." 139 | }, 140 | ) 141 | train_file: Optional[str] = field( 142 | default=None, metadata={"help": "The input training data file (a text file)."} 143 | ) 144 | validation_file: Optional[str] = field( 145 | default=None, 146 | metadata={ 147 | "help": "An optional input evaluation data file to evaluate the perplexity on (a text file)." 148 | }, 149 | ) 150 | max_train_samples: Optional[int] = field( 151 | default=None, 152 | metadata={ 153 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 154 | "value if set." 155 | }, 156 | ) 157 | max_eval_samples: Optional[int] = field( 158 | default=None, 159 | metadata={ 160 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 161 | "value if set." 162 | }, 163 | ) 164 | 165 | block_size: Optional[int] = field( 166 | default=None, 167 | metadata={ 168 | "help": "Optional input sequence length after tokenization. " 169 | "The training dataset will be truncated in block of this size for training. " 170 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 171 | }, 172 | ) 173 | overwrite_cache: bool = field( 174 | default=False, 175 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 176 | ) 177 | validation_split_percentage: Optional[int] = field( 178 | default=5, 179 | metadata={ 180 | "help": "The percentage of the train set used as validation set in case there's no validation split" 181 | }, 182 | ) 183 | preprocessing_num_workers: Optional[int] = field( 184 | default=None, 185 | metadata={"help": "The number of processes to use for the preprocessing."}, 186 | ) 187 | concat_texts: bool = field( 188 | default=True, 189 | metadata={"help": "Concatenate all lines from dataset and draw chunks"}, 190 | ) 191 | 192 | tokenizer_kwargs: Optional[dict] = field( 193 | default=None, 194 | metadata={"help": "Extra keyword arguments to initialize tokenizer"}, 195 | ) 196 | 197 | def __post_init__(self): 198 | if ( 199 | self.dataset_name is None 200 | and self.train_file is None 201 | and self.validation_file is None 202 | ): 203 | raise ValueError( 204 | "Need either a dataset name or a training/validation file." 205 | ) 206 | else: 207 | if self.train_file is not None: 208 | extension = self.train_file.split(".")[-1] 209 | assert extension in [ 210 | "csv", 211 | "json", 212 | "txt", 213 | ], "`train_file` should be a csv, a json or a txt file." 214 | if self.validation_file is not None: 215 | extension = self.validation_file.split(".")[-1] 216 | assert extension in [ 217 | "csv", 218 | "json", 219 | "txt", 220 | ], "`validation_file` should be a csv, a json or a txt file." 221 | 222 | 223 | def main( 224 | model_args: ModelArguments = None, 225 | data_args: DataTrainingArguments = None, 226 | training_args: TrainingArguments = None, 227 | ): 228 | # See all possible arguments in src/transformers/training_args.py 229 | # or by passing the --help flag to this script. 230 | # We now keep distinct sets of args, for a cleaner separation of concerns. 231 | 232 | if model_args is None or data_args is None or training_args is None: 233 | print("Using HfArgumentParser") 234 | parser = HfArgumentParser( 235 | (ModelArguments, DataTrainingArguments, TrainingArguments) 236 | ) 237 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 238 | # If we pass only one argument to the script and it's the path to a json file, 239 | # let's parse it to get our arguments. 240 | model_args, data_args, training_args = parser.parse_json_file( 241 | json_file=os.path.abspath(sys.argv[1]) 242 | ) 243 | else: 244 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 245 | 246 | # Setup logging 247 | logging.basicConfig( 248 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 249 | datefmt="%m/%d/%Y %H:%M:%S", 250 | handlers=[logging.StreamHandler(sys.stdout)], 251 | ) 252 | logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) 253 | 254 | # Log on each process the small summary: 255 | logger.warning( 256 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 257 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 258 | ) 259 | # Set the verbosity to info of the Transformers logger (on main process only): 260 | if training_args.should_log: 261 | transformers.utils.logging.set_verbosity_info() 262 | transformers.utils.logging.enable_default_handler() 263 | transformers.utils.logging.enable_explicit_format() 264 | logger.info(f"Training/evaluation parameters {training_args}") 265 | 266 | # Detecting last checkpoint. 267 | last_checkpoint = None 268 | if ( 269 | os.path.isdir(training_args.output_dir) 270 | and training_args.do_train 271 | and not training_args.overwrite_output_dir 272 | ): 273 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 274 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 275 | raise ValueError( 276 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 277 | "Use --overwrite_output_dir to overcome." 278 | ) 279 | elif ( 280 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 281 | ): 282 | logger.info( 283 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 284 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 285 | ) 286 | 287 | # Set seed before initializing model. 288 | set_seed(training_args.seed) 289 | 290 | # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) 291 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 292 | # (the dataset will be downloaded automatically from the datasets Hub). 293 | # 294 | # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called 295 | # 'text' is found. You can easily tweak this behavior (see below). 296 | # 297 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 298 | # download the dataset. 299 | if data_args.dataset_name is not None: 300 | # Downloading and loading a dataset from the hub. 301 | datasets = load_dataset( 302 | data_args.dataset_name, 303 | data_args.dataset_config_name, 304 | cache_dir=model_args.cache_dir, 305 | ) 306 | if "validation" not in datasets.keys(): 307 | datasets["validation"] = load_dataset( 308 | data_args.dataset_name, 309 | data_args.dataset_config_name, 310 | split=f"train[:{data_args.validation_split_percentage}%]", 311 | cache_dir=model_args.cache_dir, 312 | ) 313 | datasets["train"] = load_dataset( 314 | data_args.dataset_name, 315 | data_args.dataset_config_name, 316 | split=f"train[{data_args.validation_split_percentage}%:]", 317 | cache_dir=model_args.cache_dir, 318 | ) 319 | else: 320 | data_files = {} 321 | if data_args.train_file is not None: 322 | data_files["train"] = data_args.train_file 323 | if data_args.validation_file is not None: 324 | data_files["validation"] = data_args.validation_file 325 | extension = ( 326 | data_args.train_file.split(".")[-1] 327 | if data_args.train_file is not None 328 | else data_args.validation_file.split(".")[-1] 329 | ) 330 | if extension == "txt": 331 | extension = "text" 332 | datasets = load_dataset( 333 | extension, data_files=data_files, cache_dir=model_args.cache_dir 334 | ) 335 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 336 | # https://huggingface.co/docs/datasets/loading_datasets.html. 337 | 338 | # Load pretrained model and tokenizer 339 | # 340 | # Distributed training: 341 | # The .from_pretrained methods guarantee that only one local process can concurrently 342 | # download model & vocab. 343 | 344 | config_kwargs = { 345 | "cache_dir": model_args.cache_dir, 346 | "revision": model_args.model_revision, 347 | "use_auth_token": True if model_args.use_auth_token else None, 348 | } 349 | if model_args.config_name: 350 | config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) 351 | elif model_args.model_name_or_path: 352 | config = AutoConfig.from_pretrained( 353 | model_args.model_name_or_path, **config_kwargs 354 | ) 355 | else: 356 | config = CONFIG_MAPPING[model_args.model_type]() 357 | logger.warning("You are instantiating a new config instance from scratch.") 358 | if model_args.config_overrides is not None: 359 | logger.info(f"Overriding config: {model_args.config_overrides}") 360 | config.update_from_string(model_args.config_overrides) 361 | 362 | tokenizer_kwargs = { 363 | "cache_dir": model_args.cache_dir, 364 | "use_fast": model_args.use_fast_tokenizer, 365 | "revision": model_args.model_revision, 366 | "use_auth_token": True if model_args.use_auth_token else None, 367 | } 368 | if data_args.tokenizer_kwargs: 369 | tokenizer_kwargs.update(**data_args.tokenizer_kwargs) 370 | if model_args.tokenizer_name: 371 | tokenizer = AutoTokenizer.from_pretrained( 372 | model_args.tokenizer_name, **tokenizer_kwargs 373 | ) 374 | elif model_args.model_name_or_path: 375 | tokenizer = AutoTokenizer.from_pretrained( 376 | model_args.model_name_or_path, **tokenizer_kwargs 377 | ) 378 | else: 379 | raise ValueError( 380 | "You are instantiating a new tokenizer from scratch. This is not supported by this script." 381 | "You can do it from another script, save it, and load it from here, using --tokenizer_name." 382 | ) 383 | print(tokenizer) 384 | 385 | if model_args.model_name_or_path: 386 | model = AutoModelForCausalLM.from_pretrained( 387 | model_args.model_name_or_path, 388 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 389 | config=config, 390 | cache_dir=model_args.cache_dir, 391 | revision=model_args.model_revision, 392 | use_auth_token=True if model_args.use_auth_token else None, 393 | ) 394 | else: 395 | model = AutoModelForCausalLM.from_config(config) 396 | n_params = sum( 397 | dict((p.data_ptr(), p.numel()) for p in model.parameters()).values() 398 | ) 399 | logger.info( 400 | f"Training new model from scratch - Total size={n_params/2**20:.2f}M params" 401 | ) 402 | 403 | model.resize_token_embeddings(len(tokenizer)) 404 | 405 | # Preprocessing the datasets. 406 | # First we tokenize all the texts. 407 | if training_args.do_train: 408 | column_names = datasets["train"].column_names 409 | else: 410 | column_names = datasets["validation"].column_names 411 | text_column_name = "text" if "text" in column_names else column_names[0] 412 | 413 | # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function 414 | tok_logger = transformers.utils.logging.get_logger( 415 | "transformers.tokenization_utils_base" 416 | ) 417 | if not data_args.concat_texts: 418 | tokenizer.pad_token = tokenizer.eos_token 419 | 420 | def tokenize_function(examples): 421 | with CaptureLogger(tok_logger) as cl: 422 | output = tokenizer(examples[text_column_name]) 423 | 424 | if not data_args.concat_texts: 425 | output = tokenizer( 426 | examples[text_column_name], 427 | truncation=True, 428 | padding="max_length", 429 | max_length=data_args.block_size, 430 | ) 431 | 432 | # clm input could be much much longer than block_size 433 | if "Token indices sequence length is longer than the" in cl.out: 434 | tok_logger.warning( 435 | "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." 436 | ) 437 | return output 438 | 439 | tokenized_datasets = datasets.map( 440 | tokenize_function, 441 | batched=True, 442 | num_proc=data_args.preprocessing_num_workers, 443 | remove_columns=column_names, 444 | load_from_cache_file=not data_args.overwrite_cache, 445 | ) 446 | 447 | if data_args.block_size is None: 448 | block_size = tokenizer.model_max_length 449 | if block_size > 1024: 450 | logger.warning( 451 | f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " 452 | "Picking 1024 instead. You can change that default value by passing --block_size xxx." 453 | ) 454 | block_size = 1024 455 | else: 456 | if data_args.block_size > tokenizer.model_max_length: 457 | logger.warning( 458 | f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" 459 | f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." 460 | ) 461 | block_size = min(data_args.block_size, tokenizer.model_max_length) 462 | 463 | # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. 464 | def group_texts(examples): 465 | if not data_args.concat_texts: 466 | result = examples 467 | result["labels"] = result["input_ids"].copy() 468 | return result 469 | 470 | # Concatenate all texts. 471 | concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} 472 | total_length = len(concatenated_examples[list(examples.keys())[0]]) 473 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 474 | # customize this part to your needs. 475 | total_length = (total_length // block_size) * block_size 476 | # Split by chunks of max_len. 477 | result = { 478 | k: [t[i : i + block_size] for i in range(0, total_length, block_size)] 479 | for k, t in concatenated_examples.items() 480 | } 481 | result["labels"] = result["input_ids"].copy() 482 | return result 483 | 484 | # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder 485 | # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower 486 | # to preprocess. 487 | # 488 | # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: 489 | # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map 490 | 491 | lm_datasets = tokenized_datasets.map( 492 | group_texts, 493 | batched=True, 494 | num_proc=data_args.preprocessing_num_workers, 495 | load_from_cache_file=not data_args.overwrite_cache, 496 | ) 497 | 498 | if training_args.do_train: 499 | if "train" not in tokenized_datasets: 500 | raise ValueError("--do_train requires a train dataset") 501 | train_dataset = lm_datasets["train"] 502 | if data_args.max_train_samples is not None: 503 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 504 | 505 | if training_args.do_eval: 506 | if "validation" not in tokenized_datasets: 507 | raise ValueError("--do_eval requires a validation dataset") 508 | eval_dataset = lm_datasets["validation"] 509 | if data_args.max_eval_samples is not None: 510 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 511 | 512 | # Initialize our Trainer 513 | trainer = Trainer( 514 | model=model, 515 | args=training_args, 516 | train_dataset=train_dataset if training_args.do_train else None, 517 | eval_dataset=eval_dataset if training_args.do_eval else None, 518 | tokenizer=tokenizer, 519 | # Data collator will default to DataCollatorWithPadding, so we change it. 520 | data_collator=default_data_collator, 521 | ) 522 | 523 | # Training 524 | if training_args.do_train: 525 | checkpoint = None 526 | if training_args.resume_from_checkpoint is not None: 527 | checkpoint = training_args.resume_from_checkpoint 528 | elif last_checkpoint is not None: 529 | checkpoint = last_checkpoint 530 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 531 | trainer.save_model() # Saves the tokenizer too for easy upload 532 | 533 | metrics = train_result.metrics 534 | 535 | max_train_samples = ( 536 | data_args.max_train_samples 537 | if data_args.max_train_samples is not None 538 | else len(train_dataset) 539 | ) 540 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 541 | 542 | trainer.log_metrics("train", metrics) 543 | trainer.save_metrics("train", metrics) 544 | trainer.save_state() 545 | 546 | # Evaluation 547 | if training_args.do_eval: 548 | logger.info("*** Evaluate ***") 549 | 550 | metrics = trainer.evaluate() 551 | 552 | max_eval_samples = ( 553 | data_args.max_eval_samples 554 | if data_args.max_eval_samples is not None 555 | else len(eval_dataset) 556 | ) 557 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 558 | try: 559 | perplexity = math.exp(metrics["eval_loss"]) 560 | except OverflowError: 561 | perplexity = float("inf") 562 | metrics["perplexity"] = perplexity 563 | 564 | trainer.log_metrics("eval", metrics) 565 | trainer.save_metrics("eval", metrics) 566 | 567 | if training_args.push_to_hub: 568 | kwargs = { 569 | "finetuned_from": model_args.model_name_or_path, 570 | "tasks": "text-generation", 571 | } 572 | if data_args.dataset_name is not None: 573 | kwargs["dataset_tags"] = data_args.dataset_name 574 | if data_args.dataset_config_name is not None: 575 | kwargs["dataset_args"] = data_args.dataset_config_name 576 | kwargs[ 577 | "dataset" 578 | ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 579 | else: 580 | kwargs["dataset"] = data_args.dataset_name 581 | 582 | trainer.push_to_hub(**kwargs) 583 | 584 | 585 | def _mp_fn(index): 586 | # For xla_spawn (TPUs) 587 | main() 588 | 589 | 590 | if __name__ == "__main__": 591 | main() 592 | -------------------------------------------------------------------------------- /transformer_base/run_summarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | Adapted from: https://github.com/huggingface/transformers/blob/v4.7.0/examples/pytorch/summarization/run_summarization.py 19 | """ 20 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 21 | 22 | import json 23 | import logging 24 | import os 25 | import sys 26 | from dataclasses import dataclass, field 27 | from typing import Optional, Tuple 28 | 29 | import transformers 30 | from datasets import load_dataset 31 | from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, 32 | DataCollatorForSeq2Seq, HfArgumentParser, 33 | Seq2SeqTrainer, Seq2SeqTrainingArguments, set_seed) 34 | from transformers.trainer_utils import get_last_checkpoint 35 | from transformers.utils import check_min_version 36 | 37 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 38 | check_min_version("4.7.0") 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | 43 | @dataclass 44 | class ModelArguments: 45 | """ 46 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 47 | """ 48 | 49 | model_name_or_path: str = field( 50 | metadata={ 51 | "help": "Path to pretrained model or model identifier from huggingface.co/models" 52 | } 53 | ) 54 | config_name: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "Pretrained config name or path if not the same as model_name" 58 | }, 59 | ) 60 | tokenizer_name: Optional[str] = field( 61 | default=None, 62 | metadata={ 63 | "help": "Pretrained tokenizer name or path if not the same as model_name" 64 | }, 65 | ) 66 | cache_dir: Optional[str] = field( 67 | default=None, 68 | metadata={ 69 | "help": "Where to store the pretrained models downloaded from huggingface.co" 70 | }, 71 | ) 72 | use_fast_tokenizer: bool = field( 73 | default=True, 74 | metadata={ 75 | "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." 76 | }, 77 | ) 78 | model_revision: str = field( 79 | default="main", 80 | metadata={ 81 | "help": "The specific model version to use (can be a branch name, tag name or commit id)." 82 | }, 83 | ) 84 | use_auth_token: bool = field( 85 | default=False, 86 | metadata={ 87 | "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " 88 | "with private models)." 89 | }, 90 | ) 91 | 92 | 93 | @dataclass 94 | class DataTrainingArguments: 95 | """ 96 | Arguments pertaining to what data we are going to input our model for training and eval. 97 | """ 98 | 99 | dataset_name: Optional[str] = field( 100 | default=None, 101 | metadata={"help": "The name of the dataset to use (via the datasets library)."}, 102 | ) 103 | dataset_config_name: Optional[str] = field( 104 | default=None, 105 | metadata={ 106 | "help": "The configuration name of the dataset to use (via the datasets library)." 107 | }, 108 | ) 109 | text_column: Optional[str] = field( 110 | default=None, 111 | metadata={ 112 | "help": "The name of the column in the datasets containing the full texts (for summarization)." 113 | }, 114 | ) 115 | summary_column: Optional[str] = field( 116 | default=None, 117 | metadata={ 118 | "help": "The name of the column in the datasets containing the summaries (for summarization)." 119 | }, 120 | ) 121 | train_file: Optional[str] = field( 122 | default=None, 123 | metadata={"help": "The input training data file (a jsonlines or csv file)."}, 124 | ) 125 | validation_file: Optional[str] = field( 126 | default=None, 127 | metadata={ 128 | "help": "An optional input evaluation data file to evaluate the metrics (rouge) on " 129 | "(a jsonlines or csv file)." 130 | }, 131 | ) 132 | test_file: Optional[str] = field( 133 | default=None, 134 | metadata={ 135 | "help": "An optional input test data file to evaluate the metrics (rouge) on " 136 | "(a jsonlines or csv file)." 137 | }, 138 | ) 139 | overwrite_cache: bool = field( 140 | default=False, 141 | metadata={"help": "Overwrite the cached training and evaluation sets"}, 142 | ) 143 | preprocessing_num_workers: Optional[int] = field( 144 | default=None, 145 | metadata={"help": "The number of processes to use for the preprocessing."}, 146 | ) 147 | max_source_length: Optional[int] = field( 148 | default=1024, 149 | metadata={ 150 | "help": "The maximum total input sequence length after tokenization. Sequences longer " 151 | "than this will be truncated, sequences shorter will be padded." 152 | }, 153 | ) 154 | max_target_length: Optional[int] = field( 155 | default=128, 156 | metadata={ 157 | "help": "The maximum total sequence length for target text after tokenization. Sequences longer " 158 | "than this will be truncated, sequences shorter will be padded." 159 | }, 160 | ) 161 | val_max_target_length: Optional[int] = field( 162 | default=None, 163 | metadata={ 164 | "help": "The maximum total sequence length for validation target text after tokenization. Sequences longer " 165 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 166 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 167 | "during ``evaluate`` and ``predict``." 168 | }, 169 | ) 170 | pad_to_max_length: bool = field( 171 | default=False, 172 | metadata={ 173 | "help": "Whether to pad all samples to model maximum sentence length. " 174 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 175 | "efficient on GPU but very bad for TPU." 176 | }, 177 | ) 178 | max_train_samples: Optional[int] = field( 179 | default=None, 180 | metadata={ 181 | "help": "For debugging purposes or quicker training, truncate the number of training examples to this " 182 | "value if set." 183 | }, 184 | ) 185 | max_eval_samples: Optional[int] = field( 186 | default=None, 187 | metadata={ 188 | "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 189 | "value if set." 190 | }, 191 | ) 192 | max_predict_samples: Optional[int] = field( 193 | default=None, 194 | metadata={ 195 | "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this " 196 | "value if set." 197 | }, 198 | ) 199 | num_beams: Optional[int] = field( 200 | default=None, 201 | metadata={ 202 | "help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 203 | "which is used during ``evaluate`` and ``predict``." 204 | }, 205 | ) 206 | ignore_pad_token_for_loss: bool = field( 207 | default=True, 208 | metadata={ 209 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 210 | }, 211 | ) 212 | source_prefix: Optional[str] = field( 213 | default=None, 214 | metadata={ 215 | "help": "A prefix to add before every source text (useful for T5 models)." 216 | }, 217 | ) 218 | 219 | tokenizer_kwargs: Optional[dict] = field( 220 | default=None, 221 | metadata={"help": "Extra keyword arguments to initialize tokenizer"}, 222 | ) 223 | 224 | def __post_init__(self): 225 | if ( 226 | self.dataset_name is None 227 | and self.train_file is None 228 | and self.validation_file is None 229 | ): 230 | raise ValueError( 231 | "Need either a dataset name or a training/validation file." 232 | ) 233 | else: 234 | if self.train_file is not None: 235 | extension = self.train_file.split(".")[-1] 236 | assert extension in [ 237 | "csv", 238 | "json", 239 | ], "`train_file` should be a csv or a json file." 240 | if self.validation_file is not None: 241 | extension = self.validation_file.split(".")[-1] 242 | assert extension in [ 243 | "csv", 244 | "json", 245 | ], "`validation_file` should be a csv or a json file." 246 | if self.val_max_target_length is None: 247 | self.val_max_target_length = self.max_target_length 248 | 249 | 250 | summarization_name_mapping = { 251 | "amazon_reviews_multi": ("review_body", "review_title"), 252 | "big_patent": ("description", "abstract"), 253 | "cnn_dailymail": ("article", "highlights"), 254 | "orange_sum": ("text", "summary"), 255 | "pn_summary": ("article", "summary"), 256 | "psc": ("extract_text", "summary_text"), 257 | "samsum": ("dialogue", "summary"), 258 | "thaisum": ("body", "summary"), 259 | "xglue": ("news_body", "news_title"), 260 | "xsum": ("document", "summary"), 261 | "wiki_summary": ("article", "highlights"), 262 | } 263 | 264 | 265 | def main( 266 | model_args: ModelArguments = None, 267 | data_args: DataTrainingArguments = None, 268 | training_args: Seq2SeqTrainingArguments = None, 269 | ): 270 | # See all possible arguments in src/transformers/training_args.py 271 | # or by passing the --help flag to this script. 272 | # We now keep distinct sets of args, for a cleaner separation of concerns. 273 | 274 | if model_args is None or data_args is None or training_args is None: 275 | print("Using HfArgumentParser") 276 | parser = HfArgumentParser( 277 | (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments) 278 | ) 279 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 280 | # If we pass only one argument to the script and it's the path to a json file, 281 | # let's parse it to get our arguments. 282 | model_args, data_args, training_args = parser.parse_json_file( 283 | json_file=os.path.abspath(sys.argv[1]) 284 | ) 285 | else: 286 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 287 | 288 | # Setup logging 289 | logging.basicConfig( 290 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 291 | datefmt="%m/%d/%Y %H:%M:%S", 292 | handlers=[logging.StreamHandler(sys.stdout)], 293 | ) 294 | logger.setLevel(logging.INFO if training_args.should_log else logging.WARN) 295 | 296 | # Log on each process the small summary: 297 | logger.warning( 298 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 299 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 300 | ) 301 | # Set the verbosity to info of the Transformers logger (on main process only): 302 | if training_args.should_log: 303 | transformers.utils.logging.set_verbosity_info() 304 | logger.info(f"Training/evaluation parameters {training_args}") 305 | 306 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 307 | "t5-small", 308 | "t5-base", 309 | "t5-large", 310 | "t5-3b", 311 | "t5-11b", 312 | ]: 313 | logger.warning( 314 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 315 | "`--source_prefix 'summarize: ' `" 316 | ) 317 | 318 | # Detecting last checkpoint. 319 | last_checkpoint = None 320 | if ( 321 | os.path.isdir(training_args.output_dir) 322 | and training_args.do_train 323 | and not training_args.overwrite_output_dir 324 | ): 325 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 326 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 327 | raise ValueError( 328 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 329 | "Use --overwrite_output_dir to overcome." 330 | ) 331 | elif ( 332 | last_checkpoint is not None and training_args.resume_from_checkpoint is None 333 | ): 334 | logger.info( 335 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 336 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 337 | ) 338 | 339 | # Set seed before initializing model. 340 | set_seed(training_args.seed) 341 | 342 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 343 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 344 | # (the dataset will be downloaded automatically from the datasets Hub). 345 | # 346 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 347 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 348 | # 349 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 350 | # download the dataset. 351 | if data_args.dataset_name is not None: 352 | # Downloading and loading a dataset from the hub. 353 | datasets = load_dataset( 354 | data_args.dataset_name, 355 | data_args.dataset_config_name, 356 | cache_dir=model_args.cache_dir, 357 | ) 358 | else: 359 | data_files = {} 360 | if data_args.train_file is not None: 361 | data_files["train"] = data_args.train_file 362 | extension = data_args.train_file.split(".")[-1] 363 | if data_args.validation_file is not None: 364 | data_files["validation"] = data_args.validation_file 365 | extension = data_args.validation_file.split(".")[-1] 366 | if data_args.test_file is not None: 367 | data_files["test"] = data_args.test_file 368 | extension = data_args.test_file.split(".")[-1] 369 | datasets = load_dataset( 370 | extension, data_files=data_files, cache_dir=model_args.cache_dir 371 | ) 372 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 373 | # https://huggingface.co/docs/datasets/loading_datasets.html. 374 | 375 | # Load pretrained model and tokenizer 376 | # 377 | # Distributed training: 378 | # The .from_pretrained methods guarantee that only one local process can concurrently 379 | # download model & vocab. 380 | config = AutoConfig.from_pretrained( 381 | model_args.config_name 382 | if model_args.config_name 383 | else model_args.model_name_or_path, 384 | cache_dir=model_args.cache_dir, 385 | revision=model_args.model_revision, 386 | use_auth_token=True if model_args.use_auth_token else None, 387 | ) 388 | 389 | tokenizer_kwargs = data_args.tokenizer_kwargs or {} 390 | print(dict(tokenizer_kwargs=tokenizer_kwargs)) 391 | tokenizer = AutoTokenizer.from_pretrained( 392 | model_args.tokenizer_name 393 | if model_args.tokenizer_name 394 | else model_args.model_name_or_path, 395 | cache_dir=model_args.cache_dir, 396 | use_fast=model_args.use_fast_tokenizer, 397 | revision=model_args.model_revision, 398 | use_auth_token=True if model_args.use_auth_token else None, 399 | **tokenizer_kwargs, 400 | ) 401 | print(dict(tokenizer=tokenizer)) 402 | 403 | model = AutoModelForSeq2SeqLM.from_pretrained( 404 | model_args.model_name_or_path, 405 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 406 | config=config, 407 | cache_dir=model_args.cache_dir, 408 | revision=model_args.model_revision, 409 | use_auth_token=True if model_args.use_auth_token else None, 410 | ) 411 | 412 | model.resize_token_embeddings(len(tokenizer)) 413 | 414 | if model.config.decoder_start_token_id is None: 415 | raise ValueError( 416 | "Make sure that `config.decoder_start_token_id` is correctly defined" 417 | ) 418 | 419 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 420 | 421 | # Preprocessing the datasets. 422 | # We need to tokenize inputs and targets. 423 | if training_args.do_train: 424 | column_names = datasets["train"].column_names 425 | elif training_args.do_eval: 426 | column_names = datasets["validation"].column_names 427 | elif training_args.do_predict: 428 | column_names = datasets["test"].column_names 429 | else: 430 | logger.info( 431 | "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." 432 | ) 433 | return 434 | 435 | # Get the column names for input/target. 436 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 437 | if data_args.text_column is None: 438 | text_column = ( 439 | dataset_columns[0] if dataset_columns is not None else column_names[0] 440 | ) 441 | else: 442 | text_column = data_args.text_column 443 | if text_column not in column_names: 444 | raise ValueError( 445 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 446 | ) 447 | if data_args.summary_column is None: 448 | summary_column = ( 449 | dataset_columns[1] if dataset_columns is not None else column_names[1] 450 | ) 451 | else: 452 | summary_column = data_args.summary_column 453 | if summary_column not in column_names: 454 | raise ValueError( 455 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 456 | ) 457 | 458 | # Temporarily set max_target_length for training. 459 | max_target_length = data_args.max_target_length 460 | padding = "max_length" if data_args.pad_to_max_length else False 461 | 462 | if training_args.label_smoothing_factor > 0 and not hasattr( 463 | model, "prepare_decoder_input_ids_from_labels" 464 | ): 465 | logger.warning( 466 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 467 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 468 | ) 469 | 470 | def preprocess_function(examples): 471 | inputs = examples[text_column] 472 | targets = examples[summary_column] 473 | inputs = [prefix + inp for inp in inputs] 474 | model_inputs = tokenizer( 475 | inputs, 476 | max_length=data_args.max_source_length, 477 | padding=padding, 478 | truncation=True, 479 | ) 480 | 481 | # Setup the tokenizer for targets 482 | with tokenizer.as_target_tokenizer(): 483 | labels = tokenizer( 484 | targets, max_length=max_target_length, padding=padding, truncation=True 485 | ) 486 | 487 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 488 | # padding in the loss. 489 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 490 | labels["input_ids"] = [ 491 | [(l if l != tokenizer.pad_token_id else -100) for l in label] 492 | for label in labels["input_ids"] 493 | ] 494 | 495 | model_inputs["labels"] = labels["input_ids"] 496 | return model_inputs 497 | 498 | if training_args.do_train: 499 | if "train" not in datasets: 500 | raise ValueError("--do_train requires a train dataset") 501 | train_dataset = datasets["train"] 502 | if data_args.max_train_samples is not None: 503 | train_dataset = train_dataset.select(range(data_args.max_train_samples)) 504 | train_dataset = train_dataset.map( 505 | preprocess_function, 506 | batched=True, 507 | num_proc=data_args.preprocessing_num_workers, 508 | remove_columns=column_names, 509 | load_from_cache_file=not data_args.overwrite_cache, 510 | ) 511 | 512 | if training_args.do_eval: 513 | max_target_length = data_args.val_max_target_length 514 | if "validation" not in datasets: 515 | raise ValueError("--do_eval requires a validation dataset") 516 | eval_dataset = datasets["validation"] 517 | if data_args.max_eval_samples is not None: 518 | eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) 519 | eval_dataset = eval_dataset.map( 520 | preprocess_function, 521 | batched=True, 522 | num_proc=data_args.preprocessing_num_workers, 523 | remove_columns=column_names, 524 | load_from_cache_file=not data_args.overwrite_cache, 525 | ) 526 | 527 | if training_args.do_predict: 528 | max_target_length = data_args.val_max_target_length 529 | if "test" not in datasets: 530 | raise ValueError("--do_predict requires a test dataset") 531 | predict_dataset = datasets["test"] 532 | if data_args.max_predict_samples is not None: 533 | predict_dataset = predict_dataset.select( 534 | range(data_args.max_predict_samples) 535 | ) 536 | predict_dataset = predict_dataset.map( 537 | preprocess_function, 538 | batched=True, 539 | num_proc=data_args.preprocessing_num_workers, 540 | remove_columns=column_names, 541 | load_from_cache_file=not data_args.overwrite_cache, 542 | ) 543 | 544 | # Data collator 545 | label_pad_token_id = ( 546 | -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 547 | ) 548 | data_collator = DataCollatorForSeq2Seq( 549 | tokenizer, 550 | model=model, 551 | label_pad_token_id=label_pad_token_id, 552 | pad_to_multiple_of=8 if training_args.fp16 else None, 553 | ) 554 | 555 | # Initialize our Trainer 556 | trainer = Seq2SeqTrainer( 557 | model=model, 558 | args=training_args, 559 | train_dataset=train_dataset if training_args.do_train else None, 560 | eval_dataset=eval_dataset if training_args.do_eval else None, 561 | tokenizer=tokenizer, 562 | data_collator=data_collator, 563 | ) 564 | 565 | # Training 566 | if training_args.do_train: 567 | checkpoint = None 568 | if training_args.resume_from_checkpoint is not None: 569 | checkpoint = training_args.resume_from_checkpoint 570 | elif last_checkpoint is not None: 571 | checkpoint = last_checkpoint 572 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 573 | trainer.save_model() # Saves the tokenizer too for easy upload 574 | 575 | metrics = train_result.metrics 576 | max_train_samples = ( 577 | data_args.max_train_samples 578 | if data_args.max_train_samples is not None 579 | else len(train_dataset) 580 | ) 581 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 582 | 583 | trainer.log_metrics("train", metrics) 584 | trainer.save_metrics("train", metrics) 585 | trainer.save_state() 586 | 587 | # Evaluation 588 | results = {} 589 | if training_args.do_eval: 590 | logger.info("*** Evaluate ***") 591 | 592 | metrics = trainer.evaluate( 593 | max_length=data_args.val_max_target_length, 594 | num_beams=data_args.num_beams, 595 | metric_key_prefix="eval", 596 | ) 597 | max_eval_samples = ( 598 | data_args.max_eval_samples 599 | if data_args.max_eval_samples is not None 600 | else len(eval_dataset) 601 | ) 602 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 603 | 604 | trainer.log_metrics("eval", metrics) 605 | trainer.save_metrics("eval", metrics) 606 | 607 | if training_args.do_predict: 608 | logger.info("*** Predict ***") 609 | 610 | predict_results = trainer.predict( 611 | predict_dataset, 612 | metric_key_prefix="predict", 613 | max_length=data_args.val_max_target_length, 614 | num_beams=data_args.num_beams, 615 | ) 616 | metrics = predict_results.metrics 617 | max_predict_samples = ( 618 | data_args.max_predict_samples 619 | if data_args.max_predict_samples is not None 620 | else len(predict_dataset) 621 | ) 622 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 623 | 624 | trainer.log_metrics("predict", metrics) 625 | trainer.save_metrics("predict", metrics) 626 | 627 | if trainer.is_world_process_zero(): 628 | if training_args.predict_with_generate: 629 | predictions = tokenizer.batch_decode( 630 | predict_results.predictions, 631 | skip_special_tokens=True, 632 | clean_up_tokenization_spaces=True, 633 | ) 634 | predictions = [pred.strip() for pred in predictions] 635 | output_prediction_file = os.path.join( 636 | training_args.output_dir, "generated_predictions.txt" 637 | ) 638 | with open(output_prediction_file, "w") as writer: 639 | writer.write("\n".join(predictions)) 640 | 641 | if training_args.push_to_hub: 642 | kwargs = { 643 | "finetuned_from": model_args.model_name_or_path, 644 | "tasks": "summarization", 645 | } 646 | if data_args.dataset_name is not None: 647 | kwargs["dataset_tags"] = data_args.dataset_name 648 | if data_args.dataset_config_name is not None: 649 | kwargs["dataset_args"] = data_args.dataset_config_name 650 | kwargs[ 651 | "dataset" 652 | ] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 653 | else: 654 | kwargs["dataset"] = data_args.dataset_name 655 | 656 | trainer.push_to_hub(**kwargs) 657 | 658 | return results 659 | 660 | 661 | def _mp_fn(index): 662 | # For xla_spawn (TPUs) 663 | main() 664 | 665 | 666 | def encode_to_line(x: str, y: str) -> str: 667 | # Refer to original transformers readme 668 | text = json.dumps(dict(text=x, summary=y)) + "\n" 669 | assert decode_from_line(text) == (x, y) 670 | return text 671 | 672 | 673 | def decode_from_line(text: str) -> Tuple[str, str]: 674 | d = json.loads(text) 675 | return d["text"], d["summary"] 676 | 677 | 678 | if __name__ == "__main__": 679 | main() 680 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | import shutil 5 | import time 6 | from collections import Counter 7 | from pathlib import Path 8 | from typing import Dict, List, Optional, Set, Tuple, Union 9 | 10 | import numpy as np 11 | import pandas as pd 12 | from fire import Fire 13 | from pydantic import BaseModel 14 | from pydantic.main import Extra 15 | from tqdm import tqdm 16 | 17 | Span = Tuple[int, int] 18 | BasicValue = Union[str, int, bool, float] 19 | 20 | 21 | def train_test_split(*args, **kwargs) -> list: 22 | raise NotImplementedError 23 | 24 | 25 | def find_sublist_index(items: list, query: list): 26 | length = len(query) 27 | for i in range(len(items) - length + 1): 28 | if items[i : i + length] == query: 29 | return i 30 | return -1 31 | 32 | 33 | def test_find_sublist_query(): 34 | items = [1, 6, 3, 5, 7] 35 | print(dict(items=items)) 36 | for query in [[6], [7], [6, 3], [3, 5, 7], [7, 5]]: 37 | print(dict(query=query, i=find_sublist_index(items, query))) 38 | 39 | 40 | def find_sublist_indices(items: list, query: list) -> List[int]: 41 | i = find_sublist_index(items, query) 42 | if i == -1: 43 | return [] 44 | return list(range(i, i + len(query))) 45 | 46 | 47 | def test_find_sublist_indices(): 48 | items = [1, 6, 3, 5, 7] 49 | assert find_sublist_indices(items, [6, 3, 5]) == [1, 2, 3] 50 | print(dict(test_find_sublist_indices=True)) 51 | 52 | 53 | class WikiProperty(BaseModel): 54 | """ 55 | # https://query.wikidata.org 56 | # All properties with descriptions and aliases and types 57 | 58 | SELECT ?p ?pType ?pLabel ?pDescription ?pAltLabel WHERE { 59 | ?p wikibase:propertyType ?pType . 60 | SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". } 61 | } 62 | ORDER BY ASC(xsd:integer(STRAFTER(STR(?p), 'P'))) 63 | """ 64 | 65 | p: str 66 | pType: str 67 | pLabel: str 68 | pDescription: str 69 | pAltLabel: str 70 | 71 | @property 72 | def id(self) -> str: 73 | return self.p.split("/")[-1] 74 | 75 | @property 76 | def aliases(self) -> List[str]: 77 | names = [n.strip() for n in self.pAltLabel.split(",")] 78 | return sorted(set(names)) 79 | 80 | 81 | def load_wiki_relation_map(path: str) -> Dict[str, WikiProperty]: 82 | df = pd.read_csv(path) 83 | props = [WikiProperty(**r) for r in df.to_dict(orient="records")] 84 | return {p.id: p for p in props} 85 | 86 | 87 | def load_label_to_properties( 88 | path: str, use_alias: bool = True 89 | ) -> Dict[str, WikiProperty]: 90 | relation_map = load_wiki_relation_map(path) 91 | mapping = {} 92 | for p in relation_map.values(): 93 | if not p.pLabel in mapping.keys(): 94 | mapping[p.pLabel] = p 95 | if use_alias: 96 | for p in relation_map.values(): 97 | for a in p.aliases: 98 | if a not in mapping.keys(): 99 | mapping[a] = p 100 | return mapping 101 | 102 | 103 | def test_load_wiki(): 104 | relation_map = load_wiki_relation_map("data/wiki_properties.csv") 105 | for k, v in list(relation_map.items())[:3]: 106 | print(dict(k=k, v=v, aliases=v.aliases)) 107 | 108 | 109 | class DynamicModel(BaseModel): 110 | class Config: 111 | arbitrary_types_allowed = True 112 | validate_assignment = True 113 | 114 | 115 | class StrictModel(BaseModel): 116 | class Config: 117 | extra = Extra.forbid 118 | frozen = True 119 | validate_assignment = True 120 | 121 | 122 | def compute_macro_PRF( 123 | predicted_idx: np.ndarray, gold_idx: np.ndarray, i=-1, empty_label=None 124 | ) -> Tuple[float, float, float]: 125 | # https://github.com/dinobby/ZS-BERT/blob/master/model/evaluation.py 126 | """ 127 | This evaluation function follows work from Sorokin and Gurevych(https://www.aclweb.org/anthology/D17-1188.pdf) 128 | code borrowed from the following link: 129 | https://github.com/UKPLab/emnlp2017-relation-extraction/blob/master/relation_extraction/evaluation/metrics.py 130 | """ 131 | if i == -1: 132 | i = len(predicted_idx) 133 | 134 | complete_rel_set = set(gold_idx) - {empty_label} 135 | avg_prec = 0.0 136 | avg_rec = 0.0 137 | 138 | for r in complete_rel_set: 139 | r_indices = predicted_idx[:i] == r 140 | tp = len((predicted_idx[:i][r_indices] == gold_idx[:i][r_indices]).nonzero()[0]) 141 | tp_fp = len(r_indices.nonzero()[0]) 142 | tp_fn = len((gold_idx == r).nonzero()[0]) 143 | prec = (tp / tp_fp) if tp_fp > 0 else 0 144 | rec = tp / tp_fn 145 | avg_prec += prec 146 | avg_rec += rec 147 | f1 = 0.0 148 | avg_prec = avg_prec / len(set(predicted_idx[:i])) 149 | avg_rec = avg_rec / len(complete_rel_set) 150 | if (avg_rec + avg_prec) > 0: 151 | f1 = 2.0 * avg_prec * avg_rec / (avg_prec + avg_rec) 152 | 153 | return avg_prec, avg_rec, f1 154 | 155 | 156 | def test_compute_prf(): 157 | a = np.array([0, 0, 0, 0, 0]) 158 | b = np.array([0, 0, 1, 1, 0]) 159 | print(compute_macro_PRF(a, b)) 160 | 161 | 162 | def glob_rmtree(folder: str, pattern: str, verbose=True): 163 | for path in Path(folder).glob(pattern): 164 | shutil.rmtree(path) 165 | if verbose: 166 | print(dict(rmtree=path)) 167 | 168 | 169 | def test_glob_rmtree(): 170 | folder = "tmp/test_glob_rmtree" 171 | Path(folder).mkdir(exist_ok=False, parents=True) 172 | glob_rmtree("tmp", "test_glob*") 173 | 174 | 175 | def hash_text(x: str) -> str: 176 | return hashlib.md5(x.encode()).hexdigest() 177 | 178 | 179 | def check_overlap(a: Span, b: Span) -> bool: 180 | # Assumes end in (start, end) is exclusive like python slicing 181 | return ( 182 | a[0] <= b[0] < a[1] 183 | or a[0] <= b[1] - 1 < a[1] 184 | or b[0] <= a[0] < b[1] 185 | or b[0] <= a[1] - 1 < b[1] 186 | ) 187 | 188 | 189 | class RelationSentence(BaseModel): 190 | tokens: List[str] 191 | head: List[int] 192 | tail: List[int] 193 | label: str 194 | head_id: str = "" 195 | tail_id: str = "" 196 | label_id: str = "" 197 | error: str = "" 198 | raw: str = "" 199 | score: float = 0.0 200 | zerorc_included: bool = True 201 | 202 | def as_tuple(self) -> Tuple[str, str, str]: 203 | head = " ".join([self.tokens[i] for i in self.head]) 204 | tail = " ".join([self.tokens[i] for i in self.tail]) 205 | return head, self.label, tail 206 | 207 | def as_line(self) -> str: 208 | return self.json() + "\n" 209 | 210 | def is_valid(self) -> bool: 211 | for x in [self.tokens, self.head, self.tail, self.label]: 212 | if len(x) == 0: 213 | return False 214 | for x in [self.head, self.tail]: 215 | if -1 in x: 216 | return False 217 | return True 218 | 219 | @property 220 | def text(self) -> str: 221 | return " ".join(self.tokens) 222 | 223 | @classmethod 224 | def from_spans(cls, text: str, head: str, tail: str, label: str, strict=True): 225 | tokens = text.split() 226 | sent = cls( 227 | tokens=tokens, 228 | head=find_span(head, tokens), 229 | tail=find_span(tail, tokens), 230 | label=label, 231 | ) 232 | if strict: 233 | assert sent.is_valid(), (head, label, tail, text) 234 | return sent 235 | 236 | def as_marked_text(self) -> str: 237 | tokens = list(self.tokens) 238 | for i, template in [ 239 | (self.head[0], "[H {}"), 240 | (self.head[-1], "{} ]"), 241 | (self.tail[0], "[T {}"), 242 | (self.tail[-1], "{} ]"), 243 | ]: 244 | tokens[i] = template.format(tokens[i]) 245 | return " ".join(tokens) 246 | 247 | 248 | def align_span_to_tokens(span: str, tokens: List[str]) -> Tuple[int, int]: 249 | # Eg align("John R. Allen, Jr.", ['John', 'R.', 'Allen', ',', 'Jr.']) 250 | char_word_map = {} 251 | num_chars = 0 252 | for i, w in enumerate(tokens): 253 | for _ in w: 254 | char_word_map[num_chars] = i 255 | num_chars += 1 256 | char_word_map[num_chars] = len(tokens) 257 | 258 | query = span.replace(" ", "") 259 | text = "".join(tokens) 260 | assert query in text 261 | i = text.find(query) 262 | start = char_word_map[i] 263 | end = char_word_map[i + len(query) - 1] 264 | assert 0 <= start <= end 265 | return start, end + 1 266 | 267 | 268 | def test_align_span( 269 | span: str = "John R. Allen, Jr.", 270 | tokens=("The", "John", "R.", "Allen", ",", "Jr.", "is", "here"), 271 | ): 272 | start, end = align_span_to_tokens(span, tokens) 273 | print(dict(start=start, end=end, span=tokens[start:end])) 274 | 275 | 276 | def find_span(span: str, tokens: List[str]) -> List[int]: 277 | if span == "": 278 | return [] 279 | start = find_sublist_index(tokens, span.split()) 280 | if start >= 0: 281 | return [start + i for i in range(len(span.split()))] 282 | else: 283 | start, end = align_span_to_tokens(span, tokens) 284 | return list(range(start, end)) 285 | 286 | 287 | def test_find_span( 288 | span: str = "Hohenzollern", 289 | text: str = "Princess of Hohenzollern-Sigmaringen ( born 26 March 1949", 290 | ): 291 | tokens = text.split() 292 | indices = find_span(span, tokens) 293 | print(dict(test_find_span=[tokens[i] for i in indices])) 294 | 295 | 296 | class QualifierSentence(RelationSentence): 297 | qualifier: str = "" 298 | qualifier_id: str 299 | value: List[int] 300 | value_type: str 301 | 302 | def as_tuple(self) -> Tuple[str, str, str, str, str]: 303 | head = " ".join([self.tokens[i] for i in self.head]) 304 | tail = " ".join([self.tokens[i] for i in self.tail]) 305 | value = " ".join([self.tokens[i] for i in self.value]) 306 | return head, self.label, tail, self.qualifier, value 307 | 308 | 309 | class RelationData(BaseModel): 310 | sents: List[RelationSentence] 311 | 312 | @classmethod 313 | def load(cls, path: Path): 314 | with open(path) as f: 315 | lines = f.readlines() 316 | sents = [ 317 | RelationSentence(**json.loads(x)) 318 | for x in tqdm(lines, desc="RelationData.load") 319 | ] 320 | return cls(sents=sents) 321 | 322 | def save(self, path: Path): 323 | path.parent.mkdir(exist_ok=True, parents=True) 324 | with open(path, "w") as f: 325 | f.write("".join([s.as_line() for s in self.sents])) 326 | 327 | @property 328 | def unique_labels(self) -> List[str]: 329 | return sorted(set([s.label for s in self.sents])) 330 | 331 | def train_test_split( 332 | self, test_size: Union[int, float], random_seed: int, by_label: bool = False 333 | ): 334 | if by_label: 335 | labels_train, labels_test = train_test_split( 336 | self.unique_labels, test_size=test_size, random_state=random_seed 337 | ) 338 | train = [s for s in self.sents if s.label in labels_train] 339 | test = [s for s in self.sents if s.label in labels_test] 340 | else: 341 | groups = self.to_sentence_groups() 342 | keys_train, keys_test = train_test_split( 343 | sorted(groups.keys()), test_size=test_size, random_state=random_seed 344 | ) 345 | train = [s for k in keys_train for s in groups[k]] 346 | test = [s for k in keys_test for s in groups[k]] 347 | 348 | # Enforce no sentence overlap 349 | texts_test = set([s.text for s in test]) 350 | train = [s for s in train if s.text not in texts_test] 351 | 352 | data_train = RelationData(sents=train) 353 | data_test = RelationData(sents=test) 354 | if by_label: 355 | assert len(data_test.unique_labels) == test_size 356 | assert not set(data_train.unique_labels).intersection( 357 | data_test.unique_labels 358 | ) 359 | 360 | info = dict( 361 | sents_train=len(data_train.sents), 362 | sents_test=len(data_test.sents), 363 | labels_train=len(data_train.unique_labels), 364 | labels_test=len(data_test.unique_labels), 365 | ) 366 | print(json.dumps(info, indent=2)) 367 | return data_train, data_test 368 | 369 | def to_sentence_groups(self) -> Dict[str, List[RelationSentence]]: 370 | groups = {} 371 | for s in self.sents: 372 | groups.setdefault(s.text, []).append(s) 373 | return groups 374 | 375 | def to_label_groups(self) -> Dict[str, List[RelationSentence]]: 376 | groups = {} 377 | for s in self.sents: 378 | groups.setdefault(s.label, []).append(s) 379 | return groups 380 | 381 | def filter_group_sizes(self, min_size: int = 0, max_size: int = 999): 382 | groups = self.to_sentence_groups() 383 | sents = [ 384 | s 385 | for k, lst in groups.items() 386 | for s in lst 387 | if min_size <= len(lst) <= max_size 388 | ] 389 | return RelationData(sents=sents) 390 | 391 | def filter_errors(self): 392 | def check_valid_span(span: List[int]) -> bool: 393 | start = sorted(span)[0] 394 | end = sorted(span)[-1] + 1 395 | return span == list(range(start, end)) 396 | 397 | sents = [] 398 | for s in self.sents: 399 | if s.is_valid(): 400 | if check_valid_span(s.head) and check_valid_span(s.tail): 401 | sents.append(s) 402 | 403 | print(dict(filter_errors_success=len(sents) / len(self.sents))) 404 | return RelationData(sents=sents) 405 | 406 | def analyze(self, header: Optional[str] = None): 407 | labels = self.unique_labels 408 | groups = self.to_sentence_groups() 409 | spans = [] 410 | words = [] 411 | for s in self.sents: 412 | head, label, tail = s.as_tuple() 413 | spans.append(head) 414 | spans.append(tail) 415 | words.extend(s.tokens) 416 | info = dict( 417 | header=header, 418 | sents=len(self.sents), 419 | labels=str([len(labels), labels]), 420 | unique_texts=len(groups.keys()), 421 | unique_spans=len(set(spans)), 422 | unique_words=len(set(words)), 423 | group_sizes=str(Counter([len(lst) for lst in groups.values()])), 424 | ) 425 | print(json.dumps(info, indent=2)) 426 | return info 427 | 428 | 429 | def wiki_uri_to_id(uri: str) -> str: 430 | i = uri.split("/")[-1] 431 | if i[0] in "QP" and i[1:].isdigit(): 432 | return i 433 | else: 434 | return "" 435 | 436 | 437 | def split_common_prefix(texts: List[str]) -> Tuple[str, List[str]]: 438 | end = 0 439 | i_max = min(map(len, texts)) 440 | for i in range(i_max): 441 | if len(set([t[i] for t in texts])) > 1: 442 | break 443 | end += 1 444 | 445 | prefix = texts[0][:end] 446 | texts = [t[end:] for t in texts] 447 | return prefix, texts 448 | 449 | 450 | def delete_checkpoints( 451 | folder: str = ".", pattern="**/checkpoint*", delete: bool = True 452 | ): 453 | for p in Path(folder).glob(pattern): 454 | if (p.parent / "config.json").exists(): 455 | print(p) 456 | if delete: 457 | if p.is_dir(): 458 | shutil.rmtree(p) 459 | elif p.is_file(): 460 | os.remove(p) 461 | else: 462 | raise ValueError("Unknown Type") 463 | 464 | 465 | class Timer(BaseModel): 466 | name: str 467 | start: float = 0 468 | 469 | def __enter__(self): 470 | self.start = time.time() 471 | 472 | def __exit__(self, exc_type, exc_val, exc_tb): 473 | duration = round(time.time() - self.start, 3) 474 | print(dict(name=self.name, duration=duration)) 475 | 476 | 477 | def test_timer(interval: int = 2): 478 | with Timer(name="test_timer"): 479 | time.sleep(interval) 480 | 481 | 482 | def sorted_glob(folder: str, pattern: str) -> List[Path]: 483 | # Best practice to be deterministic and avoid weird behavior 484 | return sorted(Path(folder).glob(pattern)) 485 | 486 | 487 | def test_sorted_glob(): 488 | for path in sorted_glob("outputs/data/zsl/wiki", "*/test.jsonl"): 489 | print(path) 490 | 491 | 492 | def mark_wiki_entity(edge): 493 | e1 = edge["left"] 494 | e2 = edge["right"] 495 | return e1, e2 496 | 497 | 498 | def mark_fewrel_entity(edge): 499 | e1 = edge["h"][2][0] 500 | e2 = edge["t"][2][0] 501 | return e1, e2 502 | 503 | 504 | class WikiDataset: 505 | def __init__(self, mode, data, pid2vec, property2idx): 506 | assert mode in ["train", "dev", "test"] 507 | self.mode = mode 508 | self.data = data 509 | self.pid2vec = pid2vec 510 | self.property2idx = property2idx 511 | self.len = len(self.data) 512 | 513 | def load_edges( 514 | self, i: int, label_ids: Optional[Set[str]] = None 515 | ) -> List[RelationSentence]: 516 | g = self.data[i] 517 | tokens = g["tokens"] 518 | sents = [] 519 | for j in range(len(g["edgeSet"])): 520 | property_id = g["edgeSet"][j]["kbID"] 521 | edge = g["edgeSet"][j] 522 | head, tail = mark_wiki_entity(edge) 523 | if label_ids and property_id not in label_ids: 524 | continue 525 | s = RelationSentence( 526 | tokens=tokens, head=head, tail=tail, label="", label_id=property_id 527 | ) 528 | sents.append(s) 529 | return sents 530 | 531 | def __getitem__(self, item: int) -> RelationSentence: 532 | # The ZS-BERT setting is throw away all except first edge 533 | return self.load_edges(item)[0] 534 | 535 | def __len__(self): 536 | return self.len 537 | 538 | 539 | if __name__ == "__main__": 540 | """ 541 | python new_utils.py analyze_relation_data --path data/relations/trex/100000.jsonl 542 | """ 543 | test_find_sublist_query() 544 | test_load_wiki() 545 | test_compute_prf() 546 | test_glob_rmtree() 547 | test_find_sublist_indices() 548 | Fire() 549 | -------------------------------------------------------------------------------- /wrapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | from collections import Counter 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import torch 8 | from fire import Fire 9 | from pydantic.main import BaseModel 10 | from tqdm import tqdm 11 | 12 | from generation import LabelConstraint, TripletSearchDecoder 13 | from modeling import (NewRelationExtractor, RelationGenerator, RelationModel, 14 | select_model) 15 | from utils import (RelationSentence, WikiDataset, delete_checkpoints, 16 | load_wiki_relation_map, mark_fewrel_entity) 17 | 18 | 19 | def safe_divide(a: float, b: float) -> float: 20 | if a == 0 or b == 0: 21 | return 0 22 | return a / b 23 | 24 | 25 | class Sentence(BaseModel): 26 | triplets: List[RelationSentence] 27 | 28 | @property 29 | def tokens(self) -> List[str]: 30 | return self.triplets[0].tokens 31 | 32 | @property 33 | def text(self) -> str: 34 | return " ".join(self.tokens) 35 | 36 | def assert_valid(self): 37 | assert len(self.tokens) > 0 38 | for t in self.triplets: 39 | assert t.text == self.text 40 | assert len(t.head) > 0 41 | assert len(t.tail) > 0 42 | assert len(t.label) > 0 43 | 44 | 45 | class Dataset(BaseModel): 46 | sents: List[Sentence] 47 | 48 | def get_labels(self) -> List[str]: 49 | return sorted(set(t.label for s in self.sents for t in s.triplets)) 50 | 51 | @classmethod 52 | def load(cls, path: str): 53 | with open(path) as f: 54 | sents = [Sentence(**json.loads(line)) for line in f] 55 | return cls(sents=sents) 56 | 57 | def save(self, path: str): 58 | Path(path).parent.mkdir(exist_ok=True, parents=True) 59 | with open(path, "w") as f: 60 | for s in self.sents: 61 | f.write(s.json() + "\n") 62 | 63 | @classmethod 64 | def load_fewrel(cls, path: str, path_properties: str = "data/wiki_properties.csv"): 65 | relation_map = load_wiki_relation_map(path_properties) 66 | groups = {} 67 | 68 | with open(path) as f: 69 | for i, lst in tqdm(json.load(f).items()): 70 | for raw in lst: 71 | head, tail = mark_fewrel_entity(raw) 72 | t = RelationSentence( 73 | tokens=raw["tokens"], 74 | head=head, 75 | tail=tail, 76 | label=relation_map[i].pLabel, 77 | label_id=i, 78 | ) 79 | groups.setdefault(t.text, []).append(t) 80 | 81 | sents = [Sentence(triplets=lst) for lst in groups.values()] 82 | return cls(sents=sents) 83 | 84 | @classmethod 85 | def load_wiki(cls, path: str, path_properties: str = "data/wiki_properties.csv"): 86 | relation_map = load_wiki_relation_map(path_properties) 87 | sents = [] 88 | with open(path) as f: 89 | ds = WikiDataset( 90 | mode="train", data=json.load(f), pid2vec=None, property2idx=None 91 | ) 92 | for i in tqdm(range(len(ds))): 93 | triplets = ds.load_edges(i) 94 | triplets = [t for t in triplets if t.label_id in relation_map.keys()] 95 | for t in triplets: 96 | t.label = relation_map[t.label_id].pLabel 97 | if triplets: 98 | # ZSBERT only includes first triplet in each sentence 99 | for t in triplets: 100 | t.zerorc_included = False 101 | triplets[0].zerorc_included = True 102 | 103 | s = Sentence(triplets=triplets) 104 | sents.append(s) 105 | 106 | data = cls(sents=sents) 107 | counter = Counter(t.label for s in data.sents for t in s.triplets) 108 | threshold = sorted(counter.values())[-113] # Based on ZSBERT data stats 109 | labels = [k for k, v in counter.items() if v >= threshold] 110 | data = data.filter_labels(labels) 111 | return data 112 | 113 | def filter_labels(self, labels: List[str]): 114 | label_set = set(labels) 115 | sents = [] 116 | for s in self.sents: 117 | triplets = [t for t in s.triplets if t.label in label_set] 118 | if triplets: 119 | s = s.copy(deep=True) 120 | s.triplets = triplets 121 | sents.append(s) 122 | return Dataset(sents=sents) 123 | 124 | def train_test_split(self, test_size: int, random_seed: int, by_label: bool): 125 | random.seed(random_seed) 126 | 127 | if by_label: 128 | labels = self.get_labels() 129 | labels_test = random.sample(labels, k=test_size) 130 | labels_train = sorted(set(labels) - set(labels_test)) 131 | sents_train = self.filter_labels(labels_train).sents 132 | sents_test = self.filter_labels(labels_test).sents 133 | else: 134 | sents_train = [s for s in self.sents] 135 | sents_test = random.sample(self.sents, k=test_size) 136 | 137 | banned = set(s.text for s in sents_test) # Prevent sentence overlap 138 | sents_train = [s for s in sents_train if s.text not in banned] 139 | assert len(self.sents) == len(sents_train) + len(sents_test) 140 | return Dataset(sents=sents_train), Dataset(sents=sents_test) 141 | 142 | def analyze(self): 143 | info = dict( 144 | sents=len(self.sents), 145 | unique_texts=len(set(s.triplets[0].text for s in self.sents)), 146 | lengths=str(Counter(len(s.triplets) for s in self.sents)), 147 | labels=len(self.get_labels()), 148 | ) 149 | print(json.dumps(info, indent=2)) 150 | 151 | 152 | def write_data_splits( 153 | path_in: str, 154 | mode: str, 155 | folder_out: str = "outputs/data/splits/zero_rte", 156 | num_dev_labels: int = 5, 157 | num_test_labels: List[int] = [5, 10, 15], 158 | seeds: List[int] = [0, 1, 2, 3, 4], 159 | ): 160 | for n in num_test_labels: 161 | for s in seeds: 162 | if mode == "fewrel": 163 | data = Dataset.load_fewrel(path_in) 164 | elif mode == "wiki": 165 | data = Dataset.load_wiki(path_in) 166 | else: 167 | raise ValueError() 168 | 169 | train, test = data.train_test_split( 170 | test_size=n, random_seed=s, by_label=True 171 | ) 172 | train, dev = train.train_test_split( 173 | test_size=num_dev_labels, random_seed=s, by_label=True 174 | ) 175 | del data 176 | 177 | for key, data in dict(train=train, dev=dev, test=test).items(): 178 | name = f"unseen_{n}_seed_{s}" 179 | path = Path(folder_out) / Path(path_in).stem / name / f"{key}.jsonl" 180 | data.save(str(path)) 181 | print(dict(key=key, labels=len(data.get_labels()), path=path)) 182 | 183 | 184 | class Generator(BaseModel): 185 | load_dir: str 186 | save_dir: str 187 | num_gen_per_label: int = 250 188 | model_name: str = "generate" 189 | encoder_name: str = "generate" 190 | model_kwargs: dict = {} 191 | 192 | def get_model(self) -> RelationModel: 193 | model = select_model( 194 | name=self.model_name, 195 | encoder_name=self.encoder_name, 196 | model_dir=str(Path(self.save_dir) / "model"), 197 | model_name=self.load_dir, 198 | data_dir=str(Path(self.save_dir) / "data"), 199 | do_pretrain=False, 200 | **self.model_kwargs, 201 | ) 202 | return model 203 | 204 | def write_data(self, data: Dataset, name: str) -> str: 205 | model = self.get_model() 206 | path_out = Path(model.data_dir) / f"{name}.txt" 207 | path_out.parent.mkdir(exist_ok=True, parents=True) 208 | encoder = model.get_encoder() 209 | lines = [encoder.encode_to_line(t) for s in data.sents for t in s.triplets] 210 | random.seed(model.random_seed) 211 | random.shuffle(lines) 212 | with open(path_out, "w") as f: 213 | f.write("".join(lines)) 214 | return str(path_out) 215 | 216 | def fit(self, path_train: str, path_dev: str): 217 | model = self.get_model() 218 | if Path(model.model_dir).exists(): 219 | return 220 | 221 | data_train = Dataset.load(path_train) 222 | data_dev = Dataset.load(path_dev) 223 | path_train = self.write_data(data_train, "train") 224 | path_dev = self.write_data(data_dev, "dev") 225 | model.fit(path_train=path_train, path_dev=path_dev) 226 | delete_checkpoints(model.model_dir) 227 | 228 | def generate(self, labels: List[str], path_out: str): 229 | if Path(path_out).exists(): 230 | return 231 | 232 | model = self.get_model() 233 | pipe = model.make_pipe() 234 | groups = {} 235 | assert isinstance(model, RelationGenerator) 236 | for relation in tqdm(labels): 237 | triplets, raw = model.generate(relation, self.num_gen_per_label, pipe=pipe) 238 | for t in triplets: 239 | groups.setdefault(t.text, []).append(t) 240 | 241 | sents = [Sentence(triplets=lst) for lst in groups.values()] 242 | data = Dataset(sents=sents) 243 | data.save(path_out) 244 | 245 | 246 | class Extractor(BaseModel): 247 | load_dir: str 248 | save_dir: str 249 | model_name: str = "new_extract" 250 | encoder_name: str = "extract" 251 | search_threshold: float = -0.9906 252 | model_kwargs: dict = {} 253 | 254 | def get_model(self) -> RelationModel: 255 | model = select_model( 256 | name=self.model_name, 257 | encoder_name=self.encoder_name, 258 | model_dir=str(Path(self.save_dir) / "model"), 259 | model_name=self.load_dir, 260 | data_dir=str(Path(self.save_dir) / "data"), 261 | do_pretrain=False, 262 | **self.model_kwargs, 263 | ) 264 | return model 265 | 266 | def write_data(self, data: Dataset, name: str) -> str: 267 | model = self.get_model() 268 | path_out = Path(model.data_dir) / f"{name}.json" 269 | path_out.parent.mkdir(exist_ok=True, parents=True) 270 | encoder = model.get_encoder() 271 | lines = [encoder.encode_to_line(t) for s in data.sents for t in s.triplets] 272 | random.seed(model.random_seed) 273 | random.shuffle(lines) 274 | with open(path_out, "w") as f: 275 | f.write("".join(lines)) 276 | return str(path_out) 277 | 278 | def fit(self, path_train: str, path_dev: str): 279 | model = self.get_model() 280 | if Path(model.model_dir).exists(): 281 | return 282 | 283 | data_train = Dataset.load(path_train) 284 | data_train = Dataset.load(path_train) 285 | data_dev = Dataset.load(path_dev) 286 | path_train = self.write_data(data_train, "train") 287 | path_dev = self.write_data(data_dev, "dev") 288 | model.fit(path_train=path_train, path_dev=path_dev) 289 | delete_checkpoints(model.model_dir) 290 | 291 | def predict(self, path_in: str, path_out: str, use_label_constraint: bool = True): 292 | data = Dataset.load(path_in) 293 | texts = [s.text for s in data.sents] 294 | model = self.get_model() 295 | assert isinstance(model, NewRelationExtractor) 296 | gen = model.load_generator(torch.device("cuda")) 297 | encoder = model.get_encoder() 298 | constraint = LabelConstraint(labels=data.get_labels(), tokenizer=gen.tokenizer) 299 | sents = [] 300 | 301 | for i in tqdm(range(0, len(texts), model.batch_size)): 302 | batch = texts[i : i + model.batch_size] 303 | x = [encoder.encode_x(t) for t in batch] 304 | outputs = model.gen_texts( 305 | x, gen, num_beams=1, save_scores=use_label_constraint 306 | ) 307 | assert len(outputs) == len(x) 308 | 309 | for i, raw in enumerate(outputs): 310 | triplet = encoder.safe_decode(x[i], y=raw) 311 | if use_label_constraint: 312 | assert gen.scores is not None 313 | triplet = constraint.run(triplet, gen.scores[i]) 314 | sents.append(Sentence(triplets=[triplet])) 315 | 316 | Dataset(sents=sents).save(path_out) 317 | 318 | def predict_multi(self, path_in: str, path_out: str): 319 | stem = Path(path_out).stem 320 | path_raw = path_out.replace(stem, f"{stem}_raw") 321 | print(dict(predict_multi=locals())) 322 | data = Dataset.load(path_in) 323 | model = self.get_model() 324 | assert isinstance(model, NewRelationExtractor) 325 | gen = model.load_generator(torch.device("cuda")) 326 | constraint = LabelConstraint(labels=data.get_labels(), tokenizer=gen.tokenizer) 327 | searcher = TripletSearchDecoder( 328 | gen=gen, encoder=model.get_encoder(), constraint=constraint 329 | ) 330 | 331 | sents = [ 332 | Sentence(tokens=s.tokens, triplets=searcher.run(s.text)) 333 | for s in tqdm(data.sents) 334 | ] 335 | Dataset(sents=sents).save(path_raw) 336 | for s in sents: 337 | s.triplets = [t for t in s.triplets if t.score > self.search_threshold] 338 | Dataset(sents=sents).save(path_out) 339 | 340 | @staticmethod 341 | def score(path_pred: str, path_gold: str) -> dict: 342 | pred = Dataset.load(path_pred) 343 | gold = Dataset.load(path_gold) 344 | assert len(pred.sents) == len(gold.sents) 345 | num_pred = 0 346 | num_gold = 0 347 | num_correct = 0 348 | 349 | for i in range(len(gold.sents)): 350 | num_pred += len(pred.sents[i].triplets) 351 | num_gold += len(gold.sents[i].triplets) 352 | for p in pred.sents[i].triplets: 353 | for g in gold.sents[i].triplets: 354 | if (p.head, p.tail, p.label) == (g.head, g.tail, g.label): 355 | num_correct += 1 356 | 357 | precision = safe_divide(num_correct, num_pred) 358 | recall = safe_divide(num_correct, num_gold) 359 | 360 | info = dict( 361 | path_pred=path_pred, 362 | path_gold=path_gold, 363 | precision=precision, 364 | recall=recall, 365 | score=safe_divide(2 * precision * recall, precision + recall), 366 | ) 367 | return info 368 | 369 | 370 | def main( 371 | path_train: str, 372 | path_dev: str, 373 | path_test: str, 374 | save_dir: str, 375 | ): 376 | print(dict(main=locals())) 377 | generator = Generator( 378 | load_dir="gpt2", 379 | save_dir=str(Path(save_dir) / "generator"), 380 | ) 381 | extractor = Extractor( 382 | load_dir="facebook/bart-base", 383 | save_dir=str(Path(save_dir) / "extractor"), 384 | ) 385 | 386 | generator.fit(path_train, path_dev) 387 | extractor.fit(path_train, path_dev) 388 | path_synthetic = str(Path(save_dir) / "synthetic.jsonl") 389 | labels_dev = Dataset.load(path_dev).get_labels() 390 | labels_test = Dataset.load(path_test).get_labels() 391 | generator.generate(labels_dev + labels_test, path_out=path_synthetic) 392 | 393 | extractor_final = Extractor( 394 | load_dir=str(Path(save_dir) / "extractor" / "model"), 395 | save_dir=str(Path(save_dir) / "extractor_final"), 396 | ) 397 | extractor_final.fit(path_synthetic, path_dev) 398 | 399 | path_pred = str(Path(save_dir) / "pred.jsonl") 400 | extractor_final.predict(path_in=path_test, path_out=path_pred) 401 | results = extractor_final.score(path_pred, path_test) 402 | print(json.dumps(results, indent=2)) 403 | with open(Path(save_dir) / "results.json", "w") as f: 404 | json.dump(results, f, indent=2) 405 | return results 406 | 407 | 408 | def main_many(data_dir_pattern: str, save_dir: str, **kwargs): 409 | mode = Path(save_dir).name 410 | assert mode in ["fewrel", "wiki"] 411 | records = [] 412 | 413 | for path in tqdm(sorted(Path().glob(data_dir_pattern))): 414 | path_train = path / "train.jsonl" 415 | path_dev = path / "dev.jsonl" 416 | path_test = path / "test.jsonl" 417 | results = main( 418 | path_train=str(path_train), 419 | path_dev=str(path_dev), 420 | path_test=str(path_test), 421 | save_dir=str(Path(save_dir) / path.name), 422 | **kwargs, 423 | ) 424 | records.append(results) 425 | 426 | avg_p = sum([r["precision"] for r in records]) / len(records) 427 | avg_r = sum([r["recall"] for r in records]) / len(records) 428 | avg_f = safe_divide(2 * avg_p * avg_r, avg_p + avg_r) 429 | info = dict(avg_p=avg_p, avg_r=avg_r, avg_f=avg_f) 430 | print(json.dumps(info, indent=2)) 431 | 432 | 433 | def run_eval(path_model: str, path_test: str, mode: str, limit: int = 0): 434 | print(dict(run_eval=locals())) 435 | data = Dataset.load(path_test) 436 | model = Extractor(load_dir=str(Path(path_model) / "model"), save_dir=path_model) 437 | 438 | if mode == "single": 439 | data.sents = [s for s in data.sents if len(s.triplets) == 1] 440 | elif mode == "multi": 441 | data.sents = [s for s in data.sents if len(s.triplets) > 1] 442 | else: 443 | raise ValueError(f"mode must be single or multi") 444 | 445 | if limit > 0: 446 | random.seed(0) 447 | random.shuffle(data.sents) 448 | data.sents = data.sents[:limit] 449 | 450 | path_in = str(Path(path_model) / f"pred_in_{mode}.jsonl") 451 | path_out = str(Path(path_model) / f"pred_out_{mode}.jsonl") 452 | data.save(path_in) 453 | 454 | if mode == "single": 455 | model.predict(path_in, path_out) 456 | else: 457 | model.predict_multi(path_in, path_out) 458 | 459 | results = model.score(path_pred=path_out, path_gold=path_in) 460 | path_results = str(Path(path_model) / f"results_{mode}.json") 461 | results.update(mode=mode, limit=limit, path_results=path_results) 462 | print(json.dumps(results, indent=2)) 463 | with open(path_results, "w") as f: 464 | json.dump(results, f, indent=2) 465 | 466 | 467 | def run_eval_many(path_model_pattern: str, data_dir: str, **kwargs): 468 | for path in tqdm(sorted(Path().glob(path_model_pattern))): 469 | name = path.parts[-2] 470 | path_test = Path(data_dir) / name / "test.jsonl" 471 | assert path_test.exists() 472 | run_eval(path_model=str(path), path_test=str(path_test), **kwargs) 473 | 474 | 475 | """ 476 | FewRel Dataset 477 | 478 | python wrapper.py main \ 479 | --path_train outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/train.jsonl \ 480 | --path_dev outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/dev.jsonl \ 481 | --path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ 482 | --save_dir outputs/wrapper/fewrel/unseen_10_seed_0 483 | 484 | python wrapper.py run_eval \ 485 | --path_model outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final \ 486 | --path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ 487 | --mode single 488 | 489 | python wrapper.py run_eval \ 490 | --path_model outputs/wrapper/fewrel/unseen_10_seed_0/extractor_final \ 491 | --path_test outputs/data/splits/zero_rte/fewrel/unseen_10_seed_0/test.jsonl \ 492 | --mode multi 493 | 494 | Wiki-ZSL Dataset 495 | 496 | python wrapper.py main \ 497 | --path_train outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/train.jsonl \ 498 | --path_dev outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/dev.jsonl \ 499 | --path_test outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/test.jsonl \ 500 | --save_dir outputs/wrapper/wiki/unseen_10_seed_0 501 | 502 | python wrapper.py run_eval \ 503 | --path_model outputs/wrapper/wiki/unseen_10_seed_0/extractor_final \ 504 | --path_test outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/test.jsonl \ 505 | --mode single 506 | 507 | python wrapper.py run_eval \ 508 | --path_model outputs/wrapper/wiki/unseen_10_seed_0/extractor_final \ 509 | --path_test outputs/data/splits/zero_rte/wiki/unseen_10_seed_0/test.jsonl \ 510 | --mode multi 511 | 512 | """ 513 | 514 | 515 | if __name__ == "__main__": 516 | Fire() 517 | --------------------------------------------------------------------------------