├── argument_relation_transformer ├── __init__.py ├── utils.py ├── coreset.py ├── train.py ├── infer.py ├── dataset.py ├── system.py ├── active.py └── modeling.py ├── infer_demo.sh ├── setup.py ├── train_demo.sh ├── scripts ├── format_conversion.py ├── cdcp_test_ids.txt ├── echr_preproc.py ├── cdcp_train_ids.txt └── dataset.py ├── train_active_demo.sh ├── README.md └── LICENSE /argument_relation_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /infer_demo.sh: -------------------------------------------------------------------------------- 1 | SEED=42 # only used in training, here to identify the training checkpoint path 2 | DOMAIN="ampere" 3 | EXP_NAME=demo-${DOMAIN}_seed=${SEED} 4 | 5 | python -m argument_relation_transformer.infer \ 6 | --datadir=./data \ 7 | --dataset=${DOMAIN} \ 8 | --eval-set=test \ 9 | --exp-name=demo-${DOMAIN}_seed=${SEED} \ 10 | --ckptdir=./checkpoints/ \ 11 | --batch-size=32 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='argument_relation_transformer', 5 | version='0.1.0', 6 | packages=find_packages(), 7 | include_package_data=True, 8 | install_requires=[ 9 | 'torch==1.6.0', 10 | 'pytorch-lightning==1.5.0', 11 | 'transformers==4.10.3', 12 | 'numpy==1.21.6', 13 | 'scikit-learn==1.0.2', 14 | ] 15 | ) 16 | -------------------------------------------------------------------------------- /train_demo.sh: -------------------------------------------------------------------------------- 1 | # DOMAIN expects one of ['ampere', 'ukp', 'cdcp', 'abst_rct', 'echr'], data need to be downloaded separately 2 | DOMAIN="ampere" 3 | SEED=42 4 | python -m argument_relation_transformer.train \ 5 | --datadir=./data \ 6 | --seed=${SEED} \ 7 | --dataset=${DOMAIN} \ 8 | --ckptdir=./checkpoints \ 9 | --exp-name=demo-${DOMAIN}_seed=${SEED} \ 10 | --warmup-steps=5000 \ 11 | --learning-rate=1e-5 \ 12 | --huggingface-path=./huggingface/ \ 13 | --scheduler-type=constant \ 14 | --max-epochs=15 15 | -------------------------------------------------------------------------------- /scripts/format_conversion.py: -------------------------------------------------------------------------------- 1 | """Script to convert Essays/AbstRCT/ECHR/CDCP to unified format""" 2 | import argparse 3 | 4 | from dataset import ( 5 | UKPDocument, 6 | CDCPDocument, 7 | ECHRDocument, 8 | AbstCRTDocument 9 | ) 10 | 11 | dataset_map = { 12 | 'cdcp' : CDCPDocument, 13 | 'essays' : UKPDocument, 14 | 'abst_rct' : AbstCRTDocument, 15 | 'echr' : ECHRDocument 16 | } 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--dataset", type=str, 22 | choices=['abst_rct', 'essays', 'echr', 'cdcp']) 23 | args = parser.parse_args() 24 | 25 | dataset_class = dataset_map[args.dataset] 26 | dataset_class.make_all_data() 27 | 28 | if __name__=='__main__': 29 | main() 30 | 31 | -------------------------------------------------------------------------------- /scripts/cdcp_test_ids.txt: -------------------------------------------------------------------------------- 1 | 00984 2 | 00452 3 | 01191 4 | 01151 5 | 01295 6 | 01099 7 | 00615 8 | 01182 9 | 01108 10 | 00429 11 | 00524 12 | 00343 13 | 01299 14 | 00412 15 | 00240 16 | 00796 17 | 00492 18 | 00346 19 | 00365 20 | 00932 21 | 00359 22 | 00821 23 | 00543 24 | 01332 25 | 00584 26 | 00759 27 | 00378 28 | 00393 29 | 00238 30 | 00672 31 | 01094 32 | 00573 33 | 00351 34 | 00363 35 | 01084 36 | 00464 37 | 00869 38 | 00488 39 | 00341 40 | 01418 41 | 00799 42 | 00760 43 | 00536 44 | 00540 45 | 00570 46 | 00310 47 | 00702 48 | 01361 49 | 00581 50 | 00966 51 | 00716 52 | 01044 53 | 01087 54 | 00861 55 | 01093 56 | 01388 57 | 01113 58 | 01398 59 | 01239 60 | 00638 61 | 00199 62 | 00389 63 | 00748 64 | 00835 65 | 00776 66 | 00456 67 | 01344 68 | 00405 69 | 01342 70 | 00525 71 | 01132 72 | 00708 73 | 00358 74 | 00825 75 | 01290 76 | 00655 77 | 00194 78 | 00226 79 | 01405 80 | 01092 81 | 01005 82 | 00806 83 | 00698 84 | 01326 85 | 00486 86 | 01122 87 | 00369 88 | 01135 89 | 00398 90 | 00763 91 | 01178 92 | 01123 93 | 00372 94 | 00236 95 | 01152 96 | 00680 97 | 00986 98 | 00591 99 | 00320 100 | 00637 101 | 00485 102 | 00434 103 | 00730 104 | 00947 105 | 00339 106 | 01318 107 | 00978 108 | 00722 109 | 01030 110 | 00562 111 | 00890 112 | 00487 113 | 00929 114 | 00851 115 | 01241 116 | 01382 117 | 01324 118 | 01207 119 | 00837 120 | 01330 121 | 01016 122 | 00811 123 | 00677 124 | 01159 125 | 00886 126 | 01250 127 | 00565 128 | 00400 129 | 01158 130 | 00387 131 | 01220 132 | 01300 133 | 00750 134 | 00336 135 | 00196 136 | 01213 137 | 00854 138 | 01411 139 | 00810 140 | 00303 141 | 00361 142 | 01339 143 | 01242 144 | 01283 145 | 01019 146 | 00790 147 | 00661 148 | 00860 149 | 01190 150 | 00463 151 | -------------------------------------------------------------------------------- /train_active_demo.sh: -------------------------------------------------------------------------------- 1 | INTERVAL=500 2 | SEED=42 3 | METHOD="max-entropy" 4 | DOMAIN="ampere" 5 | EXP_NAME="${METHOD}-demo_SEED=${SEED}" 6 | CKPTDIR="./checkpoints_al/" 7 | MAX_EPOCHS=10 8 | 9 | active() { 10 | # load model from `model-path` (if needed) 11 | # select INTERVAL unlabeled samples 12 | # save to `/tmp/checkpoints_al/[data]/[exp_name]/[method]_[interval].jsonl`, which is 13 | # the ids of **all** labeled data 14 | # args: (1) current sample size; 15 | python -m argument_relation_transformer.active \ 16 | --dataset=${DOMAIN} \ 17 | --datadir="./data/" \ 18 | --ckptdir=${CKPTDIR} \ 19 | --exp-name=${EXP_NAME} \ 20 | --method=${METHOD} \ 21 | --seed=${SEED} \ 22 | --interval=${INTERVAL} \ 23 | --huggingface-path="./huggingface/" \ 24 | --current-sample-size=$1 25 | } 26 | 27 | train() { 28 | # load data from `{ckptdir}/{exp-name}/{method}_{current-sample-size}.jsonl` 29 | # train the model and save to `{ckptdir}/{exp-name}/model_{current-sample-size}/` 30 | python -m argument_relation_transformer.train \ 31 | --datadir=./data \ 32 | --seed=${SEED} \ 33 | --dataset=${DOMAIN} \ 34 | --ckptdir=${CKPTDIR} \ 35 | --exp-name=${EXP_NAME} \ 36 | --warmup-steps=500 \ 37 | --learning-rate=1e-5 \ 38 | --huggingface-path="./huggingface/" \ 39 | --scheduler-type=constant \ 40 | --max-epochs=${MAX_EPOCHS} \ 41 | --from-al \ 42 | --al-method=${METHOD} \ 43 | --current-sample-size=$1 44 | } 45 | 46 | for p in 0 500 1000 1500 2000 2500 3000 3500 4000 4500; 47 | do 48 | # each step, we have p samples already, and are selecting the next 500 samples 49 | active ${p} 50 | n_p=$(( $p + $INTERVAL )) 51 | train ${n_p} 52 | done 53 | -------------------------------------------------------------------------------- /argument_relation_transformer/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | 9 | DOMAINS = ["ampere", "ukp", "cdcp", "abst_rct", "echr"] 10 | AL_METHODS = [ 11 | "random", 12 | "max-entropy", 13 | "bald", 14 | "disc", 15 | "distance", 16 | "vocab", 17 | "no-disc", 18 | "coreset", 19 | ] 20 | 21 | 22 | def set_seeds(seed): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | torch.manual_seed(seed) 26 | torch.cuda.manual_seed_all(seed) 27 | 28 | 29 | def find_best_ckpt(ckpt_path): 30 | """Find the most recent path and load the model with best f1""" 31 | ckpt_paths = glob.glob(ckpt_path + "epoch=*") 32 | 33 | if len(ckpt_paths) == 0: 34 | return None, None 35 | 36 | def _get_f1_score(_path): 37 | basename = os.path.basename(_path) 38 | f1_score = basename.split(".ckpt")[0] 39 | f1_score = f1_score.split("_f1=")[1] 40 | return float(f1_score) 41 | 42 | ckpt_sorted = sorted(ckpt_paths, key=lambda x: _get_f1_score(x), reverse=True) 43 | return ckpt_sorted[0], _get_f1_score(ckpt_sorted[0]) 44 | 45 | 46 | def get_epoch_num_from_path(path): 47 | base = os.path.basename(path) 48 | base = base.split("-")[0][6:] 49 | return int(base) 50 | 51 | 52 | def load_latest_ckpt_from_globs(ckpt_list): 53 | ckpt_list = sorted( 54 | ckpt_list, key=lambda x: get_epoch_num_from_path(x), reverse=True 55 | ) 56 | return ckpt_list[0] 57 | 58 | 59 | def load_latest_ckpt(exp_name, task, domain, epoch_id=-1): 60 | 61 | if epoch_id == -1: 62 | all_ckpts = glob.glob(f"checkpoints/{task}/{domain}/{exp_name}/epoch*") 63 | else: 64 | all_ckpts = glob.glob( 65 | f"checkpoints/{task}/{domain}/{exp_name}/epoch={epoch_id}-*" 66 | ) 67 | 68 | ckpt_list = sorted( 69 | all_ckpts, key=lambda x: get_epoch_num_from_path(x), reverse=True 70 | ) 71 | assert len(ckpt_list) > 0, f"no checkpoint found for {exp_name}, epoch={epoch_id}" 72 | return ckpt_list[0], get_epoch_num_from_path(ckpt_list[0]) 73 | 74 | 75 | def move_to_cuda(sample): 76 | def _move_to_cuda(tensor): 77 | return tensor.cuda() 78 | 79 | return apply_to_sample(_move_to_cuda, sample) 80 | 81 | 82 | def apply_to_sample(f, sample): 83 | if len(sample) == 0: 84 | return {} 85 | 86 | def _apply(x): 87 | if torch.is_tensor(x): 88 | return f(x) 89 | elif isinstance(x, dict): 90 | r = {key: _apply(value) for key, value in x.items()} 91 | return r 92 | # return { 93 | # key: _apply(value) 94 | # for key, value in x.items() 95 | # } 96 | elif isinstance(x, list): 97 | return [_apply(x) for x in x] 98 | else: 99 | return x 100 | 101 | return _apply(sample) 102 | 103 | 104 | def load_vocab(domain, min_freq=0): 105 | path = f"vocab/{domain}.txt" 106 | vocab = dict() 107 | for ln in open(path): 108 | word, freq = ln.strip().split("\t") 109 | freq = int(freq) 110 | if freq < min_freq: 111 | break 112 | 113 | vocab[word] = freq 114 | return vocab 115 | -------------------------------------------------------------------------------- /argument_relation_transformer/coreset.py: -------------------------------------------------------------------------------- 1 | """Helper class for the CoreSet active learning strategy.""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | 7 | from argument_relation_transformer.utils import move_to_cuda 8 | 9 | 10 | class CoresetSampler: 11 | def __init__(self, dataloader, model): 12 | self.dataloader = dataloader 13 | self.model = model 14 | self.min_distances = None 15 | self.already_selected = [] 16 | self.pdist = nn.PairwiseDistance(p=2) 17 | 18 | def update_distances(self, cluster_centers, only_new=True, reset_dist=False): 19 | """Update min distances given cluster centers. 20 | adapted from https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py 21 | """ 22 | if reset_dist: 23 | self.min_distances = None 24 | if only_new: 25 | cluster_centers = [ 26 | d for d in cluster_centers if d not in self.already_selected 27 | ] 28 | if cluster_centers: 29 | x = self.all_features[0][cluster_centers] 30 | dist = torch.cdist(self.all_features, x).squeeze(0) 31 | 32 | if self.min_distances is None: 33 | self.min_distances, _ = torch.min(dist, axis=1) 34 | self.min_distances = self.min_distances.reshape(-1, 1) 35 | 36 | else: 37 | self.min_distances = torch.min(self.min_distances, dist) 38 | 39 | def select_batch_(self, already_selected): 40 | """Sample greedily to minimize the maximum distance to a cluster center 41 | among all unlabeled datapoints. 42 | """ 43 | features = [] 44 | sample_ix = 0 45 | already_selected_ix = [] 46 | real_ids = [] 47 | for (i, batch) in tqdm( 48 | enumerate(self.dataloader), desc=f"Sampling using coreset" 49 | ): 50 | batch = move_to_cuda(batch) 51 | # batch_size x prop_num x dim 52 | batch_feats = self.model.extract_last_layer(**batch) 53 | cur_labels = batch["labels"] 54 | for (j, _id) in enumerate(batch["ids"]): 55 | valid_prop_ix = 0 56 | for l_ix, l in enumerate(cur_labels[j]): 57 | if l == -1: 58 | continue 59 | 60 | new_id = list(_id) 61 | new_id.append(valid_prop_ix) 62 | new_id = (tuple(new_id), l.item()) 63 | if new_id in already_selected: 64 | already_selected_ix.append(sample_ix) 65 | real_ids.append(new_id) 66 | sample_ix += 1 67 | valid_prop_ix += 1 68 | 69 | cur_feat = batch_feats[j, l_ix].unsqueeze(0) 70 | features.append(cur_feat) 71 | 72 | self.all_features = torch.cat(features, 0).unsqueeze(0) 73 | self.update_distances(already_selected_ix, only_new=False, reset_dist=True) 74 | self.already_selected = already_selected_ix 75 | 76 | selected = [] 77 | for _ in range(500): 78 | ind = torch.argmax(self.min_distances) 79 | ind = ind.item() 80 | assert ind not in self.already_selected 81 | 82 | self.update_distances([ind], only_new=True, reset_dist=False) 83 | selected.append(real_ids[ind]) 84 | 85 | return selected 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Argument Structure Prediction 2 | 3 | Code release for paper `Efficient Argument Structure Extraction with Transfer Learning and Active Learning` 4 | 5 | ```bibtex 6 | @inproceedings{hua-wang-2022-efficient, 7 | title = "Efficient Argument Structure Extraction with Transfer Learning and Active Learning", 8 | author = "Hua, Xinyu and 9 | Wang, Lu", 10 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2022", 11 | month = may, 12 | year = "2022", 13 | address = "Dublin, Ireland", 14 | publisher = "Association for Computational Linguistics", 15 | url = "https://aclanthology.org/2022.findings-acl.36", 16 | pages = "423--437", 17 | } 18 | ``` 19 | 20 | ## Requirements 21 | 22 | The original project is tested under the following environments: 23 | 24 | ``` 25 | python==3.7.12 26 | torch==1.6.0 27 | pytorch_lightning==1.5.0 28 | transformers==4.10.3 29 | numpy==1.21.6 30 | scikit-learn==1.0.2 31 | ``` 32 | 33 | ## Data 34 | 35 | We release the AMPERE++ dataset in this [link](https://zenodo.org/record/6362430#.YjJJUprMIba). 36 | Please download the jsonl files and store under `./data`. 37 | 38 | 39 | The other four datasets can be downloaded using the links below (requires format conversion, code can be found in `./scripts/`): 40 | 41 | - Essays (Stab and Gurevych, 2017): [link](https://tudatalib.ulb.tu-darmstadt.de/handle/tudatalib/2422) 42 | - AbstRCT (Mayer et al., 2020): [link](https://gitlab.com/tomaye/abstrct/) 43 | - ECHR (Poudyal et al., 2020): [link](http://www.di.uevora.pt/~pq/echr/) 44 | - CDCP (Park and Cardie, 2018; Niculae et al., 2017): [link](https://facultystaff.richmond.edu/~jpark/data/cdcp_acl17.zip) 45 | 46 | ## Quick Start 47 | 48 | First, install the package: 49 | 50 | ```shell script 51 | pip install -e . 52 | ``` 53 | 54 | To train a standard supervised relation extraction model on AMPERE++: 55 | 56 | ```shell script 57 | SEED=42 58 | DOMAIN="ampere" 59 | 60 | python -m argument_relation_transformer.train \ 61 | --datadir=./data \ 62 | --seed=${SEED} \ 63 | --dataset=${DOMAIN} \ 64 | --ckptdir=./checkpoints \ 65 | --exp-name=demo-${DOMAIN}_seed=${SEED} \ 66 | --warmup-steps=5000 \ 67 | --learning-rate=1e-5 \ 68 | --scheduler=constant \ 69 | --max-epochs=15 70 | ``` 71 | 72 | The trained model will be saved at `./checkpoints/demo-ampere_seed=42/`, the tensorboard metrics can be found under `./tb_logs/demo-ampere_seed=42/` which can be loaded for evaluation. 73 | 74 | ```shell script 75 | SEED=42 # only used in training, here to identify the training checkpoint path 76 | DOMAIN="ampere" 77 | EXP_NAME=demo-${DOMAIN}_seed=${SEED} 78 | 79 | python -m argument_relation_transformer.infer \ 80 | --datadir=./data \ 81 | --dataset=${DOMAIN} \ 82 | --eval-set=test \ 83 | --exp-name=demo-${DOMAIN}_seed=${SEED} \ 84 | --ckptdir=./checkpoints/ \ 85 | --batch-size=32 86 | ``` 87 | 88 | The prediction results will be saved to `./outputs/demo-ampere_seed=42.jsonl`, the evaluation metrics will be saved to `./outputs/demo-ampere_seed=42.jsonl.scores` 89 | 90 | ## Active Learning (AL) 91 | 92 | We simulate the pool-based AL, where the entire process consists of 10 iterations. During each iteration, 500 samples are collected based 93 | on certain sampling strategy. We use the following script to demonstrate this procedure (`train_active_demo.sh`): 94 | 95 | ```shell script 96 | INTERVAL=500 97 | SEED=42 98 | METHOD="max_entropy" 99 | DOMAIN="ampere" 100 | EXP_NAME="${METHOD}-demo_seed=${SEED}" 101 | CKPTDIR="./checkpoints_al/" 102 | MAX_EPOCHS=10 103 | 104 | active() { 105 | # load model from `model-path` (if needed) 106 | # select INTERVAL unlabeled samples 107 | # save to `/tmp/checkpoints_al/[data]/[exp_name]/[method]_[interval].jsonl`, which is 108 | # the ids of **all** labeled data 109 | # args: (1) current sample size; 110 | python -m argument_relation_transformer.active \ 111 | --dataset=${DOMAIN} \ 112 | --datadir="./data/" \ 113 | --ckptdir=${CKPTDIR} \ 114 | --exp-name=${EXP_NAME} \ 115 | --method=${METHOD} \ 116 | --seed=${SEED} \ 117 | --interval=${INTERVAL} \ 118 | --huggingface-path="./huggingface/" \ 119 | --current-sample-size=$1 120 | } 121 | 122 | train() { 123 | # load data from `{ckptdir}/{exp-name}/{method}_{current-sample-size}.jsonl` 124 | # train the model and save to `{ckptdir}/{exp-name}/model_{current-sample-size}/` 125 | python -m argument_relation_transformer.train \ 126 | --datadir=./data \ 127 | --seed=${SEED} \ 128 | --dataset=${DOMAIN} \ 129 | --ckptdir=${CKPTDIR} \ 130 | --exp-name=${EXP_NAME} \ 131 | --warmup-steps=500 \ 132 | --learning-rate=1e-5 \ 133 | --huggingface-path="./huggingface/" \ 134 | --scheduler-type=constant \ 135 | --max-epochs=${MAX_EPOCHS} \ 136 | --from-al \ 137 | --al-method=${METHOD} \ 138 | --current-sample-size=$1 139 | } 140 | 141 | for p in 0 500 1000 1500 2000 2500 3000 3500 4000 4500; 142 | do 143 | # each step, we have p samples already, and are selecting the next 500 samples 144 | active ${p} 145 | n_p=$(( $p + $INTERVAL)) 146 | train ${n_p} 147 | done 148 | ``` 149 | -------------------------------------------------------------------------------- /argument_relation_transformer/train.py: -------------------------------------------------------------------------------- 1 | """Training the model, either using standard supervised training or active learning""" 2 | import argparse 3 | import datetime 4 | import json 5 | import os 6 | 7 | import pytorch_lightning as pl 8 | import torch 9 | from pytorch_lightning.loggers import TensorBoardLogger 10 | 11 | from argument_relation_transformer.infer import run_evaluation 12 | from argument_relation_transformer.system import ArgumentRelationClassificationSystem 13 | from argument_relation_transformer.utils import ( 14 | AL_METHODS, 15 | DOMAINS, 16 | find_best_ckpt, 17 | get_epoch_num_from_path, 18 | set_seeds, 19 | ) 20 | 21 | 22 | def main(): 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument( 25 | "--datadir", 26 | type=str, 27 | default="data/", 28 | help="Parent level directory containing jsonl format dataset", 29 | ) 30 | parser.add_argument( 31 | "--ckptdir", 32 | type=str, 33 | default="checkpoints/", 34 | help="Directory to save model checkpoints and active learning sample indices", 35 | ) 36 | parser.add_argument( 37 | "--dataset", 38 | type=str, 39 | choices=DOMAINS, 40 | help="Available domains for argument relation prediction task", 41 | ) 42 | parser.add_argument( 43 | "--exp-name", 44 | type=str, 45 | required=True, 46 | help="A string identifier of an experiment run, it is recommended to include `model` hyperparameters and random seed in it", 47 | ) 48 | parser.add_argument("--seed", type=int, default=42) 49 | parser.add_argument( 50 | "--end-to-end", action="store_true", help="Whether to assume heads are given" 51 | ) 52 | parser.add_argument( 53 | "--huggingface-path", 54 | type=str, 55 | default="./huggingface/", 56 | help="Directory where the pre-trained huggingface transformers are saved", 57 | ) 58 | 59 | ## Hyper-parameters 60 | # optimizer 61 | parser.add_argument("--adam-epsilon", type=float, default=1e-8) 62 | parser.add_argument("--learning-rate", type=float, default=1e-5) 63 | parser.add_argument("--warmup-steps", type=int, default=500) 64 | parser.add_argument("--scheduler-type", type=str, choices=["linear", "constant"]) 65 | 66 | parser.add_argument("--max-epochs", type=int, default=5) 67 | parser.add_argument("--batch-size", type=int, default=16) 68 | parser.add_argument("--window-size", type=int, default=20) 69 | 70 | parser.add_argument( 71 | "--from-al", 72 | action="store_true", 73 | help="If set to True, train over samples selected by AL method.", 74 | ) 75 | parser.add_argument("--al-method", type=str, choices=AL_METHODS) 76 | parser.add_argument( 77 | "--current-sample-size", 78 | type=int, 79 | help="If set --from-al to True, this will be used to identify the selected training set.", 80 | ) 81 | 82 | args = parser.parse_args() 83 | set_seeds(args.seed) 84 | 85 | if args.dataset in ["echr", "cdcp"]: 86 | print(f"dataset {args.dataset} has only two classes, convert task into binary") 87 | args.task = "binary" 88 | else: 89 | args.task = "ternary" 90 | 91 | trainer_args = dict() 92 | if torch.cuda.is_available(): 93 | trainer_args["gpus"] = 1 94 | trainer_args["log_gpu_memory"] = True 95 | print("use GPU training") 96 | else: 97 | trainer_args["gpus"] = 0 98 | print("use CPU training") 99 | 100 | if args.from_al: 101 | checkpoint_path = f"{args.ckptdir}/{args.exp_name}_model-trained-on-{args.current_sample_size}/" 102 | else: 103 | checkpoint_path = f"{args.ckptdir}/{args.exp_name}/" 104 | if not os.path.exists(args.ckptdir): 105 | os.makedirs(args.ckptdir) 106 | 107 | tb_logger = TensorBoardLogger(f"tb_logs/", name=f"{args.exp_name}") 108 | 109 | if args.task == "binary": 110 | ckpt_fname = "{epoch}-{val_loss:.4f}-{val_acc:.4f}-{val_link_f1:.4f}" 111 | monitor = "val_link_f1" 112 | else: 113 | ckpt_fname = "{epoch}-{val_loss:.4f}-{val_acc:.4f}-{val_macro_f1:.4f}" 114 | monitor = "val_macro_f1" 115 | 116 | checkpoint_callback = pl.callbacks.ModelCheckpoint( 117 | dirpath=checkpoint_path, 118 | filename=ckpt_fname, 119 | monitor=monitor, 120 | mode="max", 121 | save_top_k=1, 122 | ) 123 | trainer_args["logger"] = tb_logger 124 | trainer_args["callbacks"] = [checkpoint_callback] 125 | trainer = pl.Trainer.from_argparse_args(args, **trainer_args) 126 | 127 | if args.from_al: 128 | print( 129 | f">> Train model over actively selected samples from {args.al_method}.{args.current_sample_size}.jsonl" 130 | ) 131 | al_selected_data_path = os.path.join( 132 | args.ckptdir, 133 | args.exp_name, 134 | f"{args.al_method}.{args.current_sample_size}.jsonl", 135 | ) 136 | model = ArgumentRelationClassificationSystem(args, al_selected_data_path) 137 | else: 138 | model = ArgumentRelationClassificationSystem(args) 139 | 140 | trainer.fit(model) 141 | 142 | # run test, evaluation, and save results 143 | # predictions will be saved to `outputs/{args.exp_name}.jsonl` 144 | # scores will be saved to `scores/{args.exp_name}.scores` 145 | best_ckpt, _ = find_best_ckpt(ckpt_path=checkpoint_path) 146 | print( 147 | f">> Test on {best_ckpt}, results will be saved to `outputs/{args.exp_name}.jsonl`" 148 | ) 149 | results = trainer.test(ckpt_path=best_ckpt)[0] 150 | results["epoch"] = get_epoch_num_from_path(best_ckpt) 151 | run_evaluation( 152 | output_path=os.path.join("outputs", f"{args.exp_name}.jsonl"), 153 | label_path=os.path.join(args.datadir, f"{args.dataset}_test.jsonl"), 154 | exp_name=args.exp_name, 155 | task=model.task, 156 | ) 157 | 158 | 159 | if __name__ == "__main__": 160 | main() 161 | -------------------------------------------------------------------------------- /scripts/echr_preproc.py: -------------------------------------------------------------------------------- 1 | # Extract chunks of local context for doc-level relation prediction 2 | # Based on the original paper https://www.aclweb.org/anthology/2020.argmining-1.8.pdf 3 | # the argumentativeness nature for each sentence is assumed known during 4 | # relation prediction. The system need to predict an undirected pariwise relation 5 | # for all clause pairs that are no more than 5 sentences apart. 6 | # We aim to transform the original json format into local context blocks for 7 | # prediction. 8 | # ------- PROCEDURE --------- 9 | # 1. remove the part outside `THE LAW` 10 | # 2. split the document into sentences (check whether clauses are broken by this) 11 | # 3. store data as list of sentences with the following field 12 | # - sentence_id 13 | # - is_argument 14 | # 4. store relation pairs, e.g. "[[4, 5], [6, 5]]" 15 | 16 | import json 17 | import numpy as np 18 | from nltk.tokenize import sent_tokenize 19 | 20 | class Document: 21 | 22 | def __init__(self, name, text, clauses, arguments): 23 | self.name = name 24 | self.text = text 25 | self.clauses = clauses 26 | self.arguments = arguments 27 | 28 | def remove_boilerplate(self): 29 | """Remove anything outside `THE LAW`. Adjust clause positions.""" 30 | 31 | if 'AS TO THE LAW' in self.text: 32 | start_str = 'AS TO THE LAW' 33 | elif 'THE LAW' in self.text: 34 | start_str = 'THE LAW' 35 | else: 36 | raise ValueError 37 | 38 | real_start_index = self.text.index(start_str) + len(start_str) 39 | self.text = self.text[real_start_index:] 40 | print(f'size reduction from {len(self.text) + real_start_index} to {len(self.text)}') 41 | kept_clauses = [] 42 | self.clause_ids = set() 43 | for cl in self.clauses: 44 | if cl['start'] < real_start_index: 45 | continue 46 | 47 | # shift from removing boilerplate 48 | cl_start_shifted = cl['start'] - real_start_index 49 | cl_end_shifted = cl['end'] - real_start_index 50 | 51 | cl_text = self.text[cl_start_shifted:cl_end_shifted] 52 | cl_text_trimmed = cl_text.strip() 53 | # shift by preceding or trailing space for clause 54 | if len(cl_text_trimmed) < len(cl_text): 55 | cl_shift = cl_text.index(cl_text_trimmed) 56 | cl_start_shifted += cl_shift 57 | cl_end_shifted = cl_start_shifted + len(cl_text_trimmed) 58 | 59 | kept_clauses.append({'_id': cl['_id'], 60 | 'start': cl_start_shifted, 61 | 'end': cl_end_shifted}) 62 | self.clause_ids.add(cl['_id']) 63 | 64 | print(f'{len(kept_clauses)} clauses kept from {len(self.clauses)}') 65 | self.clauses = kept_clauses 66 | 67 | def sentence_split_and_clause_matching(self): 68 | """Split document into sentences. First take out clauses, then split 69 | the parts outside clauses.""" 70 | 71 | clause_positions = [] 72 | for i, cl in enumerate(self.clauses): 73 | clause_positions.append((cl['start'], cl['end'], cl['_id'])) 74 | clause_positions_sorted = sorted(clause_positions, key=lambda x: x[0]) 75 | self.ssplit = [] 76 | sent_len_dist = [] 77 | cur_ptr = 0 78 | for (cl_start, cl_end, cl_id) in clause_positions_sorted: 79 | cur_outside = self.text[cur_ptr: cl_start] 80 | cur_split = sent_tokenize(cur_outside) 81 | for s in cur_split: 82 | self.ssplit.append((s, False, None)) 83 | sent_len_dist.append(len(s.split())) 84 | 85 | clause_text = self.text[cl_start: cl_end] 86 | self.ssplit.append((clause_text, True, cl_id)) 87 | sent_len_dist.append(len(clause_text.split())) 88 | cur_ptr = cl_end 89 | print(f'{len(self.ssplit)} sentences found, sentence length: {min(sent_len_dist)} - {max(sent_len_dist)} (mean: {np.mean(sent_len_dist)})') 90 | 91 | 92 | def write_to_disk(self, fout): 93 | cur_doc = { 94 | 'sentences': [item[0] for item in self.ssplit], 95 | 'clause_id': [item[2] for item in self.ssplit], 96 | 'is_argument': [], 97 | 'relations': [], 98 | } 99 | 100 | arg_cl_ids = set() 101 | for arg in self.arguments: 102 | premises = arg['premises'] 103 | conclusion = arg['conclusion'] 104 | if conclusion not in self.clause_ids: 105 | continue 106 | 107 | for p in premises: 108 | if p not in self.clause_ids: 109 | continue 110 | cur_doc['relations'].append((p, conclusion)) 111 | arg_cl_ids.add(conclusion) 112 | arg_cl_ids.add(p) 113 | 114 | for cl_id in cur_doc['clause_id']: 115 | if cl_id in arg_cl_ids: 116 | cur_doc['is_argument'].append(True) 117 | else: 118 | cur_doc['is_argument'].append(False) 119 | 120 | 121 | fout.write(json.dumps(cur_doc) + "\n") 122 | return len(self.clauses), len(arg_cl_ids), len(self.ssplit), len(cur_doc['relations']) 123 | 124 | if __name__=='__main__': 125 | with open('raw/echr_corpus/ECHR_Corpus.json') as jf: 126 | data = json.load(jf) 127 | 128 | num_args, num_cl, num_rel, num_sent = 0, 0, 0, 0 129 | fout = open('raw/echr_corpus/ECHR_sentences.jsonl', 'w') 130 | for item in data: 131 | doc = Document(name=item['name'], text=item['text'], 132 | clauses=item['clauses'], arguments=item['arguments']) 133 | doc.remove_boilerplate() 134 | doc.sentence_split_and_clause_matching() 135 | cur_num_cl, cur_num_args, cur_num_sents, cur_rels = doc.write_to_disk(fout) 136 | num_args += cur_num_args 137 | num_cl += cur_num_cl 138 | num_rel += cur_rels 139 | num_sent += cur_num_sents 140 | fout.close() 141 | print(f'{num_sent} sentences in total, {num_cl} clauses, ' 142 | f'{num_args} are arguments, {num_rel} pairs of relation found') 143 | -------------------------------------------------------------------------------- /scripts/cdcp_train_ids.txt: -------------------------------------------------------------------------------- 1 | 01251 2 | 00710 3 | 00675 4 | 00953 5 | 00480 6 | 00533 7 | 00568 8 | 00695 9 | 01078 10 | 00342 11 | 01365 12 | 00204 13 | 00685 14 | 00956 15 | 00999 16 | 00963 17 | 01390 18 | 00528 19 | 00721 20 | 00575 21 | 01297 22 | 01021 23 | 01141 24 | 01168 25 | 00819 26 | 00739 27 | 01243 28 | 00551 29 | 00905 30 | 00779 31 | 00588 32 | 01129 33 | 01244 34 | 01395 35 | 01101 36 | 01347 37 | 00793 38 | 01018 39 | 01098 40 | 01035 41 | 01073 42 | 00382 43 | 00846 44 | 01373 45 | 00783 46 | 00501 47 | 00309 48 | 01238 49 | 01349 50 | 01028 51 | 01256 52 | 01189 53 | 00704 54 | 00526 55 | 01114 56 | 00493 57 | 01235 58 | 01167 59 | 01069 60 | 00435 61 | 00881 62 | 01296 63 | 01260 64 | 00645 65 | 00327 66 | 01327 67 | 00988 68 | 00633 69 | 00907 70 | 00977 71 | 00703 72 | 00679 73 | 01111 74 | 01077 75 | 00622 76 | 00200 77 | 00814 78 | 01323 79 | 01269 80 | 01275 81 | 00482 82 | 00870 83 | 00340 84 | 00535 85 | 01399 86 | 00354 87 | 00604 88 | 00517 89 | 01410 90 | 00203 91 | 01083 92 | 00815 93 | 01394 94 | 00433 95 | 01126 96 | 01063 97 | 00768 98 | 01015 99 | 00503 100 | 00549 101 | 01136 102 | 00753 103 | 00765 104 | 00913 105 | 00424 106 | 00380 107 | 01146 108 | 00635 109 | 00688 110 | 00751 111 | 00195 112 | 00772 113 | 01059 114 | 00694 115 | 00766 116 | 00771 117 | 01088 118 | 01267 119 | 00864 120 | 01298 121 | 00306 122 | 00610 123 | 01125 124 | 01121 125 | 00448 126 | 00324 127 | 00862 128 | 00225 129 | 00483 130 | 01017 131 | 01313 132 | 00318 133 | 01149 134 | 00906 135 | 01000 136 | 01013 137 | 01147 138 | 01368 139 | 01171 140 | 00774 141 | 01137 142 | 01356 143 | 01412 144 | 00601 145 | 01374 146 | 00605 147 | 01384 148 | 00314 149 | 00499 150 | 00479 151 | 00612 152 | 00514 153 | 00587 154 | 00352 155 | 00360 156 | 00461 157 | 01282 158 | 01345 159 | 00682 160 | 00396 161 | 00901 162 | 00383 163 | 00411 164 | 00534 165 | 01208 166 | 00720 167 | 00538 168 | 01109 169 | 00422 170 | 00888 171 | 00388 172 | 00845 173 | 00623 174 | 01097 175 | 00955 176 | 01179 177 | 00798 178 | 00325 179 | 00531 180 | 00691 181 | 00349 182 | 00658 183 | 00614 184 | 00571 185 | 00546 186 | 00219 187 | 01236 188 | 01186 189 | 00426 190 | 01086 191 | 00669 192 | 00823 193 | 00337 194 | 00872 195 | 01409 196 | 00594 197 | 00364 198 | 00599 199 | 01127 200 | 00686 201 | 01272 202 | 00567 203 | 00373 204 | 00423 205 | 01181 206 | 01070 207 | 01286 208 | 00822 209 | 01309 210 | 00481 211 | 01289 212 | 00521 213 | 01322 214 | 00908 215 | 00217 216 | 01040 217 | 00466 218 | 01316 219 | 01024 220 | 00987 221 | 01331 222 | 00609 223 | 00577 224 | 01142 225 | 00224 226 | 00839 227 | 00662 228 | 00740 229 | 00884 230 | 01033 231 | 01255 232 | 00212 233 | 00564 234 | 01020 235 | 01265 236 | 00406 237 | 01377 238 | 00409 239 | 01261 240 | 00579 241 | 01062 242 | 01076 243 | 01006 244 | 01389 245 | 00891 246 | 00842 247 | 00681 248 | 00639 249 | 00477 250 | 00572 251 | 00816 252 | 00529 253 | 00554 254 | 00741 255 | 00239 256 | 00560 257 | 01012 258 | 00348 259 | 00450 260 | 00344 261 | 01169 262 | 00417 263 | 00663 264 | 00397 265 | 00220 266 | 00527 267 | 00794 268 | 01056 269 | 00476 270 | 01172 271 | 01387 272 | 01145 273 | 00970 274 | 00598 275 | 01196 276 | 01414 277 | 00430 278 | 00462 279 | 01403 280 | 00732 281 | 00812 282 | 01385 283 | 01312 284 | 01157 285 | 00484 286 | 00500 287 | 01231 288 | 01072 289 | 00502 290 | 01400 291 | 01321 292 | 00817 293 | 00512 294 | 01112 295 | 01288 296 | 00964 297 | 01350 298 | 00889 299 | 00586 300 | 00228 301 | 01036 302 | 00649 303 | 01401 304 | 00333 305 | 00410 306 | 01037 307 | 00683 308 | 00995 309 | 01378 310 | 00660 311 | 01130 312 | 01252 313 | 01057 314 | 01075 315 | 00532 316 | 01051 317 | 01192 318 | 01386 319 | 00930 320 | 01337 321 | 00797 322 | 01210 323 | 00824 324 | 00223 325 | 00439 326 | 00676 327 | 00762 328 | 00600 329 | 01308 330 | 00757 331 | 00856 332 | 00802 333 | 01001 334 | 00555 335 | 00353 336 | 00427 337 | 00403 338 | 00786 339 | 01380 340 | 00205 341 | 01002 342 | 00407 343 | 00385 344 | 01247 345 | 00997 346 | 00563 347 | 01263 348 | 01032 349 | 00865 350 | 00852 351 | 00530 352 | 00805 353 | 01026 354 | 00395 355 | 00494 356 | 00747 357 | 01197 358 | 01025 359 | 00451 360 | 00829 361 | 01371 362 | 00207 363 | 00909 364 | 00595 365 | 00377 366 | 00717 367 | 00668 368 | 00602 369 | 01234 370 | 00764 371 | 00785 372 | 00210 373 | 00347 374 | 00312 375 | 00969 376 | 00357 377 | 00882 378 | 00784 379 | 00392 380 | 00419 381 | 00416 382 | 01227 383 | 01085 384 | 00506 385 | 00513 386 | 01314 387 | 01408 388 | 00596 389 | 00234 390 | 00627 391 | 00413 392 | 00198 393 | 01353 394 | 01209 395 | 00808 396 | 00308 397 | 00967 398 | 00218 399 | 01273 400 | 01307 401 | 00910 402 | 00321 403 | 00582 404 | 01061 405 | 00498 406 | 01140 407 | 00928 408 | 00631 409 | 01166 410 | 00231 411 | 01118 412 | 01095 413 | 00495 414 | 00592 415 | 00505 416 | 00414 417 | 00621 418 | 01119 419 | 00566 420 | 01248 421 | 00994 422 | 01249 423 | 00998 424 | 01311 425 | 00559 426 | 01027 427 | 00578 428 | 00902 429 | 01041 430 | 01218 431 | 00520 432 | 00656 433 | 00912 434 | 01022 435 | 00745 436 | 00863 437 | 00496 438 | 00659 439 | 00338 440 | 00958 441 | 00975 442 | 00657 443 | 00402 444 | 00305 445 | 00522 446 | 00813 447 | 00990 448 | 00673 449 | 01089 450 | 00744 451 | 00626 452 | 01079 453 | 00391 454 | 00315 455 | 01090 456 | 01369 457 | 01199 458 | 01219 459 | 00689 460 | 00611 461 | 00442 462 | 01134 463 | 00379 464 | 01150 465 | 00989 466 | 00746 467 | 00302 468 | 00628 469 | 00457 470 | 00332 471 | 01340 472 | 01066 473 | 01292 474 | 00728 475 | 00597 476 | 01257 477 | 01148 478 | 00375 479 | 01212 480 | 01329 481 | 00692 482 | 00371 483 | 01071 484 | 01284 485 | 01144 486 | 00208 487 | 01226 488 | 00420 489 | 00665 490 | 01211 491 | 01407 492 | 00345 493 | 00350 494 | 00467 495 | 01173 496 | 01383 497 | 01217 498 | 01303 499 | 00788 500 | 00460 501 | 01246 502 | 00230 503 | 00705 504 | 00444 505 | 00557 506 | 01237 507 | 00993 508 | 01029 509 | 00362 510 | 01375 511 | 00867 512 | 00996 513 | 01417 514 | 01262 515 | 00458 516 | 00518 517 | 01014 518 | 01187 519 | 00938 520 | 00974 521 | 00574 522 | 00883 523 | 00206 524 | 00394 525 | 00331 526 | 01355 527 | 01110 528 | 00390 529 | 00840 530 | 01194 531 | 01180 532 | 00826 533 | 01128 534 | 00381 535 | 01031 536 | 00489 537 | 00674 538 | 00843 539 | 00436 540 | 01274 541 | 00386 542 | 01154 543 | 00334 544 | 01050 545 | 00404 546 | 00576 547 | 01068 548 | 00370 549 | 00326 550 | 01053 551 | 00335 552 | 00755 553 | 01301 554 | 01164 555 | 00684 556 | 01280 557 | 00541 558 | 01049 559 | 00644 560 | 01281 561 | 00803 562 | 00197 563 | 00943 564 | 00201 565 | 01131 566 | 00758 567 | 00523 568 | 00903 569 | 01120 570 | 00980 571 | 00428 572 | 01091 573 | 01201 574 | 00472 575 | 00608 576 | 00756 577 | 01193 578 | 00491 579 | 01165 580 | 01067 581 | 01206 582 | -------------------------------------------------------------------------------- /argument_relation_transformer/infer.py: -------------------------------------------------------------------------------- 1 | """Methods for evaluation and inference""" 2 | import argparse 3 | import datetime 4 | import glob 5 | import json 6 | import os 7 | 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | import torch 11 | from sklearn.metrics import precision_recall_fscore_support 12 | 13 | from argument_relation_transformer.system import ArgumentRelationClassificationSystem 14 | from argument_relation_transformer.utils import DOMAINS, find_best_ckpt 15 | 16 | 17 | def run_evaluation(output_path, label_path, exp_name, task): 18 | """Given the goldstandard data, calculate macro-F1 for relation, and binary F1 for link prediction. 19 | Also report how many links are predicted vs. how many exist in labels. 20 | """ 21 | true_data = dict() # doc_id -> (tail, head) -> label 22 | pred_data = dict() # doc_id -> (tail, head) -> prediction 23 | y_true, y_pred = [], [] 24 | link_y_true, link_y_pred = [], [] 25 | acc = [] 26 | 27 | for ln in open(label_path): 28 | cur_obj = json.loads(ln) 29 | doc_id = cur_obj["doc_id"] 30 | if isinstance(doc_id, list): 31 | doc_id = "_".join(doc_id) 32 | 33 | true_data[doc_id] = dict() 34 | n_text = len(cur_obj["text"]) 35 | 36 | support_pairs, attack_pairs = [], [] 37 | head_set = set() 38 | for item in cur_obj["relations"]: 39 | head = item["head"] 40 | tail = item["tail"] 41 | rel_type = item["type"] 42 | 43 | if rel_type == "support": 44 | support_pairs.append((tail, head)) 45 | else: 46 | attack_pairs.append((tail, head)) 47 | head_set.add(head) 48 | 49 | # we assume the head are known, and iterate over all possible tails 50 | # to create the gold-standard set (true_data) 51 | for head_id in sorted(head_set): 52 | for i in range(n_text): 53 | if i == head_id: 54 | continue 55 | cur_pair = (i, head_id) 56 | if cur_pair in support_pairs: 57 | label = "support" 58 | elif cur_pair in attack_pairs: 59 | label = "attack" 60 | else: 61 | label = "no-rel" 62 | true_data[doc_id][cur_pair] = label 63 | 64 | for ln in open(output_path): 65 | cur_obj = json.loads(ln) 66 | doc_id = cur_obj["doc_id"] 67 | 68 | if isinstance(doc_id, list): 69 | doc_id = "_".join(doc_id) 70 | if doc_id not in pred_data: 71 | pred_data[doc_id] = dict() 72 | for pair in cur_obj["candidates"]: 73 | pair_idx = (pair["tail"], pair["head"]) 74 | pair_pred = pair["prediction"] 75 | pred_data[doc_id][pair_idx] = pair_pred 76 | 77 | pred_cnt, true_cnt = 0, 0 78 | support_cnt, attack_cnt, link_cnt = 0, 0, 0 79 | 80 | for doc_id, true_pairs in true_data.items(): 81 | pred_pairs = pred_data[doc_id] if doc_id in pred_data else {} 82 | for t in true_pairs: 83 | true_cnt += 1 84 | if t not in pred_pairs: 85 | pred = "no-rel" 86 | else: 87 | pred = pred_pairs[t] 88 | pred_cnt += 1 89 | y_true.append(true_pairs[t]) 90 | y_pred.append(pred) 91 | acc.append(1 if pred == true_pairs[t] else 0) 92 | 93 | if task == "binary": 94 | if pred == "link": 95 | link_cnt += 1 96 | support_cnt += 1 97 | else: 98 | if pred == "support": 99 | support_cnt += 1 100 | link_cnt += 1 101 | elif pred == "attack": 102 | attack_cnt += 1 103 | link_cnt += 1 104 | 105 | link_y_true.append(0 if true_pairs[t] == "no-rel" else 1) 106 | link_y_pred.append(0 if pred == "no-rel" else 1) 107 | 108 | print(f"{true_cnt} pairs found in label, {pred_cnt} predicted") 109 | macro_prec, macro_rec, macro_f1, _ = precision_recall_fscore_support( 110 | y_true, y_pred, average="macro", zero_division=0 111 | ) 112 | binary_prec, binary_rec, binary_f1, _ = precision_recall_fscore_support( 113 | link_y_true, link_y_pred, average="binary", zero_division=0 114 | ) 115 | 116 | print( 117 | f"Macro F1: {macro_f1}\tBinary F1: {binary_f1}\tPredicted: {pred_cnt}\tTotal: {true_cnt}" 118 | ) 119 | 120 | fout = open(output_path + ".scores", "w") 121 | results = { 122 | "macro_prec": macro_prec, 123 | "macro_rec": macro_rec, 124 | "macro_f1": macro_f1, 125 | "binary_prec": binary_prec, 126 | "binary_rec": binary_rec, 127 | "binary_f1": binary_f1, 128 | "pred_samples": pred_cnt, 129 | "labeled_samples": true_cnt, 130 | "accuracy": np.mean(acc), 131 | "support_ratio": support_cnt / true_cnt, 132 | "attack_ratio": attack_cnt / true_cnt, 133 | "link_ratio": link_cnt / true_cnt, 134 | } 135 | fout.write(json.dumps(results)) 136 | fout.close() 137 | 138 | 139 | def main(): 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--datadir", type=str, default="./data/") 142 | parser.add_argument("--ckptdir", type=str, default="./checkpoints/") 143 | parser.add_argument("--exp-name", type=str, required=True) 144 | parser.add_argument("--dataset", type=str, required=True, choices=DOMAINS) 145 | 146 | parser.add_argument( 147 | "--end-to-end", action="store_true", help="Whether to assume heads are given" 148 | ) 149 | parser.add_argument("--eval-set", type=str, choices=["val", "test"], default="test") 150 | 151 | parser.add_argument("--batch-size", type=int, default=32) 152 | parser.add_argument("--window-size", type=int, default=20) 153 | args = parser.parse_args() 154 | 155 | ckpt_path = os.path.join(args.ckptdir, args.exp_name) 156 | candidate_ckpt = glob.glob(f"{ckpt_path}/*.ckpt") 157 | assert len(candidate_ckpt) > 0, f"No checkpoint found under {ckpt_path}" 158 | if len(candidate_ckpt) > 1: 159 | best_ckpt, _ = find_best_ckpt(ckpt_path) 160 | else: 161 | best_ckpt = candidate_ckpt[0] 162 | print(f"Loading checkpoint from {best_ckpt}") 163 | 164 | eval_output_path = f"outputs/{args.exp_name}.jsonl" 165 | trainer = pl.Trainer.from_argparse_args(args, gpus=1) 166 | model = ArgumentRelationClassificationSystem.load_from_checkpoint( 167 | checkpoint_path=best_ckpt 168 | ) 169 | trainer.test(model) 170 | 171 | run_evaluation( 172 | output_path=eval_output_path, 173 | label_path=f"{args.datadir}/{args.dataset}_{args.eval_set}.jsonl", 174 | exp_name=args.exp_name, 175 | task=model.task, 176 | ) 177 | 178 | 179 | if __name__ == "__main__": 180 | main() 181 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2022 Bloomberg Finance L.P. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | https://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /argument_relation_transformer/dataset.py: -------------------------------------------------------------------------------- 1 | """Data loading and processing.""" 2 | import json 3 | import os 4 | import random 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | 11 | class ArgumentRelationDataset(Dataset): 12 | """Dataset class that creates context window and batchify the dataset.""" 13 | 14 | def __init__( 15 | self, 16 | dataset_name, 17 | datadir, 18 | set_type, 19 | tokenizer, 20 | end_to_end=False, 21 | window_size=20, 22 | seed=42, 23 | sampled_ids=None, 24 | ): 25 | """ 26 | Args: 27 | dataset_name (str): dataset domain, e.g., `ampere`, `ukp`. 28 | datadir (str): path to where the dataset jsonl files are stored 29 | set_type (str): one of `train`, `val`, `test` 30 | tokenizer (transformers.Tokenizer): the tokenizer used to map text to symbols 31 | end_to_end (bool): if set to True, head propositions are not given 32 | window_size (int): the number of propositions encoded to the left and right 33 | seed (int): random seed 34 | sampled_ids (list): list of sample ids that are to be included, use None if all should be included. This is not None only for active learning setting. 35 | 36 | """ 37 | super().__init__() 38 | self.dataset_name = dataset_name 39 | self.datadir = datadir 40 | self.set_type = set_type 41 | self.window_size = window_size 42 | self.end_to_end = end_to_end 43 | self.tokenizer = tokenizer 44 | self.seed = seed 45 | random.seed(seed) 46 | 47 | self.disc_token_id = tokenizer.cls_token_id 48 | 49 | self.ID = [] 50 | self.input_str = [] 51 | self.input_ids = [] 52 | self.labels = [] 53 | 54 | # statistics to print 55 | self.supp_cnt = 0 56 | self.att_cnt = 0 57 | self.no_rel_cnt = 0 58 | 59 | self.skipped_supp_label = 0 60 | self.skipped_att_label = 0 61 | 62 | self.sampled_ids = None 63 | if sampled_ids is not None: 64 | print("select subset by AL") 65 | self.sampled_ids = dict() 66 | for ln in sampled_ids: 67 | doc_head_id = tuple(ln[:-1]) 68 | prop_id = ln[-1] 69 | if doc_head_id not in self.sampled_ids: 70 | self.sampled_ids[doc_head_id] = [] 71 | self.sampled_ids[doc_head_id].append(prop_id) 72 | 73 | self._load_data() 74 | 75 | def _load_data(self): 76 | """Load data split and report statistics""" 77 | path = os.path.join(self.datadir, f"{self.dataset_name}_{self.set_type}.jsonl") 78 | print(path) 79 | for ln in open(path): 80 | cur_obj = json.loads(ln) 81 | self._include_bidir_context_with_window_size(cur_obj) 82 | 83 | print("=" * 50) 84 | print(f"label distribution for ({self.set_type}):") 85 | total_cnt = self.supp_cnt + self.att_cnt + self.no_rel_cnt 86 | print( 87 | f"Support: {self.supp_cnt} ({100 * self.supp_cnt / total_cnt:.2f}%) ({self.skipped_supp_label} skipped due to context limit)", 88 | end="\t", 89 | ) 90 | print( 91 | f"Attack {self.att_cnt} ({100 * self.att_cnt / total_cnt:.2f}%) ({self.skipped_att_label} skipped due to context limit)", 92 | end="\t", 93 | ) 94 | print(f"No-rel: {self.no_rel_cnt} ({100 * self.no_rel_cnt / total_cnt:.2f}%)") 95 | print(f"{len(self.ID)} sequence loaded.") 96 | print(f"=" * 50) 97 | 98 | def _include_bidir_context_with_window_size(self, doc_obj): 99 | """Create actual training sample in sequences, by extracting context from left and right. 100 | 101 | Results are stored in the following lists: 102 | - self.input_ids : store the token ids of the input sequence 103 | - self.input_str : store the raw string of the input sequence 104 | - self.labels : store the label id of the input sequence (0: no-rel, 1: support, 2: attack, -2: head itself) 105 | - self.ID 106 | """ 107 | 108 | doc_id = doc_obj["doc_id"] 109 | cur_sents = doc_obj["text"] 110 | cur_toks = [ 111 | self.tokenizer.encode(sent, add_special_tokens=False) for sent in cur_sents 112 | ] 113 | relation_list = doc_obj["relations"] 114 | head_to_tails = dict() # head_id -> list of tail ids 115 | for item in relation_list: 116 | head = item["head"] 117 | tail = item["tail"] 118 | rel_type = item["type"] 119 | 120 | if head not in head_to_tails: 121 | head_to_tails[head] = [] 122 | head_to_tails[head].append((tail, rel_type)) 123 | 124 | for head in range(len(cur_toks)): 125 | 126 | if head in head_to_tails: 127 | tail_list = head_to_tails[head] 128 | 129 | elif self.end_to_end: 130 | # do not assume heads are given, include all propositions as potential head 131 | tail_list = [] 132 | 133 | else: 134 | # assumes heads are given, therefore skip non-head cases 135 | continue 136 | 137 | self._extract_left_context(doc_id, cur_toks, cur_sents, head, tail_list) 138 | self._extract_right_context(doc_id, cur_toks, cur_sents, head, tail_list) 139 | return 140 | 141 | def _extract_right_context(self, doc_id, tokens, strs, head, tail_list): 142 | """Extract right context with window size, also truncate to at most 500 tokens (not counting disc_token)""" 143 | 144 | # right context is `forward` 145 | id_proposal = (doc_id, head, "forward") 146 | right_id = min(head + self.window_size, len(tokens) - 1) 147 | proposal_tokens = tokens[head : right_id + 1] 148 | proposal_str = strs[head : right_id + 1] 149 | labels = [0] * len(proposal_tokens) 150 | 151 | # use -2 to indicate head in `labels` 152 | labels[0] = -2 153 | 154 | for r_id, rel_type in tail_list: 155 | if r_id <= head: # left context 156 | continue 157 | 158 | # the offset for prop that is immediately to the right of head is 1, head itself would be 0 159 | positive_offset = r_id - head 160 | if ( 161 | positive_offset > len(labels) - 1 162 | ): # too far right, exceeding window size 163 | continue 164 | 165 | if rel_type == "support": 166 | labels[positive_offset] = 1 167 | 168 | elif rel_type == "attack": 169 | labels[positive_offset] = 2 170 | 171 | cur_lens = sum([len(x) for x in proposal_tokens]) 172 | while cur_lens > 500 - self.window_size: 173 | 174 | # shrink from the rightmost 175 | cur_lens -= len(proposal_tokens[-1]) 176 | proposal_tokens = proposal_tokens[:-1] 177 | proposal_str = proposal_str[:-1] 178 | if labels[-1] == 1: 179 | self.skipped_supp_label += 1 180 | elif labels[-1] == 2: 181 | self.skipped_att_label += 1 182 | labels = labels[:-1] 183 | 184 | if len(labels) <= 1: 185 | return 186 | 187 | input_ids = [] 188 | input_str = [] 189 | for _ix, toks in enumerate(proposal_tokens): 190 | input_ids.append(self.disc_token_id) 191 | input_ids.extend(toks) 192 | input_str.append(proposal_str[_ix]) 193 | 194 | skip = False 195 | if self.set_type == "train" and self.sampled_ids is not None: 196 | uncov_ids = ( 197 | self.sampled_ids[id_proposal] if id_proposal in self.sampled_ids else [] 198 | ) 199 | masked_labels = [] 200 | skip = True 201 | for i, l in enumerate(labels): 202 | if l == -2: 203 | masked_labels.append(l) 204 | elif i in uncov_ids: # this label is uncovered during active learning 205 | masked_labels.append(l) 206 | skip = False 207 | else: # this label is not uncovered, so hidden as -1 (padding) 208 | masked_labels.append(-1) 209 | labels = masked_labels 210 | 211 | if not skip: 212 | self.input_ids.append(input_ids) 213 | self.input_str.append(input_str) 214 | self.labels.append(labels) 215 | self.ID.append(id_proposal) 216 | 217 | # record statistics 218 | cur_supp = labels.count(1) 219 | cur_att = labels.count(2) 220 | self.supp_cnt += cur_supp 221 | self.att_cnt += cur_att 222 | self.no_rel_cnt += labels.count(0) 223 | 224 | def _extract_left_context(self, doc_id, tokens, strs, head, tail_list): 225 | """Extract left context with window size, also truncate to at most 500 tokens (not counting disc_token)""" 226 | 227 | # left context is `backward` 228 | id_proposal = (doc_id, head, "backward") 229 | left_id = max(0, head - self.window_size) 230 | proposal_tokens = tokens[left_id : head + 1] 231 | proposal_str = strs[left_id : head + 1] 232 | 233 | labels = [0] * len(proposal_tokens) 234 | for r_id, rel_type in tail_list: 235 | if r_id >= head: # belongs to the right context 236 | continue 237 | 238 | # the offset for prop that is immediately to the left of head is -2, head itself would be -1 239 | negative_offset = r_id - head - 1 240 | if negative_offset < -1 * len( 241 | labels 242 | ): # too far left, exceeding window size 243 | continue 244 | 245 | if rel_type == "support": 246 | labels[negative_offset] = 1 247 | 248 | elif rel_type == "attack": 249 | labels[negative_offset] = 2 250 | 251 | # make sure the entire sequence is within 500 token range 252 | cur_lens = sum([len(x) for x in proposal_tokens]) 253 | while cur_lens > 500 - self.window_size: 254 | # shrink from the leftmost 255 | cur_lens -= len(proposal_tokens[0]) 256 | proposal_tokens = proposal_tokens[1:] 257 | proposal_str = proposal_str[1:] 258 | if labels[0] == 1: 259 | self.skipped_supp_label += 1 260 | elif labels[0] == 2: 261 | self.skipped_att_label += 1 262 | labels = labels[1:] 263 | 264 | # only one proposition is left, likely because some proposition in this dataset is too long 265 | if len(labels) <= 1: 266 | return 267 | 268 | # if prop3 is head, encode as: 269 | # prop1 prop2 prop3 270 | # use -2 to indicate the head proposition in `labels` 271 | labels = labels[:-1] + [-2] 272 | input_str = [] 273 | input_ids = [] 274 | for _ix, toks in enumerate(proposal_tokens): 275 | input_ids.append(self.disc_token_id) 276 | input_ids.extend(toks) 277 | input_str.append(proposal_str[_ix]) 278 | 279 | skip = False 280 | if self.set_type == "train" and self.sampled_ids is not None: 281 | uncov_ids = ( 282 | self.sampled_ids[id_proposal] if id_proposal in self.sampled_ids else [] 283 | ) 284 | masked_labels = [] 285 | skip = True 286 | for i, l in enumerate(labels): 287 | if l == -2: 288 | masked_labels.append(l) 289 | elif i in uncov_ids: # this label is uncovered during active learning 290 | masked_labels.append(l) 291 | skip = False 292 | else: # this label is not uncovered, so hidden as -1 (padding) 293 | masked_labels.append(-1) 294 | labels = masked_labels 295 | 296 | if not skip: 297 | self.input_ids.append(input_ids) 298 | self.input_str.append(input_str) 299 | self.labels.append(labels) 300 | self.ID.append(id_proposal) 301 | cur_supp = labels.count(1) 302 | cur_att = labels.count(2) 303 | self.supp_cnt += cur_supp 304 | self.att_cnt += cur_att 305 | self.no_rel_cnt += labels.count(0) 306 | 307 | def __len__(self): 308 | return len(self.ID) 309 | 310 | def __getitem__(self, index): 311 | result = { 312 | "id": self.ID[index], 313 | "input_ids": self.input_ids[index], 314 | "labels": self.labels[index], 315 | "input_str": self.input_str[index], 316 | } 317 | return result 318 | 319 | def collater(self, samples): 320 | """Consolidate list of samples into a batch""" 321 | batch = dict() 322 | batch["ids"] = [s["id"] for s in samples] 323 | batch["input_str"] = [s["input_str"] for s in samples] 324 | 325 | batch_size = len(samples) 326 | max_input_len = max([len(s["input_ids"]) for s in samples]) 327 | input_ids = np.full( 328 | shape=[batch_size, max_input_len], 329 | fill_value=self.tokenizer.pad_token_id, 330 | dtype=np.int, 331 | ) 332 | 333 | # 1 - [disc], 0 - otherwise 334 | disc_token_mask = np.full( 335 | shape=[batch_size, max_input_len], fill_value=0, dtype=np.int 336 | ) 337 | 338 | max_sent_num = max([len(s["labels"]) for s in samples]) 339 | sentence_boundary = np.full( 340 | shape=[batch_size, max_sent_num], fill_value=0, dtype=np.int 341 | ) 342 | sentence_boundary_mask = np.zeros( 343 | shape=[batch_size, max_sent_num], dtype=np.int 344 | ) 345 | 346 | # 0 - no-rel; 1 - support; 2 - attack; -1 - (pad, self) 347 | labels = np.full(shape=[batch_size, max_sent_num], fill_value=-1, dtype=np.int) 348 | target_ids = torch.zeros([batch_size, 1], dtype=torch.long) 349 | 350 | for ix, s in enumerate(samples): 351 | cur_input_ids = s["input_ids"] 352 | input_ids[ix][: len(cur_input_ids)] = cur_input_ids 353 | 354 | cur_labels = [x if x != -2 else -1 for x in s["labels"]] 355 | 356 | disc_token_pos = [] 357 | cur_disc = cur_input_ids.index(self.disc_token_id) 358 | while cur_disc < len(cur_input_ids) - 1: 359 | disc_token_pos.append(cur_disc) 360 | disc_token_mask[ix, cur_disc] = 1 361 | try: 362 | cur_disc = cur_input_ids.index(self.disc_token_id, cur_disc + 1) 363 | except ValueError: 364 | break 365 | 366 | assert len(disc_token_pos) == len(cur_labels) 367 | 368 | sentence_boundary[ix][: len(disc_token_pos)] = disc_token_pos 369 | sentence_boundary_mask[ix][: len(disc_token_pos)] = 1 370 | 371 | labels[ix][: len(cur_labels)] = cur_labels 372 | target_ids[ix][0] = s["labels"].index(-2) 373 | 374 | batch["input_ids"] = torch.LongTensor(input_ids) 375 | assert ( 376 | batch["input_ids"].shape[1] < 512 377 | ), f'size too large! {batch["input_ids"].shape}' 378 | batch["labels"] = torch.LongTensor( 379 | labels 380 | ) # 0 -> no-real, -1 -> attack, 1 -> support 381 | batch["sequence_boundary_ids"] = torch.LongTensor(sentence_boundary) 382 | batch["sequence_boundary_mask"] = torch.LongTensor(sentence_boundary_mask) 383 | batch["attention_mask"] = ( 384 | batch["input_ids"] != self.tokenizer.pad_token_id 385 | ).long() 386 | batch["target_ids"] = target_ids 387 | batch["disc_token_mask"] = torch.LongTensor(disc_token_mask) 388 | return batch 389 | -------------------------------------------------------------------------------- /argument_relation_transformer/system.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import pytorch_lightning as pl 6 | from sklearn.metrics import precision_recall_fscore_support 7 | from torch.utils.data import DataLoader 8 | from transformers import ( 9 | AdamW, 10 | RobertaTokenizer, 11 | get_constant_schedule_with_warmup, 12 | get_linear_schedule_with_warmup, 13 | ) 14 | 15 | from argument_relation_transformer.dataset import ArgumentRelationDataset 16 | from argument_relation_transformer.modeling import RobertaForDocumentSpanClassification 17 | 18 | 19 | class ArgumentRelationClassificationSystem(pl.LightningModule): 20 | def __init__(self, hparams, al_selected_data_path=None): 21 | super().__init__() 22 | 23 | if isinstance(hparams, dict): 24 | hparams = argparse.Namespace(**hparams) 25 | self.save_hyperparameters(hparams) 26 | 27 | self.datadir = hparams.datadir 28 | self.dataset_name = hparams.dataset 29 | self.batch_size = hparams.batch_size 30 | self.window_size = hparams.window_size 31 | self.adam_epsilon = hparams.adam_epsilon 32 | self.learning_rate = hparams.learning_rate 33 | self.scheduler_type = hparams.scheduler_type 34 | self.warmup_steps = hparams.warmup_steps 35 | self.exp_name = hparams.exp_name 36 | self.seed = hparams.seed 37 | self.task = hparams.task 38 | self.end_to_end = hparams.end_to_end 39 | 40 | self.sampled_ids = None 41 | if al_selected_data_path is not None: 42 | self.sampled_ids = [] 43 | for ln in open(al_selected_data_path): 44 | _id, _label = json.loads(ln) 45 | self.sampled_ids.append(_id) 46 | 47 | if hparams.huggingface_path is None: 48 | model_name_or_path = "roberta-base" 49 | else: 50 | model_name_or_path = os.path.join(hparams.huggingface_path, "roberta-base") 51 | 52 | self.tokenizer = RobertaTokenizer.from_pretrained(model_name_or_path) 53 | self.model = RobertaForDocumentSpanClassification.from_pretrained( 54 | model_name_or_path, 55 | num_labels=3 if self.task == "ternary" else 2, 56 | config=model_name_or_path, 57 | ) 58 | 59 | def training_step(self, batch, batch_idx): 60 | logits, loss = self.model(**batch) 61 | pred = logits.argmax(-1) 62 | labels = batch["labels"] 63 | 64 | accuracy = (pred == labels)[labels != -1].float().mean() 65 | pred_masked = pred[labels > -1].tolist() 66 | if len(pred_masked) > 0: 67 | 68 | if self.task == "binary": 69 | self.log( 70 | "train_step_pos_ratio", 71 | sum(pred_masked) / len(pred_masked), 72 | on_step=True, 73 | prog_bar=False, 74 | logger=True, 75 | ) 76 | else: 77 | supp_cnt = pred_masked.count(1) 78 | att_cnt = pred_masked.count(2) 79 | supp_ratio = supp_cnt / len(pred_masked) 80 | att_ratio = att_cnt / len(pred_masked) 81 | self.log( 82 | "train_step_support_ratio", 83 | supp_ratio, 84 | on_step=True, 85 | prog_bar=False, 86 | logger=True, 87 | ) 88 | self.log( 89 | "train_step_attack_ratio", 90 | att_ratio, 91 | on_step=True, 92 | prog_bar=False, 93 | logger=True, 94 | ) 95 | 96 | self.log("train_loss", loss, on_step=True, prog_bar=False, logger=True) 97 | self.log("train_acc", accuracy, on_step=True, prog_bar=False, logger=True) 98 | for i, param in enumerate(self.opt.param_groups): 99 | self.log( 100 | f"lr_group_{i}", param["lr"], on_step=True, prog_bar=False, logger=True 101 | ) 102 | return {"loss": loss, "pred": pred, "labels": labels} 103 | 104 | def test_step(self, batch, batch_idx): 105 | logits, loss = self.model(**batch) 106 | 107 | pred = logits.argmax(-1) 108 | labels = batch["labels"] 109 | accuracy = (pred == labels)[labels != -1].float().mean() 110 | pred_unmasked = pred[labels != -1].tolist() 111 | 112 | if len(pred_unmasked) > 0: 113 | if self.task == "binary": 114 | self.log( 115 | "test_pos_ratio", 116 | sum(pred_unmasked) / len(pred_unmasked), 117 | on_step=False, 118 | prog_bar=False, 119 | logger=True, 120 | ) 121 | else: 122 | supp_cnt = pred_unmasked.count(1) 123 | att_cnt = pred_unmasked.count(2) 124 | supp_ratio = supp_cnt / len(pred_unmasked) 125 | att_ratio = att_cnt / len(pred_unmasked) 126 | self.log( 127 | "test_support_ratio", 128 | supp_ratio, 129 | on_step=False, 130 | prog_bar=False, 131 | logger=True, 132 | ) 133 | self.log( 134 | "test_attack_ratio", 135 | att_ratio, 136 | on_step=False, 137 | prog_bar=False, 138 | logger=True, 139 | ) 140 | 141 | self.log("test_loss", loss, on_step=False, prog_bar=False, logger=True) 142 | self.log("test_acc", accuracy, on_step=False, prog_bar=False, logger=True) 143 | 144 | # recover the original predictions for more accurate evaluation 145 | pred_results = dict() # (src, tgt) -> [pred, label] 146 | for ids, p, l, i_str in zip(batch["ids"], pred, labels, batch["input_str"]): 147 | doc_id, head_prop_id, rel_dir = ids 148 | cur_samples = len(i_str) - 1 149 | if rel_dir == "backward": 150 | effective_l = l[:cur_samples].tolist() 151 | effective_p = p[:cur_samples].tolist() 152 | for tail_i in range(cur_samples): 153 | tail_real_idx = head_prop_id - cur_samples + tail_i 154 | pred_results[(doc_id, tail_real_idx, head_prop_id)] = ( 155 | effective_p[tail_i], 156 | effective_l[tail_i], 157 | ) 158 | else: 159 | effective_l = l[1 : cur_samples + 1].tolist() 160 | effective_p = p[1 : cur_samples + 1].tolist() 161 | for tail_i in range(cur_samples): 162 | tail_real_idx = head_prop_id + tail_i + 1 163 | pred_results[(doc_id, tail_real_idx, head_prop_id)] = ( 164 | effective_p[tail_i], 165 | effective_l[tail_i], 166 | ) 167 | 168 | return { 169 | "loss": loss, 170 | "acc": accuracy, 171 | "pred": pred, 172 | "labels": labels, 173 | "results": pred_results, 174 | } 175 | 176 | def validation_step(self, batch, batch_idx): 177 | logits, loss = self.model(**batch) 178 | 179 | pred = logits.argmax(-1) 180 | labels = batch["labels"] 181 | accuracy = (pred == labels)[labels != -1].float().mean() 182 | pred_unmasked = pred[labels != -1].tolist() 183 | 184 | if len(pred_unmasked) > 0: 185 | if self.task == "binary": 186 | self.log( 187 | "val_pos_ratio", 188 | sum(pred_unmasked) / len(pred_unmasked), 189 | on_step=False, 190 | prog_bar=False, 191 | logger=True, 192 | ) 193 | else: 194 | supp_cnt = pred_unmasked.count(1) 195 | att_cnt = pred_unmasked.count(2) 196 | supp_ratio = supp_cnt / len(pred_unmasked) 197 | att_ratio = att_cnt / len(pred_unmasked) 198 | self.log( 199 | "val_support_ratio", 200 | supp_ratio, 201 | on_step=False, 202 | prog_bar=False, 203 | logger=True, 204 | ) 205 | self.log( 206 | "val_attack_ratio", 207 | att_ratio, 208 | on_step=False, 209 | prog_bar=False, 210 | logger=True, 211 | ) 212 | 213 | self.log("val_loss", loss, on_step=False, prog_bar=False, logger=True) 214 | self.log("val_acc", accuracy, on_step=False, prog_bar=False, logger=True) 215 | 216 | return {"loss": loss, "acc": accuracy, "pred": pred, "labels": labels} 217 | 218 | def validation_epoch_end(self, validation_step_outputs): 219 | y_true, y_pred = [], [] 220 | for out in validation_step_outputs: 221 | for p, l in zip(out["pred"], out["labels"]): 222 | p = p[l > -1] 223 | l = l[l > -1] 224 | y_pred.extend(p.tolist()) 225 | y_true.extend(l.tolist()) 226 | 227 | if self.task == "binary": 228 | prec, rec, f1, _ = precision_recall_fscore_support( 229 | y_true, y_pred, average="binary" 230 | ) 231 | self.log("val_link_f1", f1, on_epoch=True, logger=True) 232 | self.log("val_link_prec", prec, on_epoch=True, logger=True) 233 | self.log("val_link_rec", rec, on_epoch=True, logger=True) 234 | else: 235 | prec, rec, f1, _ = precision_recall_fscore_support( 236 | y_true, y_pred, average="macro" 237 | ) 238 | self.log("val_macro_f1", f1, on_epoch=True, logger=True) 239 | self.log("val_macro_prec", prec, on_epoch=True, logger=True) 240 | self.log("val_macro_rec", rec, on_epoch=True, logger=True) 241 | 242 | def test_epoch_end(self, test_step_outputs): 243 | LABEL_NAMES = ["no-rel", "support", "attack"] 244 | y_true, y_pred = [], [] 245 | total_results = dict() # doc -> [tail, head] -> (pred, label) 246 | for out in test_step_outputs: 247 | for p, l in zip(out["pred"], out["labels"]): 248 | p = p[l > -1] 249 | l = l[l > -1] 250 | y_pred.extend(p.tolist()) 251 | y_true.extend(l.tolist()) 252 | 253 | for k, v in out["results"].items(): 254 | doc_id, tail, head = k 255 | if doc_id not in total_results: 256 | total_results[doc_id] = dict() 257 | total_results[doc_id][(tail, head)] = ( 258 | LABEL_NAMES[v[0]], 259 | LABEL_NAMES[v[1]], 260 | ) 261 | 262 | # log results to disk 263 | output_path = f"outputs/{self.exp_name}.jsonl" 264 | if not os.path.exists("outputs/"): 265 | os.makedirs("./outputs/") 266 | fout = open(output_path, "w") 267 | for doc, pairs in total_results.items(): 268 | _pairs = [ 269 | {"tail": tail, "head": head, "prediction": p, "label": l} 270 | for ((tail, head), (p, l)) in pairs.items() 271 | ] 272 | fout.write(json.dumps({"doc_id": doc, "candidates": _pairs}) + "\n") 273 | fout.close() 274 | if self.task == "binary": 275 | prec, rec, f1, _ = precision_recall_fscore_support( 276 | y_true, y_pred, average="binary" 277 | ) 278 | self.log("test_link_f1", f1, on_epoch=True, logger=True) 279 | self.log("test_link_prec", prec, on_epoch=True, logger=True) 280 | self.log("test_link_rec", rec, on_epoch=True, logger=True) 281 | else: 282 | prec, rec, f1, _ = precision_recall_fscore_support( 283 | y_true, y_pred, average="macro" 284 | ) 285 | self.log("test_macro_f1", f1, on_epoch=True, logger=True) 286 | self.log("test_macro_prec", prec, on_epoch=True, logger=True) 287 | self.log("test_macro_rec", rec, on_epoch=True, logger=True) 288 | 289 | def training_epoch_end(self, outputs) -> None: 290 | y_true, y_pred = [], [] 291 | for out in outputs: 292 | for p, l in zip(out["pred"], out["labels"]): 293 | p = p[l > -1] 294 | l = l[l > -1] 295 | y_pred.extend(p.tolist()) 296 | y_true.extend(l.tolist()) 297 | 298 | if self.task == "binary": 299 | prec, rec, f1, _ = precision_recall_fscore_support( 300 | y_true, y_pred, average="binary" 301 | ) 302 | self.log("train_link_f1", f1, on_epoch=True, logger=True) 303 | self.log("train_link_prec", prec, on_epoch=True, logger=True) 304 | self.log("train_link_rec", rec, on_epoch=True, logger=True) 305 | self.log( 306 | "train_link_pos_ratio", 307 | sum(y_pred) / len(y_pred), 308 | on_epoch=True, 309 | logger=True, 310 | ) 311 | else: 312 | prec, rec, f1, _ = precision_recall_fscore_support( 313 | y_true, y_pred, average="macro" 314 | ) 315 | self.log("train_macro_f1", f1, on_epoch=True, logger=True) 316 | self.log("train_macro_prec", prec, on_epoch=True, logger=True) 317 | self.log("train_macro_rec", rec, on_epoch=True, logger=True) 318 | supp_ratio = y_pred.count(1) / len(y_pred) 319 | att_ratio = y_pred.count(2) / len(y_pred) 320 | self.log("train_support_ratio", supp_ratio, on_epoch=True, logger=True) 321 | self.log("train_attack_ratio", att_ratio, on_epoch=True, logger=True) 322 | self.log( 323 | "train_link_ratio", supp_ratio + att_ratio, on_epoch=True, logger=True 324 | ) 325 | 326 | def get_dataloader(self, set_type, shuffle): 327 | dataset = ArgumentRelationDataset( 328 | dataset_name=self.dataset_name, 329 | datadir=self.datadir, 330 | set_type=set_type, 331 | tokenizer=self.tokenizer, 332 | end_to_end=self.end_to_end, 333 | window_size=self.window_size, 334 | seed=self.seed, 335 | sampled_ids=self.sampled_ids if set_type == "train" else None, 336 | ) 337 | dataloader = DataLoader( 338 | dataset, 339 | batch_size=self.batch_size, 340 | collate_fn=dataset.collater, 341 | shuffle=shuffle, 342 | num_workers=0, 343 | ) 344 | return dataloader 345 | 346 | def train_dataloader(self): 347 | return self.train_loader 348 | 349 | def val_dataloader(self): 350 | return self.get_dataloader(set_type="val", shuffle=False) 351 | 352 | def test_dataloader(self, test_set="test", use_pipeline=False): 353 | return self.get_dataloader(set_type=test_set, shuffle=False) 354 | 355 | def total_steps(self): 356 | return (self.dataset_size / self.hparams.batch_size) * self.hparams.max_epochs 357 | 358 | def setup(self, stage): 359 | self.train_loader = self.get_dataloader("train", shuffle=True) 360 | self.dataset_size = len(self.train_loader.dataset) 361 | 362 | def get_lr_scheduler(self): 363 | if self.scheduler_type == "linear": 364 | scheduler = get_linear_schedule_with_warmup( 365 | self.opt, 366 | num_warmup_steps=self.warmup_steps, 367 | num_training_steps=self.total_steps(), 368 | ) 369 | else: 370 | scheduler = get_constant_schedule_with_warmup( 371 | self.opt, 372 | num_warmup_steps=self.warmup_steps, 373 | ) 374 | scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} 375 | return scheduler 376 | 377 | def configure_optimizers(self): 378 | model = self.model 379 | optimizer_grouped_parameters = [ 380 | {"params": [p for n, p in model.named_parameters()], "weight_decay": 0.0} 381 | ] 382 | print( 383 | f'{len(optimizer_grouped_parameters[0]["params"])} parameters will be trained' 384 | ) 385 | 386 | optimizer = AdamW( 387 | optimizer_grouped_parameters, 388 | lr=self.hparams.learning_rate, 389 | eps=self.hparams.adam_epsilon, 390 | ) 391 | self.opt = optimizer 392 | 393 | scheduler = self.get_lr_scheduler() 394 | return [optimizer], [scheduler] 395 | -------------------------------------------------------------------------------- /argument_relation_transformer/active.py: -------------------------------------------------------------------------------- 1 | """Active Learning routines that select next batch of unlabeled data for annotation.""" 2 | import argparse 3 | import glob 4 | import json 5 | import os 6 | import random 7 | from collections import Counter 8 | 9 | import numpy as np 10 | import torch 11 | from torch.distributions.categorical import Categorical 12 | from torch.utils.data import DataLoader 13 | from tqdm import tqdm 14 | from transformers import RobertaTokenizer 15 | 16 | from argument_relation_transformer.coreset import CoresetSampler 17 | from argument_relation_transformer.dataset import ArgumentRelationDataset 18 | from argument_relation_transformer.system import ArgumentRelationClassificationSystem 19 | from argument_relation_transformer.utils import AL_METHODS, move_to_cuda, set_seeds 20 | 21 | 22 | class AL: 23 | """Active Learning class that implements various acqusition methods.""" 24 | 25 | def __init__(self, args, existing_samples): 26 | self.dataset = args.dataset 27 | self.datadir = args.datadir 28 | self.method = args.method 29 | if self.method == "vocab": 30 | self.vocab_path = args.vocab_path 31 | 32 | self.interval = args.interval 33 | self.existing_samples = existing_samples 34 | self.batch_size = args.batch_size 35 | self.seed = args.seed 36 | self.huggingface_path = args.huggingface_path 37 | 38 | def acquire(self, model): 39 | if self.huggingface_path is not None: 40 | tokenizer_name_or_path = self.huggingface_path + "roberta-base/" 41 | else: 42 | tokenizer_name_or_path = "roberta-base" 43 | tokenizer = RobertaTokenizer.from_pretrained(tokenizer_name_or_path) 44 | dataset = ArgumentRelationDataset( 45 | dataset_name=self.dataset, 46 | datadir=self.datadir, 47 | set_type="train", 48 | window_size=20, 49 | tokenizer=tokenizer, 50 | seed=self.seed, 51 | ) 52 | dataloader = DataLoader( 53 | dataset=dataset, 54 | batch_size=self.batch_size, 55 | collate_fn=dataset.collater, 56 | shuffle=False, 57 | num_workers=0, 58 | ) 59 | 60 | print( 61 | f"------------------ SELECTION METHOD: {self.method.upper()} ---------------------" 62 | ) 63 | if self.method == "disc": 64 | return self.acquire_discourse_marker(dataloader) 65 | elif self.method == "no-disc": 66 | return self.acquire_discourse_marker(dataloader, True) 67 | elif self.method == "distance": 68 | return self.acquire_distance(dataloader) 69 | elif self.method == "vocab": 70 | return self.acquire_vocab(dataloader) 71 | elif self.method == "random" or model is None: 72 | return self.acquire_random(dataloader) 73 | elif self.method == "max-entropy": 74 | return self.acquire_entropy(dataloader, model) 75 | elif self.method == "bald": 76 | return self.acquire_bald(dataloader, model) 77 | elif self.method == "coreset": 78 | return self.acquire_coreset(dataloader, model) 79 | 80 | def acquire_vocab(self, dataloader): 81 | """Sample propositions based on the novelty of vocabulary. 82 | 83 | Score(prop) = \sum_w {freq} 84 | """ 85 | from nltk.tokenize import word_tokenize 86 | 87 | vocab_idf = dict() 88 | for ln in open(self.vocab_path): 89 | word, df = ln.strip().split("\t") 90 | vocab_idf[word] = 1 / int(df) 91 | 92 | def _score_proposition(input_str): 93 | score = 0 94 | for w in word_tokenize(input_str): 95 | w = w.lower() 96 | if str.isalpha(w): 97 | w_idf = vocab_idf[w] if w in vocab_idf else 0.0 98 | w_score = 1 / (1 + w_idf) 99 | score += w_score 100 | return score 101 | 102 | id2score = dict() 103 | for (i, batch) in tqdm( 104 | enumerate(dataloader), desc="sampling using vocabulary novelty" 105 | ): 106 | for (j, (_id, _labels)) in enumerate(zip(batch["ids"], batch["labels"])): 107 | _labels = _labels[_labels > -1] 108 | for k in range(len(_labels)): 109 | new_id = list(_id) 110 | new_id.append(k) 111 | new_id = (tuple(new_id), _labels[k].item()) 112 | 113 | if new_id in self.existing_samples: 114 | continue 115 | if "forward" in new_id[0]: 116 | cur_input_str = batch["input_str"][j][k + 1] 117 | else: 118 | cur_input_str = batch["input_str"][j][k] 119 | 120 | cur_score = _score_proposition(cur_input_str) 121 | id2score[new_id] = cur_score 122 | selected = [] 123 | for item in sorted(id2score.items(), key=lambda x: x[1], reverse=True): 124 | selected.append(item[0]) 125 | if len(selected) == self.interval: 126 | break 127 | self.existing_samples.extend(selected) 128 | all_labels = [item[1] for item in self.existing_samples] 129 | label_dist = Counter(all_labels) 130 | 131 | return self.existing_samples, label_dist 132 | 133 | def acquire_distance(self, dataloader): 134 | """Sample propositions based on their distance to the target, using a Possion distribution.""" 135 | bins = list(range(1, 21)) 136 | current_round_samples = [[] for _ in bins] 137 | 138 | for (i, batch) in tqdm(enumerate(dataloader), desc="sampling using distance"): 139 | for (j, (_id, _labels)) in enumerate(zip(batch["ids"], batch["labels"])): 140 | _labels = _labels[_labels > -1] 141 | for k in range(len(_labels)): 142 | new_id = list(_id) 143 | new_id.append(k) 144 | new_id = (tuple(new_id), _labels[k].item()) 145 | 146 | if new_id in self.existing_samples: 147 | continue 148 | 149 | if "backward" in new_id[0]: 150 | cur_distance = len(_labels) - k 151 | else: 152 | cur_distance = k + 1 153 | cur_bin = cur_distance - 1 154 | current_round_samples[cur_bin].append(new_id) 155 | possion_samples = np.random.poisson(lam=4, size=self.interval) 156 | bin_dist = Counter(possion_samples) 157 | extra = 0 158 | 159 | selected = [] 160 | for ix in range(len(bins)): 161 | count = bin_dist[ix] 162 | cur_pool = current_round_samples[ix] 163 | if count > len(cur_pool): 164 | cur_sample = cur_pool 165 | extra += count - len(cur_pool) 166 | else: 167 | cur_sample = random.sample(cur_pool, k=count) 168 | current_round_samples[ix] = [_i for _i in cur_pool if _i not in cur_sample] 169 | for item in cur_sample: 170 | selected.append(item) 171 | 172 | last_extra = 0 173 | if extra > 0: 174 | print( 175 | f"{extra} extra need to be sampled by lambda=2, step={len(self.existing_samples)}" 176 | ) 177 | extra_poisson_samples = np.random.poisson(lam=2, size=extra) 178 | extra_bin_dist = Counter(extra_poisson_samples) 179 | for ix in range(len(bins)): 180 | count = extra_bin_dist[ix] 181 | if count >= len(current_round_samples[ix]): 182 | cur_sample = current_round_samples[ix] 183 | last_extra += count - len(current_round_samples) 184 | else: 185 | cur_sample = random.sample(current_round_samples[ix], k=count) 186 | 187 | cur_pool = current_round_samples[ix] 188 | current_round_samples[ix] = [ 189 | _i for _i in cur_pool if _i not in cur_sample 190 | ] 191 | 192 | for item in cur_sample: 193 | selected.append(item) 194 | if last_extra > 0: 195 | 196 | total = [] 197 | for b in current_round_samples: 198 | total.extend(b) 199 | print( 200 | f"{last_extra} needs to be sampled from the remaining, using random, sampling from the modified universe of {len(total)} samples" 201 | ) 202 | cur_sample = random.sample(total, k=last_extra) 203 | selected.extend(cur_sample) 204 | 205 | self.existing_samples.extend(selected) 206 | all_labels = [item[1] for item in self.existing_samples] 207 | label_dist = Counter(all_labels) 208 | return self.existing_samples, label_dist 209 | 210 | def acquire_random(self, dataloader): 211 | 212 | unsampled = [] 213 | for i, batch in tqdm(enumerate(dataloader)): 214 | for j, (_id, _label) in enumerate(zip(batch["ids"], batch["labels"])): 215 | 216 | # _id: (doc_id, tgt_id, direction) 217 | # _label: [ 0, 0, 0, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1] 218 | _label = _label[_label > -1] 219 | 220 | for k in range(len(_label)): 221 | cur_label = _label[k].item() 222 | new_id = list(_id) 223 | new_id.append(k) 224 | new_id = (tuple(new_id), cur_label) 225 | if new_id in self.existing_samples: 226 | continue 227 | 228 | unsampled.append(new_id) 229 | 230 | if len(unsampled) <= self.interval: 231 | selected_tuples = unsampled 232 | else: 233 | selected_tuples = random.sample(unsampled, k=self.interval) 234 | 235 | self.existing_samples.extend(selected_tuples) 236 | all_labels = [item[1] for item in self.existing_samples] 237 | label_dist = Counter(all_labels) 238 | 239 | return self.existing_samples, label_dist 240 | 241 | def acquire_entropy(self, dataloader, model): 242 | scores = dict() # id -> score 243 | label_dist = Counter() 244 | for i, batch in tqdm(enumerate(dataloader)): 245 | labels = batch["labels"] 246 | batch_ids = batch.pop("ids") 247 | batch.pop("input_str") 248 | batch.pop("disc_token_mask") 249 | batch = move_to_cuda(batch) 250 | logits, loss = model.model(**batch) 251 | categorical = Categorical(logits=logits) 252 | entropies = categorical.entropy() 253 | 254 | for j, (_id, _label) in enumerate(zip(batch_ids, batch["labels"])): 255 | sample_entropy = entropies[j] 256 | sample_entropy = sample_entropy[_label > -1] 257 | _label = _label[_label > -1] 258 | 259 | for (k, ent) in enumerate(sample_entropy): 260 | new_id = list(_id) 261 | new_id.append(k) 262 | cur_label = _label[k].item() 263 | new_id = (tuple(new_id), cur_label) 264 | 265 | if new_id in self.existing_samples: 266 | continue 267 | 268 | scores[new_id] = ent.item() 269 | 270 | sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) 271 | unsampled = [] 272 | for item in sorted_scores[: self.interval]: 273 | unsampled.append(item[0]) 274 | 275 | self.existing_samples.extend(unsampled) 276 | all_labels = [item[1] for item in self.existing_samples] 277 | label_dist = Counter(all_labels) 278 | 279 | return self.existing_samples, label_dist 280 | 281 | @torch.no_grad() 282 | def acquire_coreset(self, dataloader, model): 283 | """ 284 | Sample using coreset sampling 285 | """ 286 | sampler = CoresetSampler(dataloader=dataloader, model=model.model) 287 | selected = sampler.select_batch_(self.existing_samples) 288 | self.existing_samples.extend(selected) 289 | 290 | all_labels = [item[1] for item in self.existing_samples] 291 | label_dist = Counter(all_labels) 292 | return self.existing_samples, label_dist 293 | 294 | @torch.no_grad() 295 | def acquire_bald(self, dataloader, model): 296 | """Run BALD based sampling, code adapted from: 297 | 298 | https://github.com/siddk/vqa-outliers/blob/main/active.py#L728 299 | 300 | Need to run Monte-Carlo dropout in train mode; collect disagreement of k different forward passes 301 | """ 302 | import numpy as np 303 | import torch 304 | from scipy.stats import entropy 305 | 306 | model.train() 307 | 308 | def mc_step(batch, k=10): 309 | bsz = batch["labels"].shape[0] 310 | probs, disagreements = [], [] 311 | with torch.no_grad(): 312 | for _ in range(k): 313 | logits, _ = model.model(**batch) 314 | prob = torch.softmax(logits, dim=-1).detach().cpu().numpy() 315 | prob = prob.reshape(-1, prob.shape[-1]) # [16, 21, 3] -> [16*21, 3] 316 | probs.append(prob) 317 | disagreements.append(entropy(prob.transpose(1, 0))) 318 | 319 | entropies = entropy(np.mean(probs, axis=0).transpose(1, 0)) 320 | disagreements = np.mean(disagreements, axis=0) 321 | diff = entropies - disagreements 322 | return diff.reshape(bsz, -1) 323 | 324 | id_list = [] 325 | score_list = [] 326 | for (i, batch) in tqdm(enumerate(dataloader)): 327 | batch = move_to_cuda(batch) 328 | info = mc_step(batch, k=10) # info: [bsz, window] 329 | labels = batch["labels"].detach().cpu().numpy() 330 | 331 | for j, (_id, _label) in enumerate(zip(batch["ids"], labels)): 332 | cur_info = info[j][_label > -1] 333 | _label = _label[_label > -1] 334 | 335 | for k, _info in enumerate(cur_info): 336 | new_id = list(_id) 337 | new_id.append(k) 338 | cur_label = _label[k].item() 339 | new_id = (tuple(new_id), cur_label) 340 | if new_id in self.existing_samples: 341 | continue 342 | 343 | id_list.append(new_id) 344 | score_list.append(_info) 345 | added_ids = [ 346 | id_list[x] 347 | for x in np.argpartition(score_list, -self.interval)[-self.interval :] 348 | ] 349 | self.existing_samples.extend(added_ids) 350 | all_labels = [item[1] for item in added_ids] 351 | label_dist = Counter(all_labels) 352 | 353 | return self.existing_samples, label_dist 354 | 355 | def acquire_discourse_marker(self, dataloader, exclude_marker=False): 356 | DISC_MARKER = [ 357 | "because", 358 | "however", 359 | "therefore", 360 | "although", 361 | "though", 362 | "nevertheless", 363 | "nonetheless", 364 | "thus", 365 | "hence", 366 | "consequently", 367 | "for this reason", 368 | "due to", 369 | "in particular", 370 | "particularly", 371 | "specifically", 372 | "but", 373 | "in fact", 374 | "actually", 375 | ] 376 | 377 | unsampled_to_keep = [] 378 | unsampled_to_discard = [] 379 | for (i, batch) in tqdm(enumerate(dataloader)): 380 | for j, (_id, _label) in enumerate(zip(batch["ids"], batch["labels"])): 381 | _label = _label[_label > -1] 382 | for k in range(len(_label)): 383 | new_id = list(_id) 384 | new_id.append(k) 385 | cur_label = _label[k].item() 386 | new_id = (tuple(new_id), cur_label) 387 | if new_id in self.existing_samples: 388 | continue 389 | if "forward" in new_id[0]: 390 | cur_input_str = batch["input_str"][j][k + 1].lower() 391 | else: 392 | cur_input_str = batch["input_str"][j][k].lower() 393 | 394 | contains_disc = False 395 | for w in DISC_MARKER: 396 | if w in cur_input_str: 397 | contains_disc = True 398 | break 399 | 400 | if exclude_marker: 401 | if contains_disc: 402 | unsampled_to_discard.append(new_id) 403 | else: 404 | unsampled_to_keep.append(new_id) 405 | else: 406 | if contains_disc: 407 | unsampled_to_keep.append(new_id) 408 | else: 409 | unsampled_to_discard.append(new_id) 410 | 411 | if self.interval > len(unsampled_to_keep): 412 | selected = unsampled_to_keep 413 | additional_count = self.interval - len(unsampled_to_keep) 414 | selected_extra = random.sample(unsampled_to_discard, k=additional_count) 415 | selected.extend(selected_extra) 416 | else: 417 | selected = random.sample(unsampled_to_keep, k=self.interval) 418 | 419 | self.existing_samples.extend(selected) 420 | all_labels = [item[1] for item in self.existing_samples] 421 | label_dist = Counter(all_labels) 422 | 423 | return self.existing_samples, label_dist 424 | 425 | 426 | def main(): 427 | parser = argparse.ArgumentParser() 428 | parser.add_argument("--dataset", type=str, required=True) 429 | parser.add_argument("--datadir", type=str, default="./data/") 430 | parser.add_argument("--ckptdir", type=str, default="checkpoints/") 431 | parser.add_argument("--exp-name", type=str, required=True) 432 | parser.add_argument("--method", type=str, required=True, choices=AL_METHODS) 433 | parser.add_argument("--seed", type=int, required=True) 434 | 435 | parser.add_argument( 436 | "--current-sample-size", 437 | type=int, 438 | required=True, 439 | help="The number of samples that's annotated now (before AL).", 440 | ) 441 | parser.add_argument( 442 | "--interval", 443 | type=int, 444 | default=500, 445 | help="The number of samples to be selected for annotation.", 446 | ) 447 | parser.add_argument("--batch-size", type=int, default=16) 448 | 449 | parser.add_argument( 450 | "--huggingface-path", 451 | type=str, 452 | default=None, 453 | help="Path to local copy of the huggingface model, if not specified, will attempt to download remotely", 454 | ) 455 | 456 | parser.add_argument( 457 | "--vocab-path", 458 | type=str, 459 | help="Path to the vocabulary, only required when `--method` set to `vocab`", 460 | ) 461 | args = parser.parse_args() 462 | 463 | set_seeds(args.seed) 464 | 465 | output_dir = os.path.join(args.ckptdir, args.exp_name) 466 | if not os.path.exists(output_dir): 467 | os.makedirs(output_dir) 468 | 469 | if args.current_sample_size == 0: 470 | # this is the first iteration, we have no model and no existing samples 471 | model = None 472 | existing_samples = [] 473 | 474 | else: 475 | if args.method in ["random", "disc", "distance", "vocab", "no-disc"]: 476 | # methods that do not require model 477 | model = None 478 | 479 | else: 480 | # load model from previous round 481 | model_dir = os.path.join( 482 | args.ckptdir, 483 | f"{args.exp_name}_model-trained-on-{args.current_sample_size}/", 484 | ) 485 | model_path = glob.glob(model_dir + "*.ckpt") 486 | assert len(model_path) == 1, model_dir 487 | model_path = model_path[0] 488 | print(f"loading model from {model_path}") 489 | model = ArgumentRelationClassificationSystem.load_from_checkpoint( 490 | checkpoint_path=model_path 491 | ) 492 | model.cuda() 493 | model.eval() 494 | 495 | prev_output_path = os.path.join( 496 | output_dir, f"{args.method}.{args.current_sample_size}.jsonl" 497 | ) 498 | existing_samples = [] 499 | for ln in open(prev_output_path): 500 | _id, _label = json.loads(ln) 501 | _id = tuple(_id) 502 | existing_samples.append((_id, _label)) 503 | 504 | print(f"{len(existing_samples)} existing samples loaded") 505 | 506 | al = AL(args, existing_samples) 507 | selected, label_dist = al.acquire(model) 508 | 509 | output_path = os.path.join( 510 | output_dir, f"{args.method}.{args.current_sample_size+args.interval}.jsonl" 511 | ) 512 | with open(output_path, "w") as fout: 513 | for ln in selected: 514 | fout.write(json.dumps(ln) + "\n") 515 | 516 | # save ratio of labels for analysis 517 | label_ratio_output_path = output_path.replace(".jsonl", ".label_ratio.jsonl") 518 | with open(label_ratio_output_path, "w") as fout_log: 519 | fout_log.write(json.dumps(label_dist)) 520 | 521 | 522 | if __name__ == "__main__": 523 | main() 524 | -------------------------------------------------------------------------------- /scripts/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import glob 4 | from itertools import permutations 5 | import numpy as np 6 | 7 | 8 | class Statistics: 9 | 10 | def __init__(self, dataset): 11 | self.dataset = dataset 12 | self.stats_keys = [ 13 | 'num_doc', 'num_prop', 'num_segments', 'num_unique_target', 14 | 'num_unique_support_target', 'num_unique_attack_target', 15 | 'num_link', 'num_supports', 'num_attacks', 'avg_prop_len' 16 | ] 17 | self.stats = { 18 | _split: { 19 | k: 0 if k.startswith('num') else [] for k in self.stats_keys 20 | } 21 | for _split in ['train', 'val', 'test'] 22 | } 23 | 24 | def update(self, doc, _split): 25 | self.stats[_split]['num_doc'] += 1 26 | self.stats[_split]['num_prop'] += doc.num_args 27 | self.stats[_split]['num_segments'] += len(doc.segments) 28 | self.stats[_split]['num_unique_target'] += len(doc.target_prop_ids["union"]) 29 | self.stats[_split]['num_link'] += (len(doc.relations)) 30 | self.stats[_split]['num_supports'] += len(doc.supports) 31 | self.stats[_split]['num_attacks'] += len(doc.attacks) 32 | self.stats[_split]['num_unique_support_target'] += len(doc.target_prop_ids["support"]) 33 | self.stats[_split]['num_unique_attack_target'] += len(doc.target_prop_ids["attack"]) 34 | self.stats[_split]['avg_prop_len'].extend(doc.segment_lengths) 35 | 36 | def print(self): 37 | print(f'---------- {self.dataset} -----------') 38 | # print title 39 | print(f'{"split":<6s} |', end='') 40 | for k in self.stats_keys: 41 | print(f'{k:<10s} |', end='') 42 | print() 43 | for _split in ['train', 'val', 'test']: 44 | print(f'{_split:<6s} |', end='') 45 | for k in self.stats_keys: 46 | cur_s = self.stats[_split][k] 47 | if k.startswith('num'): 48 | print(f'{cur_s:<10} |', end='') 49 | else: 50 | print(f'{np.mean(cur_s):<10.1f}', end='') 51 | print() 52 | 53 | print(f'{"total":<6s} |', end='') 54 | for k in self.stats_keys: 55 | if k.startswith('num'): 56 | cur_s = sum([item[k] for item in self.stats.values()]) 57 | print(f'{cur_s:<10} |', end='') 58 | else: 59 | cur_s = [] 60 | for _split in self.stats: 61 | cur_s.extend(self.stats[_split][k]) 62 | cur_s = np.mean(cur_s) 63 | print(f'{cur_s:<10.1f} |', end='') 64 | 65 | print() 66 | 67 | 68 | class ArgumentDocument: 69 | 70 | dataset_path_prefix = None 71 | dataset_type = None 72 | dataset_glob_expression = None 73 | dataset_train_ids = [] 74 | dataset_val_ids = [] 75 | dataset_test_ids = [] 76 | 77 | def __init__(self, file_id): 78 | self.file_id = file_id 79 | self.num_args = 0 80 | # raw segmentation, might not be arguments 81 | self.segments = [] 82 | # for UKP, non-args are not considered for relations 83 | self.types = [] 84 | # target proposition ids, for UKP, this is NOT their segment ids 85 | self.target_prop_ids = {"support" : set(), "attack" : set(), "union" : set()} 86 | 87 | # indicate the target information for each segment 88 | self.target_labels = {"support" : [], "attack" : []} 89 | 90 | self.supports = [] 91 | self.attacks = [] 92 | self.relations = [] 93 | 94 | def load_document(self): 95 | """Load raw data from disk, and store in the following fields: 96 | self.num_args 97 | self.segments = [] 98 | self.types = [] 99 | self.target_prop_ids = set() # use proposition ids, not segment ids 100 | self.supports = [] # use segment ids 101 | self.supports_in_prop_ids = [] # use proposition ids 102 | self.attacks = [] 103 | ... 104 | """ 105 | raise NotImplementedError 106 | 107 | def generate_relation_candidates(self): 108 | """Create possible link prediction candidate pairs.""" 109 | for (src, tgt) in permutations(range(self.num_args), 2): 110 | if (src, tgt) in self.supports: 111 | label = "support" 112 | elif (src, tgt) in self.attacks: 113 | label = "attack" 114 | else: 115 | label = "none" 116 | src_text = self.segments[src] 117 | tgt_text = self.segments[tgt] 118 | yield ((src, src_text), (tgt, tgt_text), label) 119 | 120 | @classmethod 121 | def find_split(cls, file_id): 122 | raise NotImplementedError 123 | 124 | 125 | @classmethod 126 | def make_all_data(cls): 127 | docs = {'train': [], 'val': [], 'test': []} 128 | stats = Statistics(cls.dataset_type) 129 | for path in glob.glob(cls.dataset_path_prefix + cls.dataset_glob_expression): 130 | if 'ids' in path: 131 | continue 132 | 133 | file_id = os.path.basename(path).split('.')[0] 134 | cur_doc = cls(file_id) 135 | cur_doc.load_document() 136 | _split = cls.find_split(file_id) 137 | 138 | docs[_split].append(cur_doc) 139 | stats.update(cur_doc, _split) 140 | 141 | stats.print() 142 | cls.make_relation_doc_level_data(docs) 143 | 144 | 145 | @classmethod 146 | def make_relation_doc_level_data(cls, docs): 147 | fout_list = { 148 | s: open(f'trainable/{cls.dataset_type}_{s}.jsonl', 'w') 149 | for s in docs.keys() 150 | } 151 | for s in docs: 152 | for doc in docs[s]: 153 | output_obj = { 154 | 'doc_id': doc.file_id, 155 | 'text': doc.segments, 156 | 'relations': [{"head": item[1], "tail": item[0], "type": item[2]} for item in doc.relations], 157 | } 158 | fout_list[s].write(json.dumps(output_obj) + "\n") 159 | fout_list[s].close() 160 | 161 | 162 | class CDCPDocument(ArgumentDocument): 163 | 164 | dataset_path_prefix = "raw/cdcp/" 165 | dataset_glob_expression = "*.txt" 166 | dataset_type = "cdcp" 167 | raw_train_ids = [ln.strip() for ln in open('cdcp_train_ids.txt')] 168 | 169 | dataset_train_ids = raw_train_ids[:-80] 170 | dataset_val_ids = raw_train_ids[-80:] 171 | dataset_test_ids = [ln.strip() for ln in open('cdcp_test_ids.txt')] 172 | 173 | def __init__(self, file_id): 174 | # file_id: "01418" 175 | super().__init__(file_id) 176 | self.text_path = CDCPDocument.dataset_path_prefix + f'{file_id}.txt' 177 | self.ann_path = CDCPDocument.dataset_path_prefix + f'{file_id}.ann.json' 178 | 179 | def load_document(self): 180 | raw_text = open(self.text_path).read() 181 | ann_obj = json.loads(open(self.ann_path).read()) 182 | 183 | for (ch_start, ch_end) in ann_obj['prop_offsets']: 184 | self.segments.append(raw_text[ch_start:ch_end]) 185 | self.segment_lengths = [len(sent.split()) for sent in self.segments] 186 | self.num_args = len(self.segments) 187 | self.proposition_ids = [i for i in range(self.num_args)] 188 | 189 | self.target_prop_ids = {"support" : set(), "attack" : set(), "union" : set()} 190 | for ((src_start, src_end), tgt) in ann_obj['reasons'] + ann_obj['evidences']: 191 | self.target_prop_ids["support"].add(tgt) 192 | self.target_prop_ids["union"].add(tgt) 193 | for i in range(src_start, src_end + 1): 194 | self.relations.append((i, tgt, "support")) 195 | self.relations_in_prop_ids = self.relations 196 | self.target_labels["support"] = [(i in self.target_prop_ids) for i in range(self.num_args)] 197 | self.target_labels["attack"] = [False for i in range(self.num_args)] 198 | 199 | 200 | @classmethod 201 | def find_split(cls, file_id): 202 | if file_id in cls.dataset_train_ids: 203 | return 'train' 204 | elif file_id in cls.dataset_val_ids: 205 | return 'val' 206 | else: 207 | return 'test' 208 | 209 | 210 | class UKPDocument(ArgumentDocument): 211 | 212 | dataset_path_prefix = "raw/ArgumentAnnotatedEssays-2.0/brat-project-final/" 213 | dataset_glob_expression = "*.ann" 214 | dataset_type = "essays" 215 | dataset_train_ids = [1, 2, 3, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 216 | 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 217 | 37, 38, 39, 40, 41, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 218 | 54, 55, 56, 57, 58, 59, 60, 62, 63, 64, 65, 66, 67, 69, 70, 219 | 73, 74, 75, 76, 78, 79, 80, 81, 83, 84, 85, 87, 88, 89, 90, 220 | 92, 93, 94, 95, 96, 99, 100, 101, 102, 105, 106, 107, 109, 221 | 110, 111, 112, 113, 114, 115, 116, 118, 120, 121, 122, 123, 222 | 124, 125, 127, 128, 130, 131, 132, 133, 134, 135, 137, 138, 223 | 140, 141, 143, 144, 145, 146, 147, 148, 150, 151, 152, 153, 224 | 155, 156, 157, 158, 159, 161, 162, 164, 165, 166, 167, 168, 225 | 170, 171, 173, 174, 175, 176, 177, 178, 179, 181, 183, 184, 226 | 185, 186, 188, 189, 190, 191, 194, 195, 196, 197, 198, 200, 227 | 201, 203, 205, 206, 207, 208, 209, 210, 213, 214, 215, 216, 228 | 217, 219, 222, 223, 224, 225, 226, 228, 230, 231, 232, 233, 229 | 235, 236, 237, 238, 239, 242, 244, 246, 247, 248, 249, 250, 230 | 251, 253, 254, 256, 257, 258, 260, 261, 262, 263, 264, 267, 231 | 268, 269, 270, 271, 272, 273, 274, 275, 276, 279, 280, 281, 232 | 282, 283, 284, 285, 286, 288, 290, 291, 292, 293, 294, 295, 233 | 296, 297, 298, 299, 300, 302, 303, 304, 305, 307, 308, 309, 234 | 311, 312, 313, 314, 315, 317, 318, 319, 320, 321, 323, 324, 235 | 325, 326, 327, 329, 330, 332, 333, 334, 336, 337, 338, 339, 236 | 340, 342, 343, 344, 345, 346, 347, 349, 350, 351, 353, 354, 237 | ] 238 | dataset_val_ids = [ 239 | 356, 357, 358, 360, 361, 362, 363, 365, 366, 367, 368, 369, 240 | 370, 371, 372, 374, 375, 376, 377, 378, 379, 380, 381, 383, 241 | 384, 385, 387, 388, 389, 390, 391, 392, 394, 395, 396, 397, 242 | 399, 400, 401, 402 243 | ] 244 | dataset_test_ids = [ 245 | 4, 5, 6, 21, 42, 52, 61, 68, 71, 72, 77, 82, 86, 91, 97, 98, 246 | 103, 104, 108, 117, 119, 126, 129, 136, 139, 142, 149, 154, 247 | 160, 163, 169, 172, 180, 182, 187, 192, 193, 199, 202, 204, 248 | 211, 212, 218, 220, 221, 227, 229, 234, 240, 241, 243, 245, 249 | 252, 255, 259, 265, 266, 277, 278, 287, 289, 301, 306, 310, 250 | 316, 322, 328, 331, 335, 341, 348, 352, 355, 359, 364, 373, 251 | 382, 386, 393, 398 252 | ] 253 | 254 | 255 | def __init__(self, file_id): 256 | # file_id: "essay012" 257 | super().__init__(file_id) 258 | self.adjusted_prop_para = [] 259 | self.text_path = UKPDocument.dataset_path_prefix + f'{file_id}.txt' 260 | self.ann_path = UKPDocument.dataset_path_prefix + f'{file_id}.ann' 261 | 262 | 263 | def generate_relation_candidates(self): 264 | """For all argumentative segments, create possible link prediction 265 | candidate pairs. Note that only arguments within the same paragraph can 266 | be candidates. 267 | """ 268 | for (src, tgt) in permutations(range(self.num_args), 2): 269 | if self.prop_para[src] != self.prop_para[tgt]: 270 | continue 271 | 272 | if (src, tgt) in self.supports: 273 | label = "support" 274 | elif (src, tgt) in self.attacks: 275 | label = "attack" 276 | else: 277 | label = "none" 278 | src_text = self.prop_id_to_text[src] 279 | tgt_text = self.prop_id_to_text[tgt] 280 | yield ((src, src_text), (tgt, tgt_text), label) 281 | 282 | 283 | def load_document(self): 284 | raw_text = open(self.text_path).read() 285 | 286 | props_info = dict() # T1 -> (begin, end) 287 | raw_rels = {'supports': [], 'attacks': []} 288 | 289 | para_offsets = [] 290 | ix = 0 291 | while True: 292 | para_offsets.append(ix) 293 | try: 294 | ix = raw_text.index("\n", ix + 1) 295 | except ValueError: 296 | break 297 | para_offsets = np.array(para_offsets) 298 | for line in open(self.ann_path): 299 | if line[0] == 'T': 300 | fields = line.split('\t') 301 | position = fields[1].split() 302 | char_start = int(position[1]) 303 | char_end = int(position[2]) 304 | props_info[fields[0]] = (char_start, char_end) 305 | 306 | elif line[0] == 'R': 307 | fields = line.split('\t') 308 | rel_info = fields[1].split() 309 | rel_type = rel_info[0] 310 | src_id = rel_info[1].split(':')[1] 311 | tgt_id = rel_info[2].split(':')[1] 312 | raw_rels[rel_type].append((src_id, tgt_id)) 313 | 314 | # old_ix: (T2, T1, T3) 315 | # sorted_prop_ids: [(5, 10), (20, 30), (35, 40)] 316 | old_ix, sorted_prop_ids = zip(*sorted(props_info.items(), key=lambda x: x[1])) 317 | inv_idx = {k: v for v, k in enumerate(old_ix)} 318 | self.supports = [(inv_idx[src], inv_idx[tgt]) for (src, tgt) in raw_rels['supports']] 319 | self.attacks = [(inv_idx[src], inv_idx[tgt]) for (src, tgt) in raw_rels['attacks']] 320 | self.target_prop_ids = {"support": set(), "attack": set(), "union": set()} 321 | for (s, t) in self.supports: 322 | self.target_prop_ids["support"].add(t) 323 | self.target_prop_ids["union"].add(t) 324 | 325 | for (s, t) in self.attacks: 326 | self.target_prop_ids["attack"].add(t) 327 | self.target_prop_ids["union"].add(t) 328 | 329 | self.prop_para = [int(np.searchsorted(para_offsets, start)) - 1 330 | for start, _ in sorted_prop_ids] 331 | 332 | self.segments = [] 333 | self.is_argument = [] 334 | self.target_labels = {"support": [], "attack": []} 335 | self.num_args = len(sorted_prop_ids) 336 | self.prop_id_to_text = [] 337 | 338 | prop_id_to_seg_id = dict() 339 | cur_char_ptr = 0 340 | 341 | self.proposition_ids = [] 342 | self.adjusted_prop_para = [] 343 | for ix, item in enumerate(sorted_prop_ids): 344 | 345 | # first add previous segment 346 | if cur_char_ptr < item[0]: 347 | prev_seg = raw_text[cur_char_ptr:item[0]] 348 | self.is_argument.append(False) 349 | self.segments.append(prev_seg) 350 | self.proposition_ids.append(-1) 351 | self.target_labels["support"].append(None) 352 | self.target_labels["attack"].append(None) 353 | self.adjusted_prop_para.append(int(np.searchsorted(para_offsets, cur_char_ptr))) 354 | 355 | # add self 356 | cur_seg = raw_text[item[0]:item[1]] 357 | prop_id_to_seg_id[ix] = len(self.segments) 358 | self.segments.append(cur_seg) 359 | self.is_argument.append(True) 360 | self.proposition_ids.append(ix) 361 | self.target_labels["support"].append(ix in self.target_prop_ids["support"]) 362 | self.target_labels["attack"].append(ix in self.target_prop_ids["attack"]) 363 | self.prop_id_to_text.append(cur_seg) 364 | cur_char_ptr = item[1] 365 | self.adjusted_prop_para.append(int(np.searchsorted(para_offsets, item[0]))) 366 | 367 | if cur_char_ptr < len(raw_text) - 1: 368 | self.segments.append(raw_text[cur_char_ptr:]) 369 | self.is_argument.append(False) 370 | self.target_labels["support"].append(None) 371 | self.target_labels["attack"].append(None) 372 | self.proposition_ids.append(-1) 373 | self.adjusted_prop_para.append(int(np.searchsorted(para_offsets, cur_char_ptr))) 374 | 375 | self.segment_lengths = [len(sent.split()) for sent in self.segments] 376 | 377 | self.relations = [] # using global ids 378 | self.relations_in_prop_ids = [] # using prop ids 379 | 380 | tgt_to_ids = dict() 381 | rel_tgt_to_ids = dict() 382 | 383 | 384 | for ix, relation_type in enumerate([self.supports, self.attacks]): 385 | for (src, tgt) in relation_type: 386 | src_real_id = prop_id_to_seg_id[src] 387 | tgt_real_id = prop_id_to_seg_id[tgt] 388 | 389 | if tgt_real_id not in rel_tgt_to_ids: 390 | rel_tgt_to_ids[tgt_real_id] = [[], []] 391 | # support -> [src], attack -> [src] 392 | rel_tgt_to_ids[tgt_real_id][ix].append(src_real_id) 393 | 394 | if tgt not in tgt_to_ids: 395 | tgt_to_ids[tgt] = [[], []] 396 | tgt_to_ids[tgt][ix].append(src) 397 | 398 | for t, (sup, att) in rel_tgt_to_ids.items(): 399 | for s in sup: 400 | self.relations.append((s, t, "support")) 401 | for s in att: 402 | self.relations.append((s, t, "attack")) 403 | 404 | for t, (sup, att) in tgt_to_ids.items(): 405 | for s in sup: 406 | self.relations_in_prop_ids.append((s, t, "support")) 407 | for s in att: 408 | self.relations_in_prop_ids.append((s, t, "attack")) 409 | 410 | 411 | @classmethod 412 | def find_split(cls, file_id): 413 | doc_num_id = int(file_id[5:]) 414 | if doc_num_id in cls.dataset_train_ids: 415 | return 'train' 416 | elif doc_num_id in cls.dataset_val_ids: 417 | return 'val' 418 | else: 419 | return 'test' 420 | 421 | 422 | class ECHRDocument(ArgumentDocument): 423 | 424 | dataset_type = 'ECHR' 425 | dataset_path_prefix = 'raw/echr_corpus/' 426 | train_ids = list(range(27)) 427 | val_ids = list(range(27, 27+7)) 428 | test_ids = list(range(27 + 7, 42)) 429 | 430 | def __init__(self, file_id, doc_items): 431 | super().__init__(file_id) 432 | self.doc_items = doc_items 433 | 434 | def load_document(self): 435 | self.segments = self.doc_items['sentences'] 436 | self.segment_lengths = [len(sent.split()) for sent in self.segments] 437 | self.proposition_ids = [] 438 | 439 | prop_id = 0 440 | id2prop_id = dict() 441 | id2seg_id = dict() 442 | for i, item in enumerate(self.doc_items['is_argument']): 443 | cur_id = self.doc_items['clause_id'][i] 444 | if item: 445 | self.proposition_ids.append(prop_id) 446 | id2prop_id[cur_id] = prop_id 447 | id2seg_id[cur_id] = i 448 | prop_id += 1 449 | else: 450 | self.proposition_ids.append(-1) 451 | 452 | self.relations = [] 453 | 454 | for (src, tgt) in self.doc_items['relations']: 455 | self.relations.append([id2seg_id[src], id2seg_id[tgt], 'support']) 456 | self.types = None 457 | self.target_labels = None 458 | 459 | 460 | @classmethod 461 | def make_all_data(cls): 462 | docs = {'train': [], 'val': [], 'test': []} 463 | stats = Statistics(cls.dataset_type) 464 | 465 | assert os.path.exists(cls.dataset_path_prefix + "ECHR_sentences.jsonl"), "To convert ECHR, please run `echr_preproc.py` first." 466 | 467 | raw_data = [json.loads(ln) for ln in open(cls.dataset_path_prefix + "ECHR_sentences.jsonl")] 468 | for i, doc in enumerate(raw_data): 469 | 470 | cur_doc = cls(i, doc) 471 | cur_doc.load_document() 472 | 473 | if i in cls.train_ids: 474 | dsplit = 'train' 475 | docs['train'].append(cur_doc) 476 | elif i in cls.val_ids: 477 | dsplit = 'val' 478 | docs['val'].append(cur_doc) 479 | else: 480 | dsplit = 'test' 481 | docs['test'].append(cur_doc) 482 | stats.update(cur_doc, dsplit) 483 | 484 | stats.print() 485 | #cls.make_target_prediction_data(docs) 486 | cls.make_relation_doc_level_data(docs) 487 | 488 | 489 | class AbstCRTDocument(ArgumentDocument): 490 | 491 | dataset_path_prefix = 'raw/abstrct-master/AbstRCT_corpus/data/' 492 | dataset_type = 'abst_rct' 493 | 494 | def __init__(self, file_id, path): 495 | super().__init__(file_id) 496 | self.text_path = path[:-3] + 'txt' 497 | self.ann_path = path 498 | assert os.path.exists(self.text_path), f'Path not found! {self.text_path}' 499 | 500 | def load_document(self): 501 | raw_text = open(self.text_path).read() 502 | props_info = dict() # T1 -> (begin, end) 503 | raw_rels = {'Support': [], 'Attack': []} 504 | 505 | for line in open(self.ann_path): 506 | if line[0] == 'T': 507 | fields = line.split("\t") 508 | position = fields[1].split() 509 | char_start = int(position[1]) 510 | char_end = int(position[2]) 511 | props_info[fields[0]] = (char_start, char_end) 512 | 513 | elif line[0] == 'R': 514 | fields = line.split("\t") 515 | rel_info = fields[1].split() 516 | rel_type = rel_info[0] 517 | if rel_type not in raw_rels: 518 | continue 519 | src_id = rel_info[1].split(':')[1] 520 | tgt_id = rel_info[2].split(':')[1] 521 | raw_rels[rel_type].append((src_id, tgt_id)) 522 | 523 | # old_ix: (T3, T1, T2) 524 | # sorted_prop_ids ((639, 1018), (1611, 1673), (1674, 1776)) 525 | old_ix, sorted_prop_ids = zip(*sorted(props_info.items(), key=lambda x: x[1])) 526 | inv_idx = {k: v for v, k in enumerate(old_ix)} 527 | self.supports = [(inv_idx[src], inv_idx[tgt]) for (src, tgt) in raw_rels['Support']] 528 | self.attacks = [(inv_idx[src], inv_idx[tgt]) for (src, tgt) in raw_rels['Attack']] 529 | self.target_prop_ids = {"support" : set(), "attack" : set(), "union" : set()} 530 | 531 | for (s, t) in self.supports: 532 | self.target_prop_ids["support"].add(t) 533 | self.target_prop_ids["union"].add(t) 534 | 535 | for (s, t) in self.attacks: 536 | self.target_prop_ids["attack"].add(t) 537 | self.target_prop_ids["union"].add(t) 538 | 539 | prop_id_to_seg_id = dict() # map proposition id to natural segment id 540 | cur_char_ptr = 0 541 | self.segments = [] 542 | self.target_labels = [] 543 | self.is_argument = [] 544 | self.num_args = len(sorted_prop_ids) 545 | self.prop_id_to_text = [] 546 | self.proposition_ids = [] # -1 for non-arg, otherwise count from 0 547 | 548 | for ix, item in enumerate(sorted_prop_ids): 549 | # if there's anything, add as non-arg 550 | if cur_char_ptr < item[0] and item[0] - cur_char_ptr > 2: 551 | prev_seg = raw_text[cur_char_ptr: item[0]] 552 | self.is_argument.append(False) 553 | self.segments.append(prev_seg) 554 | self.proposition_ids.append(-1) 555 | self.target_labels.append(None) 556 | 557 | # add current argument (proposition) 558 | cur_seg = raw_text[item[0]: item[1]] 559 | prop_id_to_seg_id[ix] = len(self.segments) 560 | self.segments.append(cur_seg) 561 | self.is_argument.append(True) 562 | self.proposition_ids.append(ix) 563 | self.target_labels.append(ix in self.target_prop_ids) 564 | self.prop_id_to_text.append(cur_seg) 565 | cur_char_ptr = item[1] 566 | 567 | if cur_char_ptr < len(raw_text) - 1: 568 | cur_seg = raw_text[cur_char_ptr:] 569 | if cur_seg.strip() != '': 570 | self.segments.append(cur_seg) 571 | self.is_argument.append(False) 572 | self.target_labels.append(None) 573 | self.proposition_ids.append(-1) 574 | 575 | self.relations = [] 576 | self.relations_in_prop_ids = [] 577 | rel_tgt_to_ids = dict() 578 | tgt_to_ids = dict() 579 | 580 | for ix, relation_type in enumerate([self.supports, self.attacks]): 581 | for (src, tgt) in relation_type: 582 | src_real_id = prop_id_to_seg_id[src] 583 | tgt_real_id = prop_id_to_seg_id[tgt] 584 | if tgt_real_id not in rel_tgt_to_ids: 585 | rel_tgt_to_ids[tgt_real_id] = [[], []] 586 | rel_tgt_to_ids[tgt_real_id][ix].append(src_real_id) 587 | 588 | if tgt not in tgt_to_ids: 589 | tgt_to_ids[tgt] = [[], []] 590 | tgt_to_ids[tgt][ix].append(src) 591 | 592 | for t, (sup, att) in rel_tgt_to_ids.items(): 593 | for s in sup: 594 | self.relations.append((s, t, "support")) 595 | for s in att: 596 | self.relations.append((s, t, "attack")) 597 | 598 | self.segment_lengths = [len(sent.split()) for sent in self.segments] 599 | 600 | 601 | @classmethod 602 | def make_all_data(cls): 603 | docs = {'train': [], 'val': [], 'test': []} 604 | stats = Statistics(cls.dataset_type) 605 | 606 | for dsplit in ['train', 'dev', 'test']: 607 | paths = glob.glob(cls.dataset_path_prefix + dsplit + '/*/*.ann') 608 | if dsplit == 'dev': 609 | dsplit = 'val' 610 | for path in paths: 611 | file_id = os.path.basename(path).split('.')[0] 612 | topic = path.split('/')[-2] 613 | file_id = (topic, file_id) 614 | cur_doc = cls(file_id, path) 615 | cur_doc.load_document() 616 | docs[dsplit].append(cur_doc) 617 | stats.update(cur_doc, dsplit) 618 | stats.print() 619 | cls.make_relation_doc_level_data(docs) -------------------------------------------------------------------------------- /argument_relation_transformer/modeling.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import re 4 | from typing import Optional, Union 5 | 6 | import torch 7 | import torch.nn as nn 8 | from transformers import RobertaForTokenClassification, RobertaModel 9 | from transformers.configuration_utils import PretrainedConfig 10 | from transformers.file_utils import ( 11 | TF2_WEIGHTS_NAME, 12 | TF_WEIGHTS_NAME, 13 | WEIGHTS_NAME, 14 | cached_path, 15 | hf_bucket_url, 16 | is_remote_url, 17 | ) 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class RobertaForDocumentSpanClassification(RobertaForTokenClassification): 23 | def __init__(self, config, **kwargs): 24 | super().__init__(config) 25 | assert "num_labels" in kwargs 26 | num_labels = kwargs["num_labels"] 27 | output_layers = kwargs.pop("output_layers", 1) 28 | self.num_labels = num_labels 29 | self.roberta = RobertaModel(config) 30 | self.init_weights() 31 | 32 | self.dropout = nn.Dropout(config.hidden_dropout_prob) 33 | self.output_layers = output_layers 34 | if output_layers == 1: 35 | self.output_dense = nn.Sequential( 36 | self.dropout, 37 | nn.Linear( 38 | in_features=2 * config.hidden_size, 39 | out_features=config.hidden_size, 40 | bias=True, 41 | ), 42 | nn.Tanh(), 43 | self.dropout, 44 | nn.Linear( 45 | in_features=config.hidden_size, 46 | out_features=self.num_labels, 47 | bias=True, 48 | ), 49 | ) 50 | 51 | else: 52 | self.output_dense = nn.Sequential( 53 | self.dropout, 54 | nn.Linear( 55 | in_features=2 * config.hidden_size, 56 | out_features=config.hidden_size, 57 | bias=True, 58 | ), 59 | nn.Tanh(), 60 | self.dropout, 61 | nn.Linear( 62 | in_features=config.hidden_size, 63 | out_features=config.hidden_size, 64 | bias=True, 65 | ), 66 | nn.Tanh(), 67 | self.dropout, 68 | nn.Linear( 69 | in_features=config.hidden_size, 70 | out_features=self.num_labels, 71 | bias=True, 72 | ), 73 | ) 74 | 75 | @classmethod 76 | def from_pretrained( 77 | cls, 78 | pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], 79 | *model_args, 80 | **kwargs, 81 | ): 82 | r""" 83 | Instantiate a pretrained pytorch model from a pre-trained model configuration. 84 | The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated). To 85 | train the model, you should first set it back in training mode with ``model.train()``. 86 | The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come 87 | pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning 88 | task. 89 | The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those 90 | weights are discarded. 91 | Parameters: 92 | pretrained_model_name_or_path (:obj:`str` or :obj:`os.PathLike`, `optional`): 93 | Can be either: 94 | - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. 95 | Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under 96 | a user or organization name, like ``dbmdz/bert-base-german-cased``. 97 | - A path to a `directory` containing model weights saved using 98 | :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. 99 | - A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In 100 | this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided 101 | as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in 102 | a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. 103 | - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword 104 | arguments ``config`` and ``state_dict``). 105 | model_args (sequence of positional arguments, `optional`): 106 | All remaning positional arguments will be passed to the underlying model's ``__init__`` method. 107 | config (:obj:`Union[PretrainedConfig, str, os.PathLike]`, `optional`): 108 | Can be either: 109 | - an instance of a class derived from :class:`~transformers.PretrainedConfig`, 110 | - a string or path valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`. 111 | Configuration for the model to use instead of an automatically loaded configuation. Configuration can 112 | be automatically loaded when: 113 | - The model is a model provided by the library (loaded with the `model id` string of a pretrained 114 | model). 115 | - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded 116 | by supplying the save directory. 117 | - The model is loaded by supplying a local directory as ``pretrained_model_name_or_path`` and a 118 | configuration JSON file named `config.json` is found in the directory. 119 | state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`): 120 | A state dictionary to use instead of a state dictionary loaded from saved weights file. 121 | This option can be used if you want to create a model from a pretrained configuration but load your own 122 | weights. In this case though, you should check if using 123 | :func:`~transformers.PreTrainedModel.save_pretrained` and 124 | :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option. 125 | cache_dir (:obj:`Union[str, os.PathLike]`, `optional`): 126 | Path to a directory in which a downloaded pretrained model configuration should be cached if the 127 | standard cache should not be used. 128 | from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`): 129 | Load the model weights from a TensorFlow checkpoint save file (see docstring of 130 | ``pretrained_model_name_or_path`` argument). 131 | force_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 132 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the 133 | cached versions if they exist. 134 | resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`): 135 | Whether or not to delete incompletely received files. Will attempt to resume the download if such a 136 | file exists. 137 | proxies (:obj:`Dict[str, str], `optional`): 138 | A dictionary of proxy servers to use by protocol or endpoint, e.g., :obj:`{'http': 'foo.bar:3128', 139 | 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 140 | output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`): 141 | Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. 142 | local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`): 143 | Whether or not to only look at local files (i.e., do not try to download the model). 144 | use_auth_token (:obj:`str` or `bool`, `optional`): 145 | The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token 146 | generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). 147 | revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): 148 | The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a 149 | git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any 150 | identifier allowed by git. 151 | mirror(:obj:`str`, `optional`, defaults to :obj:`None`): 152 | Mirror source to accelerate downloads in China. If you are from China and have an accessibility 153 | problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. 154 | Please refer to the mirror site for more information. 155 | kwargs (remaining dictionary of keyword arguments, `optional`): 156 | Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., 157 | :obj:`output_attentions=True`). Behaves differently depending on whether a ``config`` is provided or 158 | automatically loaded: 159 | - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the 160 | underlying model's ``__init__`` method (we assume all relevant updates to the configuration have 161 | already been done) 162 | - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class 163 | initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of 164 | ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute 165 | with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration 166 | attribute will be passed to the underlying model's ``__init__`` function. 167 | .. note:: 168 | Passing :obj:`use_auth_token=True` is required when you want to use a private model. 169 | """ 170 | config = kwargs.pop("config", None) 171 | state_dict = kwargs.pop("state_dict", None) 172 | cache_dir = kwargs.pop("cache_dir", None) 173 | from_tf = kwargs.pop("from_tf", False) 174 | force_download = kwargs.pop("force_download", False) 175 | resume_download = kwargs.pop("resume_download", False) 176 | proxies = kwargs.pop("proxies", None) 177 | output_loading_info = kwargs.pop("output_loading_info", False) 178 | 179 | revision = kwargs.pop("revision", None) 180 | mirror = kwargs.pop("mirror", None) 181 | 182 | # Load config if we don't provide a configuration 183 | if not isinstance(config, PretrainedConfig): 184 | config_path = ( 185 | config if config is not None else pretrained_model_name_or_path 186 | ) 187 | config, model_kwargs = cls.config_class.from_pretrained( 188 | config_path, 189 | *model_args, 190 | cache_dir=cache_dir, 191 | return_unused_kwargs=True, 192 | force_download=force_download, 193 | resume_download=resume_download, 194 | proxies=proxies, 195 | revision=revision, 196 | **kwargs, 197 | ) 198 | 199 | # Load model 200 | if pretrained_model_name_or_path is not None: 201 | pretrained_model_name_or_path = str(pretrained_model_name_or_path) 202 | if os.path.isdir(pretrained_model_name_or_path): 203 | if os.path.isfile( 204 | os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) 205 | ): 206 | # Load from a PyTorch checkpoint 207 | archive_file = os.path.join( 208 | pretrained_model_name_or_path, WEIGHTS_NAME 209 | ) 210 | else: 211 | raise EnvironmentError( 212 | "Error no file named {} found in directory {} or `from_tf` set to False".format( 213 | [ 214 | WEIGHTS_NAME, 215 | TF2_WEIGHTS_NAME, 216 | TF_WEIGHTS_NAME + ".index", 217 | ], 218 | pretrained_model_name_or_path, 219 | ) 220 | ) 221 | elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url( 222 | pretrained_model_name_or_path 223 | ): 224 | archive_file = pretrained_model_name_or_path 225 | else: 226 | archive_file = hf_bucket_url( 227 | pretrained_model_name_or_path, 228 | filename=(TF2_WEIGHTS_NAME if from_tf else WEIGHTS_NAME), 229 | revision=revision, 230 | mirror=mirror, 231 | ) 232 | 233 | try: 234 | # Load from URL or cache if already cached 235 | resolved_archive_file = cached_path( 236 | archive_file, 237 | cache_dir=cache_dir, 238 | force_download=force_download, 239 | proxies=proxies, 240 | resume_download=resume_download, 241 | ) 242 | except EnvironmentError as err: 243 | logger.error(err) 244 | msg = ( 245 | f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" 246 | f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" 247 | f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n" 248 | ) 249 | raise EnvironmentError(msg) 250 | 251 | if resolved_archive_file == archive_file: 252 | logger.info("loading weights file {}".format(archive_file)) 253 | else: 254 | logger.info( 255 | "loading weights file {} from cache at {}".format( 256 | archive_file, resolved_archive_file 257 | ) 258 | ) 259 | else: 260 | resolved_archive_file = None 261 | 262 | config.name_or_path = pretrained_model_name_or_path 263 | 264 | # Instantiate model. 265 | model = cls(config, **kwargs) 266 | 267 | if state_dict is None and not from_tf: 268 | try: 269 | state_dict = torch.load(resolved_archive_file, map_location="cpu") 270 | except Exception: 271 | raise OSError( 272 | f"Unable to load weights from pytorch checkpoint file for '{pretrained_model_name_or_path}' " 273 | f"at '{resolved_archive_file}'" 274 | "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True. " 275 | ) 276 | if "state_dict" in state_dict: 277 | state_dict = state_dict["state_dict"] 278 | 279 | missing_keys = [] 280 | unexpected_keys = [] 281 | error_msgs = [] 282 | 283 | # Convert old format to new format if needed from a PyTorch state_dict 284 | old_keys = [] 285 | new_keys = [] 286 | for key in state_dict.keys(): 287 | new_key = None 288 | if "gamma" in key: 289 | new_key = key.replace("gamma", "weight") 290 | if "beta" in key: 291 | new_key = key.replace("beta", "bias") 292 | if key.startswith("model."): 293 | new_key = key[6:] 294 | old_keys.append(key) 295 | new_keys.append(new_key) 296 | 297 | # if pivot_pretrained: 298 | # if key.startswith('model.'): 299 | # new_key = key[6:] 300 | # old_keys.append(key) 301 | # new_keys.append(new_key) 302 | # else: 303 | # # NOTE: special treatment for pre-training tasks 304 | # if 'model.roberta' in key: 305 | # new_key = '.'.join(key.split('.')[2:]) 306 | # if new_key: 307 | # old_keys.append(key) 308 | # new_keys.append(new_key) 309 | 310 | for old_key, new_key in zip(old_keys, new_keys): 311 | state_dict[new_key] = state_dict.pop(old_key) 312 | 313 | # check if checkpoint (src model) and target model has same `num_labels` 314 | if "roberta-base" not in pretrained_model_name_or_path: 315 | src_num_labels = state_dict["output_dense.4.weight"].shape[0] 316 | else: 317 | src_num_labels = model.num_labels 318 | tgt_num_labels = model.num_labels 319 | 320 | # copy state_dict so _load_from_state_dict can modify it 321 | metadata = getattr(state_dict, "_metadata", None) 322 | state_dict = state_dict.copy() 323 | if metadata is not None: 324 | state_dict._metadata = metadata 325 | 326 | # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants 327 | # so we need to apply the function recursively. 328 | def load(module: nn.Module, prefix=""): 329 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 330 | 331 | module._load_from_state_dict( 332 | state_dict, 333 | prefix, 334 | local_metadata, 335 | True, 336 | missing_keys, 337 | unexpected_keys, 338 | error_msgs, 339 | ) 340 | for name, child in module._modules.items(): 341 | # the output layer needs to be loaded separately 342 | # print("prefix: ", prefix, "name: ", name) 343 | if src_num_labels != tgt_num_labels: 344 | 345 | if prefix == "output_dense." and name == "4": 346 | print(">>>>>> skip copying output_dense.") 347 | continue 348 | if prefix == "" and name == "classifier": 349 | print(">>>>>> skip copying classifier") 350 | continue 351 | 352 | if child is not None: 353 | load(child, prefix + name + ".") 354 | 355 | # Make sure we are able to load base models as well as derived models (with heads) 356 | start_prefix = "" 357 | model_to_load = model 358 | has_prefix_module = any( 359 | s.startswith(cls.base_model_prefix) for s in state_dict.keys() 360 | ) 361 | 362 | if not hasattr(model, cls.base_model_prefix) and has_prefix_module: 363 | start_prefix = cls.base_model_prefix + "." 364 | if hasattr(model, cls.base_model_prefix) and not has_prefix_module: 365 | model_to_load = getattr(model, cls.base_model_prefix) 366 | load(model_to_load, prefix=start_prefix) 367 | 368 | if src_num_labels != tgt_num_labels: 369 | 370 | if src_num_labels == 3 and tgt_num_labels == 2: 371 | 372 | # load output_dense.4 in full if current model is trinary 373 | # load partially if current model is binary 374 | 375 | # print("ckpt weight: ", state_dict["output_dense.4.weight"][:2, :5]) 376 | # print("ckpt bias: ", state_dict["output_dense.4.bias"]) 377 | # print("ckpt classifier weight", state_dict["classifier.weight"][:2, :5]) 378 | # print("ckpt classifier bias", state_dict["classifier.bias"]) 379 | 380 | # print(">>>>>>>>>>> BEFORE COPY >>>>>>>>>>>>>>>>") 381 | # print("model.output_dense.weight", model_to_load.output_dense[4].weight[:, :5]) 382 | # print("model.output_dense.bias", model_to_load.output_dense[4].bias) 383 | # print("model.classifier.weight", model_to_load.classifier.weight[:, :5]) 384 | # print("model.classifier.bias", model_to_load.classifier.bias) 385 | 386 | with torch.no_grad(): 387 | model_to_load.output_dense[4].bias.copy_( 388 | state_dict["output_dense.4.bias"][:2] 389 | ) 390 | model_to_load.output_dense[4].weight.copy_( 391 | state_dict["output_dense.4.weight"][:2, :] 392 | ) 393 | model_to_load.classifier.bias.copy_( 394 | state_dict["classifier.bias"][:2] 395 | ) 396 | model_to_load.classifier.weight.copy_( 397 | state_dict["classifier.weight"][:2, :] 398 | ) 399 | # print(">>>>>>>>>>> AFTER COPY >>>>>>>>>>>>>>>>") 400 | # print("model.output_dense.weight", model_to_load.output_dense[4].weight[:, :5]) 401 | # print("model.output_dense.bias", model_to_load.output_dense[4].bias) 402 | # print("model.classifier.weight", model_to_load.classifier.weight[:, :5]) 403 | # print("model.classifier.bias", model_to_load.classifier.bias) 404 | 405 | else: 406 | # in this case, the entire head will be reinitialized 407 | pass 408 | 409 | if model.__class__.__name__ != model_to_load.__class__.__name__: 410 | base_model_state_dict = model_to_load.state_dict().keys() 411 | head_model_state_dict_without_base_prefix = [ 412 | key.split(cls.base_model_prefix + ".")[-1] 413 | for key in model.state_dict().keys() 414 | ] 415 | missing_keys.extend( 416 | head_model_state_dict_without_base_prefix - base_model_state_dict 417 | ) 418 | 419 | # Some models may have keys that are not in the state by design, removing them before needlessly warning 420 | # the user. 421 | if cls._keys_to_ignore_on_load_missing is not None: 422 | for pat in cls._keys_to_ignore_on_load_missing: 423 | missing_keys = [k for k in missing_keys if re.search(pat, k) is None] 424 | 425 | if cls._keys_to_ignore_on_load_unexpected is not None: 426 | for pat in cls._keys_to_ignore_on_load_unexpected: 427 | unexpected_keys = [ 428 | k for k in unexpected_keys if re.search(pat, k) is None 429 | ] 430 | 431 | if len(unexpected_keys) > 0: 432 | logger.warning( 433 | f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " 434 | f"initializing {model.__class__.__name__}: {unexpected_keys}\n" 435 | f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " 436 | f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n" 437 | f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " 438 | f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." 439 | ) 440 | else: 441 | logger.info( 442 | f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n" 443 | ) 444 | if len(missing_keys) > 0: 445 | logger.warning( 446 | f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " 447 | f"and are newly initialized: {missing_keys}\n" 448 | f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." 449 | ) 450 | else: 451 | logger.info( 452 | f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" 453 | f"If your task is similar to the task the model of the checkpoint was trained on, " 454 | f"you can already use {model.__class__.__name__} for predictions without further training." 455 | ) 456 | if len(error_msgs) > 0: 457 | raise RuntimeError( 458 | "Error(s) in loading state_dict for {}:\n\t{}".format( 459 | model.__class__.__name__, "\n\t".join(error_msgs) 460 | ) 461 | ) 462 | # make sure token embedding weights are still tied if needed 463 | model.tie_weights() 464 | 465 | # Set model in evaluation mode to deactivate DropOut modules by default 466 | model.eval() 467 | 468 | if output_loading_info: 469 | loading_info = { 470 | "missing_keys": missing_keys, 471 | "unexpected_keys": unexpected_keys, 472 | "error_msgs": error_msgs, 473 | } 474 | return model, loading_info 475 | 476 | return model 477 | 478 | def forward( 479 | self, 480 | input_ids, 481 | sequence_boundary_ids, 482 | target_ids, 483 | attention_mask=None, 484 | labels=None, 485 | **kwargs, 486 | ): 487 | """Run forward pass over entire documents. 488 | sequence_boundary_ids: torch.LongTensor of shape (batch_size, max_boundary_num), 0 for non-selected 489 | sequence_boundary_mask: torch.FloatTensor of shape (batch_size, max_boundary_num), 1 for not masked, 0 for masked 490 | """ 491 | 492 | input_embeddings = self.roberta.embeddings(input_ids=input_ids) 493 | 494 | outputs = self.roberta( 495 | inputs_embeds=input_embeddings, 496 | attention_mask=attention_mask, 497 | head_mask=None, 498 | output_attentions=None, 499 | output_hidden_states=None, 500 | return_dict=False, 501 | ) 502 | sequence_output = outputs[0] 503 | sequence_output = self.dropout(sequence_output) 504 | boundary_outputs = torch.gather( 505 | input=sequence_output, 506 | index=sequence_boundary_ids.unsqueeze(-1).expand( 507 | -1, -1, sequence_output.shape[-1] 508 | ), 509 | dim=1, 510 | ) 511 | 512 | target_outputs = torch.gather( 513 | input=boundary_outputs, 514 | index=target_ids.unsqueeze(-1).expand(-1, -1, boundary_outputs.shape[-1]), 515 | dim=1, 516 | ) 517 | candidate_outputs = boundary_outputs 518 | concat = torch.cat( 519 | ( 520 | target_outputs.expand(-1, candidate_outputs.shape[1], -1), 521 | candidate_outputs, 522 | ), 523 | 2, 524 | ) 525 | 526 | logits = self.output_dense(concat) 527 | 528 | if labels is not None: 529 | loss_fct = nn.CrossEntropyLoss(ignore_index=-1) 530 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 531 | return logits, loss 532 | else: 533 | return logits 534 | --------------------------------------------------------------------------------