├── maverick ├── utils │ ├── __init__.py │ ├── loggingl.py │ ├── logging.py │ ├── sampler.py │ └── upload_to_hf.py ├── models │ ├── __init__.py │ ├── test.py │ ├── maverick_model.py │ ├── pl_modules.py │ ├── model_incr.py │ ├── model_s2e.py │ └── model_mes.py ├── __init__.py ├── common │ ├── constants.py │ ├── metrics.py │ └── util.py ├── data │ ├── pl_data_modules.py │ └── datasets.py ├── train.py └── evaluate.py ├── conf ├── evaluation │ └── default_evaluation.yaml ├── logging │ └── wandb_logging.yaml ├── root.yaml ├── model │ ├── s2e │ │ ├── deberta-base.yaml │ │ ├── deberta-large.yaml │ │ ├── longformer-base.yaml │ │ └── longformer-large.yaml │ ├── mes │ │ ├── deberta-base.yaml │ │ ├── deberta-large.yaml │ │ ├── longformer-base.yaml │ │ └── longformer-large.yaml │ └── incr │ │ ├── deberta-base.yaml │ │ ├── deberta-large.yaml │ │ ├── longformer-base.yaml │ │ └── longformer-large.yaml ├── data │ ├── gap.yaml │ ├── litbank-ns.yaml │ ├── preco-ns.yaml │ ├── litbank.yaml │ ├── preco.yaml │ ├── wikicoref.yaml │ └── ontonotes.yaml └── train │ └── default_train.yaml ├── requirements.txt ├── setup.py ├── setup.sh ├── .gitignore ├── README.md └── LICENSE.txt /maverick/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /maverick/models/__init__.py: -------------------------------------------------------------------------------- 1 | from maverick.models import * 2 | -------------------------------------------------------------------------------- /conf/evaluation/default_evaluation.yaml: -------------------------------------------------------------------------------- 1 | # put args here 2 | checkpoint: "-" 3 | device: "cuda:0" #cpu 4 | -------------------------------------------------------------------------------- /maverick/__init__.py: -------------------------------------------------------------------------------- 1 | from maverick.models.maverick_model import Maverick 2 | from maverick.models.model_mes import Maverick_mes 3 | from maverick.models.model_s2e import Maverick_s2e 4 | from maverick.models.model_incr import Maverick_incr 5 | -------------------------------------------------------------------------------- /conf/logging/wandb_logging.yaml: -------------------------------------------------------------------------------- 1 | # don't forget loggers.login() for the first usage. 2 | 3 | log: True # set to False to avoid the logging 4 | 5 | wandb_arg: 6 | _target_: pytorch_lightning.loggers.WandbLogger 7 | name: ${train.model_name} 8 | project: ${train.project_name} 9 | save_dir: ./ 10 | log_model: False 11 | mode: 'online' 12 | 13 | watch: 14 | log: 'all' 15 | log_freq: 100 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1.tar.gz 2 | datasets>=2.14 3 | hydra-core==1.3 4 | hydra_colorlog==1.2 5 | jsonlines==4.0.0 6 | nltk==3.8.1 7 | pandas>=2.0 8 | protobuf==3.20 9 | pytorch-lightning>=1.8.0 10 | rich==13.6 11 | scipy>=1.10 12 | sentencepiece==0.2.0 13 | spacy>=3.7 14 | transformers>=4.34 15 | wandb==0.20.1 16 | -------------------------------------------------------------------------------- /conf/root.yaml: -------------------------------------------------------------------------------- 1 | # Required to make the "experiments" dir the default one for the output of the models 2 | hydra: 3 | run: 4 | dir: ./experiments/${train.model_name}/${now:%Y-%m-%d}/${now:%H-%M-%S} 5 | 6 | model_name: "maverick-coref/${model.module.model.huggingface_model_name}" 7 | 8 | defaults: 9 | - _self_ 10 | - train: default_train 11 | - model: mes/deberta-large 12 | - data: ontonotes 13 | - evaluation: default_evaluation 14 | - logging: wandb_logging 15 | - override hydra/job_logging: colorlog 16 | - override hydra/hydra_logging: colorlog 17 | 18 | -------------------------------------------------------------------------------- /conf/model/s2e/deberta-base.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_s2e.Model 18 | language_model: "deberta-v3-base" 19 | huggingface_model_name: "microsoft/deberta-v3-base" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | -------------------------------------------------------------------------------- /conf/model/s2e/deberta-large.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_s2e.Model 18 | language_model: "deberta-v3-large" 19 | huggingface_model_name: "microsoft/deberta-v3-large" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" -------------------------------------------------------------------------------- /conf/model/mes/deberta-base.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_mes.Maverick_mes 18 | language_model: "deberta-v3-base" 19 | huggingface_model_name: "microsoft/deberta-v3-base" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | -------------------------------------------------------------------------------- /conf/model/mes/deberta-large.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_mes.Maverick_mes 18 | language_model: "deberta-v3-large" 19 | huggingface_model_name: "microsoft/deberta-v3-large" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" -------------------------------------------------------------------------------- /conf/model/mes/longformer-base.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_mes.Maverick_mes 18 | language_model: "longformer-base-4096" 19 | huggingface_model_name: "allenai/longformer-base-4096" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" -------------------------------------------------------------------------------- /conf/model/s2e/longformer-base.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_s2e.Model 18 | language_model: "longformer-base-4096" 19 | huggingface_model_name: "allenai/longformer-base-4096" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | -------------------------------------------------------------------------------- /conf/model/s2e/longformer-large.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_s2e.Model 18 | language_model: "longformer-large-4096" 19 | huggingface_model_name: "allenai/longformer-large-4096" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | -------------------------------------------------------------------------------- /conf/data/gap.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | test: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'test' 7 | path: 'data/gap/data/gap-test-onto.jsonl' 8 | batch_size: "${data.datamodule.batch_sizes.train}" 9 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 10 | tokenizer: "${model.module.model.huggingface_model_name}" 11 | 12 | 13 | 14 | batch_sizes: 15 | train: 1 16 | val: 1 17 | test: 1 18 | 19 | num_workers: 20 | train: 0 21 | val: 0 22 | test: 0 23 | 24 | -------------------------------------------------------------------------------- /conf/model/mes/longformer-large.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_mes.Maverick_mes 18 | language_model: "longformer-large-4096" 19 | huggingface_model_name: "allenai/longformer-large-4096" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | -------------------------------------------------------------------------------- /conf/data/litbank-ns.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | test: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'test' 7 | path: 'data/litbank/official/test.english.jsonlines' 8 | batch_size: "${data.datamodule.batch_sizes.train}" 9 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 10 | tokenizer: "${model.module.model.huggingface_model_name}" 11 | 12 | 13 | 14 | batch_sizes: 15 | train: 1 16 | val: 1 17 | test: 1 18 | 19 | num_workers: 20 | train: 0 21 | val: 0 22 | test: 0 23 | 24 | -------------------------------------------------------------------------------- /conf/data/preco-ns.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | test: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'test' 7 | path: 'data/prepare_preco/test_preprocessed_preco_no_singletons.jsonl' 8 | batch_size: "${data.datamodule.batch_sizes.train}" 9 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 10 | tokenizer: "${model.module.model.huggingface_model_name}" 11 | 12 | 13 | 14 | batch_sizes: 15 | train: 1 16 | val: 1 17 | test: 1 18 | 19 | num_workers: 20 | train: 0 21 | val: 0 22 | test: 0 23 | 24 | -------------------------------------------------------------------------------- /conf/model/incr/deberta-base.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_incr.Maverick_incr 18 | language_model: "deberta-v3-base" 19 | huggingface_model_name: "microsoft/deberta-v3-base" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | incremental_model_hidden_size: 768 23 | incremental_model_num_layers: 1 24 | -------------------------------------------------------------------------------- /conf/model/incr/deberta-large.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_incr.Maverick_incr 18 | language_model: "deberta-v3-large" 19 | huggingface_model_name: "microsoft/deberta-v3-large" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | incremental_model_hidden_size: 768 23 | incremental_model_num_layers: 1 24 | -------------------------------------------------------------------------------- /conf/model/incr/longformer-base.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_incr.Maverick_incr 18 | language_model: "longformer-base-4096" 19 | huggingface_model_name: "allenai/longformer-base-4096" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | incremental_model_hidden_size: 768 23 | incremental_model_num_layers: 1 24 | 25 | -------------------------------------------------------------------------------- /conf/model/incr/longformer-large.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | _target_: models.pl_modules.BasePLModule 3 | RAdam: 4 | _target_: torch.optim.RAdam 5 | lr: 2e-5 6 | Adafactor: 7 | _target_: transformers.Adafactor 8 | lr: 3e-5 9 | weight_decay: 0.01 10 | scale_parameter: False 11 | relative_step: False 12 | lr_scheduler: 13 | num_warmup_steps: 6000 14 | num_training_steps: 80000 15 | opt: "Adafactor" #RAdam 16 | model: 17 | _target_: models.model_incr.Maverick_incr 18 | language_model: "longformer-large-4096" 19 | huggingface_model_name: "allenai/longformer-large-4096" 20 | freeze_encoder: False 21 | span_representation: "concat_start_end" 22 | incremental_model_hidden_size: 768 23 | incremental_model_num_layers: 1 24 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="maverick-coref", 5 | version="1.0.7", 6 | long_description=open("README.md").read(), 7 | long_description_content_type="text/markdown", 8 | url="https://github.com/sapienzanlp/maverick-coref", 9 | packages=["maverick", "maverick.models", "maverick.data", "maverick.utils", "maverick.common"], 10 | install_requires=[ 11 | "huggingface-hub>=0.19.0", 12 | "hydra-core>=1.2", 13 | "jsonlines==4.0.0", 14 | "nltk==3.8.1", 15 | "pandas>=1.3.5", 16 | "protobuf==3.20", 17 | "pytorch-lightning>=1.8.0", 18 | "sentencepiece==0.2.0", 19 | "scipy>=1.10", 20 | "spacy>=3.7.2", 21 | "transformers>=4.34", 22 | ], 23 | python_requires=">=3.8.0", 24 | author="Giuliano Martinelli", 25 | author_email="giuliano.martinelli97@gmail.com", 26 | zip_safe=False, 27 | ) 28 | -------------------------------------------------------------------------------- /maverick/common/constants.py: -------------------------------------------------------------------------------- 1 | PRONOUNS_GROUPS = { 2 | "i": 0, 3 | "me": 0, 4 | "my": 0, 5 | "mine": 0, 6 | "myself": 0, 7 | "you": 1, 8 | "your": 1, 9 | "yours": 1, 10 | "yourself": 1, 11 | "yourselves": 1, 12 | "he": 2, 13 | "him": 2, 14 | "his": 2, 15 | "himself": 2, 16 | "she": 3, 17 | "her": 3, 18 | "hers": 3, 19 | "herself": 3, 20 | "it": 4, 21 | "its": 4, 22 | "itself": 4, 23 | "we": 5, 24 | "us": 5, 25 | "our": 5, 26 | "ours": 5, 27 | "ourselves": 5, 28 | "they": 6, 29 | "them": 6, 30 | "their": 6, 31 | "themselves": 6, 32 | "that": 7, 33 | "this": 7, 34 | } 35 | 36 | STOPWORDS = { 37 | "'s", 38 | "a", 39 | "all", 40 | "an", 41 | "and", 42 | "at", 43 | "for", 44 | "from", 45 | "in", 46 | "into", 47 | "more", 48 | "of", 49 | "on", 50 | "or", 51 | "some", 52 | "the", 53 | "these", 54 | "those", 55 | } 56 | 57 | MODELS = ["mes", "s2e", "incr", None] 58 | DATASETS = ["ontonotes", "litbank", "preco"] 59 | CATEGORIES = {"pron-pron-comp": 0, "pron-pron-no-comp": 1, "pron-ent": 2, "match": 3, "contain": 4, "other": 5} 60 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | conda activate maverick-coref 3 | 4 | cd data/prepare_ontonotes 5 | conda create -y --name py27 python=2.7 6 | 7 | eval "$(conda shell.bash hook)" 8 | 9 | dlx() { 10 | wget $1/$2 11 | tar -xvzf $2 12 | rm $2 13 | } 14 | 15 | conll_url=http://conll.cemantix.org/2012/download 16 | dlx $conll_url conll-2012-train.v4.tar.gz 17 | dlx $conll_url conll-2012-development.v4.tar.gz 18 | dlx $conll_url/test conll-2012-test-key.tar.gz 19 | dlx $conll_url/test conll-2012-test-official.v9.tar.gz 20 | dlx $conll_url conll-2012-scripts.v3.tar.gz 21 | 22 | dlx http://conll.cemantix.org/download reference-coreference-scorers.v8.01.tar.gz 23 | mv reference-coreference-scorers conll-2012/scorer 24 | 25 | tar -xzvf ontonotes-release-5.0_LDC2013T19.tgz 26 | ontonotes_path=ontonotes-release-5.0 27 | 28 | conda activate py27 29 | bash conll-2012/v3/scripts/skeleton2conll.sh -D $ontonotes_path/data/files/data conll-2012 30 | 31 | function compile_partition() { 32 | rm -f $2.$5.$3$4 33 | cat conll-2012/$3/data/$1/data/$5/annotations/*/*/*/*.$3$4 >> $2.$5.$3$4 34 | } 35 | 36 | function compile_language() { 37 | compile_partition development dev v4 _gold_conll $1 38 | compile_partition train train v4 _gold_conll $1 39 | compile_partition test test v4 _gold_conll $1 40 | } 41 | 42 | compile_language english 43 | 44 | conda activate maverick-coref 45 | python minimize.py 46 | python prepare_tests_files.py 47 | 48 | -------------------------------------------------------------------------------- /conf/train/default_train.yaml: -------------------------------------------------------------------------------- 1 | # reproducibility 2 | seed: 30 3 | 4 | model_name: ${model_name} # used to name the directory in which model's checkpoints will be stored (experiments/model_name/...) 5 | project_name: maverick # used to name the project in wandb 6 | export: False 7 | 8 | # pl_trainer 9 | pl_trainer: 10 | _target_: pytorch_lightning.Trainer 11 | log_every_n_steps: 10 12 | accelerator: gpu 13 | devices: 1 14 | num_nodes: 1 15 | accumulate_grad_batches: 4 16 | gradient_clip_val: 1.0 17 | val_check_interval: 0.5 # you can specify an int "n" here => validation every "n" steps 18 | max_epochs: 300 19 | deterministic: True 20 | fast_dev_run: False 21 | precision: 16 22 | num_sanity_val_steps: 0 23 | 24 | 25 | # early stopping callback 26 | # "early_stopping_callback: null" will disable early stopping 27 | early_stopping_callback: 28 | _target_: pytorch_lightning.callbacks.EarlyStopping 29 | monitor: val/conll2012_f1_score 30 | mode: max 31 | patience: 120 32 | 33 | # model_checkpoint_callback 34 | # "model_checkpoint_callback: null" will disable model checkpointing 35 | model_checkpoint_callback: 36 | _target_: pytorch_lightning.callbacks.ModelCheckpoint 37 | monitor: val/conll2012_f1_score 38 | mode: max 39 | verbose: True 40 | save_top_k: 1 41 | filename: 'checkpoint-val_f1_{val/conll2012_f1_score:.4f}-epoch_{epoch:02d}' 42 | auto_insert_metric_name: False 43 | 44 | learning_rate_callback: 45 | _target_: pytorch_lightning.callbacks.LearningRateMonitor 46 | logging_interval: "step" 47 | log_momentum: False -------------------------------------------------------------------------------- /conf/data/litbank.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | train: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'train' 7 | path: 'data/litbank/official/train-splitted.english.jsonlines' 8 | batch_size: "${data.datamodule.batch_sizes.train}" 9 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/train' 10 | tokenizer: "${model.module.model.huggingface_model_name}" 11 | max_doc_len: 1500 12 | val: 13 | _target_: data.datasets.OntonotesDataset 14 | name: 'val' 15 | path: 'data/litbank/official/dev.english.jsonlines' 16 | batch_size: "${data.datamodule.batch_sizes.train}" 17 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/val' 18 | tokenizer: "${model.module.model.huggingface_model_name}" 19 | test: 20 | _target_: data.datasets.OntonotesDataset 21 | name: 'test' 22 | path: 'data/litbank/official/test.english.jsonlines' 23 | 24 | batch_size: "${data.datamodule.batch_sizes.train}" 25 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 26 | tokenizer: "${model.module.model.huggingface_model_name}" 27 | 28 | 29 | 30 | batch_sizes: 31 | train: 1 32 | val: 1 33 | test: 1 34 | 35 | num_workers: 36 | train: 0 37 | val: 0 38 | test: 0 39 | 40 | -------------------------------------------------------------------------------- /conf/data/preco.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | train: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'train' 7 | path: 'data/prepare_preco/train_36120_preprocessed_preco.jsonl' 8 | batch_size: "${data.datamodule.batch_sizes.train}" 9 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/train' 10 | tokenizer: "${model.module.model.huggingface_model_name}" 11 | max_doc_len: 1500 12 | val: 13 | _target_: data.datasets.OntonotesDataset 14 | name: 'val' 15 | path: 'data/prepare_preco/dev_500_preprocessed_preco.jsonl' 16 | batch_size: "${data.datamodule.batch_sizes.train}" 17 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/val' 18 | tokenizer: "${model.module.model.huggingface_model_name}" 19 | test: 20 | _target_: data.datasets.OntonotesDataset 21 | name: 'test' 22 | path: 'data/prepare_preco/test_preprocessed_preco.jsonl' 23 | batch_size: "${data.datamodule.batch_sizes.train}" 24 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 25 | tokenizer: "${model.module.model.huggingface_model_name}" 26 | 27 | 28 | 29 | batch_sizes: 30 | train: 1 31 | val: 1 32 | test: 1 33 | 34 | num_workers: 35 | train: 0 36 | val: 0 37 | test: 0 38 | 39 | -------------------------------------------------------------------------------- /conf/data/wikicoref.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | train: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'train' 7 | path: 'data/prepare_ontonotes/train.english.jsonlines' 8 | # path: 'data/prepare_preco/train_36120_preprocessed_preco.jsonl' 9 | # path: 'data/litbank/official/train_enhanced_1.english.jsonlines' 10 | # path: 'data/prepare_ontonotes_arabic/train.arabic.jsonlines' 11 | #path: 'data/prepare_ontonotes_chinese/train.chinese.jsonlines' 12 | batch_size: "${data.datamodule.batch_sizes.train}" 13 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/train' 14 | tokenizer: "${model.module.model.huggingface_model_name}" 15 | max_doc_len: 1500 16 | val: 17 | _target_: data.datasets.OntonotesDataset 18 | name: 'val' 19 | path: 'data/prepare_ontonotes/dev.english.jsonlines' 20 | # path: 'data/prepare_preco/dev_500_preprocessed_preco.jsonl' 21 | # path: 'data/litbank/official/dev.english.jsonlines' 22 | # path: 'data/prepare_ontonotes_arabic/dev.arabic.jsonlines' 23 | # path: 'data/prepare_ontonotes_chinese/dev.chinese.jsonlines' 24 | batch_size: "${data.datamodule.batch_sizes.train}" 25 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/val' 26 | tokenizer: "${model.module.model.huggingface_model_name}" 27 | test: 28 | _target_: data.datasets.OntonotesDataset 29 | name: 'test' 30 | path: 'data/wkc/test.jsonlines' 31 | batch_size: "${data.datamodule.batch_sizes.train}" 32 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 33 | tokenizer: "${model.module.model.huggingface_model_name}" 34 | 35 | 36 | 37 | batch_sizes: 38 | train: 1 39 | val: 1 40 | test: 1 41 | 42 | num_workers: 43 | train: 0 44 | val: 0 45 | test: 0 46 | 47 | -------------------------------------------------------------------------------- /maverick/data/pl_data_modules.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Any, Union, List, Optional, Sequence 3 | 4 | import hydra 5 | import pytorch_lightning as pl 6 | import torch 7 | from omegaconf import DictConfig 8 | from pytorch_lightning.utilities.types import EVAL_DATALOADERS 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | 12 | class BasePLDataModule(pl.LightningDataModule): 13 | def __init__( 14 | self, 15 | dataset: DictConfig, 16 | batch_sizes: DictConfig, 17 | num_workers: DictConfig, 18 | *args, 19 | **kwargs, 20 | ): 21 | super().__init__() 22 | self.dataset = dataset 23 | self.num_workers = num_workers 24 | self.batch_sizes = batch_sizes 25 | 26 | # data 27 | self.train_dataset: Optional[Dataset] = None 28 | self.val_dataset: Optional[Dataset] = None 29 | self.test_dataset: Optional[Dataset] = None 30 | 31 | def setup(self, stage: Optional[str] = None): 32 | if stage == "fit" or stage is None: 33 | self.val_dataset = hydra.utils.instantiate(self.dataset.val) 34 | self.train_dataset = hydra.utils.instantiate(self.dataset.train) 35 | if stage == "test" or stage is None: 36 | self.test_dataset = hydra.utils.instantiate(self.dataset.test) 37 | 38 | def train_dataloader(self, *args, **kwargs) -> DataLoader: 39 | return DataLoader( 40 | self.train_dataset, 41 | shuffle=True, 42 | batch_size=self.batch_sizes.train, 43 | num_workers=self.num_workers.train, 44 | collate_fn=self.train_dataset.collate_fn, 45 | ) 46 | 47 | def val_dataloader(self, *args, **kwargs) -> DataLoader: 48 | return DataLoader( 49 | self.val_dataset, 50 | shuffle=False, 51 | batch_size=self.batch_sizes.val, 52 | num_workers=self.num_workers.val, 53 | collate_fn=self.val_dataset.collate_fn, 54 | ) 55 | 56 | def test_dataloader(self, *args, **kwargs) -> DataLoader: 57 | return DataLoader( 58 | self.test_dataset, 59 | shuffle=False, 60 | batch_size=self.batch_sizes.test, 61 | num_workers=self.num_workers.test, 62 | collate_fn=self.test_dataset.collate_fn, 63 | ) 64 | 65 | def __repr__(self) -> str: 66 | return f"{self.__class__.__name__}(" f"{self.dataset=}, " f"{self.num_workers=}, " f"{self.batch_sizes=})" 67 | -------------------------------------------------------------------------------- /conf/data/ontonotes.yaml: -------------------------------------------------------------------------------- 1 | datamodule: 2 | _target_: data.pl_data_modules.BasePLDataModule 3 | dataset: 4 | train: 5 | _target_: data.datasets.OntonotesDataset 6 | name: 'train' 7 | path: 'data/prepare_ontonotes/train.english.jsonlines' 8 | # path: 'data/prepare_preco/train_36120_preprocessed_preco.jsonl' 9 | # path: 'data/litbank/official/train_enhanced_1.english.jsonlines' 10 | # path: 'data/prepare_ontonotes_arabic/train.arabic.jsonlines' 11 | #path: 'data/prepare_ontonotes_chinese/train.chinese.jsonlines' 12 | batch_size: "${data.datamodule.batch_sizes.train}" 13 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/train' 14 | tokenizer: "${model.module.model.huggingface_model_name}" 15 | max_doc_len: 1500 16 | val: 17 | _target_: data.datasets.OntonotesDataset 18 | name: 'val' 19 | path: 'data/prepare_ontonotes/dev.english.jsonlines' 20 | # path: 'data/prepare_preco/dev_500_preprocessed_preco.jsonl' 21 | # path: 'data/litbank/official/dev.english.jsonlines' 22 | # path: 'data/prepare_ontonotes_arabic/dev.arabic.jsonlines' 23 | # path: 'data/prepare_ontonotes_chinese/dev.chinese.jsonlines' 24 | batch_size: "${data.datamodule.batch_sizes.train}" 25 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/val' 26 | tokenizer: "${model.module.model.huggingface_model_name}" 27 | test: 28 | _target_: data.datasets.OntonotesDataset 29 | name: 'test' 30 | path: 'data/prepare_ontonotes/test.english.jsonlines' 31 | # path: 'data/prepare_preco/test_preprocessed_preco.jsonl' 32 | # path: 'data/litbank/official/test.english.jsonlines' 33 | # path: 'data/prepare_ontonotes_arabic/test.arabic.jsonlines' 34 | # path: 'data/prepare_ontonotes_chinese/test.chinese.jsonlines' 35 | # path: 'data/longtonotes/test/test.english.jsonlines' 36 | # path: 'data/gap/gap-test-ontoformat_3.jsonl' 37 | # path: 'data/wkc/test.jsonlines' 38 | # path: 'data/prepare_preco/test.english.jsonlines' #gap/gap-ontoformat-jsonl 39 | 40 | batch_size: "${data.datamodule.batch_sizes.train}" 41 | processed_dataset_path: 'data/cache_${data.datamodule.dataset.train.max_doc_len}/${model.module.model.huggingface_model_name}/ontonotes/test' 42 | tokenizer: "${model.module.model.huggingface_model_name}" 43 | 44 | 45 | 46 | batch_sizes: 47 | train: 1 48 | val: 1 49 | test: 1 50 | 51 | num_workers: 52 | train: 0 53 | val: 0 54 | test: 0 55 | 56 | -------------------------------------------------------------------------------- /maverick/models/test.py: -------------------------------------------------------------------------------- 1 | from maverick import Maverick 2 | 3 | a = Maverick() 4 | # r = a.predict("Hi i am Giuliano, i like pasta. Giuliano, my cool name") 5 | print( 6 | a.predict( 7 | [ 8 | [ 9 | "--", 10 | "basically", 11 | ",", 12 | "it", 13 | "was", 14 | "unanimously", 15 | "agreed", 16 | "upon", 17 | "by", 18 | "the", 19 | "various", 20 | "relevant", 21 | "parties", 22 | ".", 23 | ], 24 | [ 25 | "To", 26 | "express", 27 | "its", 28 | "determination", 29 | ",", 30 | "the", 31 | "Chinese", 32 | "securities", 33 | "regulatory", 34 | "department", 35 | "compares", 36 | "this", 37 | "stock", 38 | "reform", 39 | "to", 40 | "a", 41 | "die", 42 | "that", 43 | "has", 44 | "been", 45 | "cast", 46 | ".", 47 | ], 48 | [ 49 | "It", 50 | "takes", 51 | "time", 52 | "to", 53 | "prove", 54 | "whether", 55 | "the", 56 | "stock", 57 | "reform", 58 | "can", 59 | "really", 60 | "meet", 61 | "expectations", 62 | ",", 63 | "and", 64 | "whether", 65 | "any", 66 | "deviations", 67 | "that", 68 | "arise", 69 | "during", 70 | "the", 71 | "stock", 72 | "reform", 73 | "can", 74 | "be", 75 | "promptly", 76 | "corrected", 77 | ".", 78 | ], 79 | ["Dear", "viewers", ",", "the", "China", "News", "program", "will", "end", "here", "."], 80 | ["This", "is", "Xu", "Li", "."], 81 | ["Thank", "you", "everyone", "for", "watching", "."], 82 | ["Coming", "up", "is", "the", "Focus", "Today", "program", "hosted", "by", "Wang", "Shilin", "."], 83 | ["Good-bye", ",", "dear", "viewers", "."], 84 | ], 85 | # predefined_mentions=[[16, 16], [19, 23], [42, 44], [57, 59], [25, 27], [83, 83], [82, 82]], 86 | add_gold_clusters=[[[1, 1], [2, 2]]], 87 | singletons=False, 88 | ) 89 | ) 90 | -------------------------------------------------------------------------------- /maverick/utils/loggingl.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import threading 4 | from typing import Optional 5 | 6 | from rich.console import Console 7 | 8 | _lock = threading.Lock() 9 | _default_handler: Optional[logging.Handler] = None 10 | 11 | _default_log_level = logging.WARNING 12 | 13 | # fancy logger 14 | _console = Console() 15 | 16 | 17 | def _get_library_name() -> str: 18 | return __name__.split(".")[0] 19 | 20 | 21 | def _get_library_root_logger() -> logging.Logger: 22 | return logging.getLogger(_get_library_name()) 23 | 24 | 25 | def _configure_library_root_logger() -> None: 26 | 27 | global _default_handler 28 | 29 | with _lock: 30 | if _default_handler: 31 | # This library has already configured the library root logger. 32 | return 33 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 34 | _default_handler.flush = sys.stderr.flush 35 | 36 | # Apply our default configuration to the library root logger. 37 | library_root_logger = _get_library_root_logger() 38 | library_root_logger.addHandler(_default_handler) 39 | library_root_logger.setLevel(_default_log_level) 40 | library_root_logger.propagate = False 41 | 42 | 43 | def _reset_library_root_logger() -> None: 44 | 45 | global _default_handler 46 | 47 | with _lock: 48 | if not _default_handler: 49 | return 50 | 51 | library_root_logger = _get_library_root_logger() 52 | library_root_logger.removeHandler(_default_handler) 53 | library_root_logger.setLevel(logging.NOTSET) 54 | _default_handler = None 55 | 56 | 57 | def set_log_level(level: int, logger: logging.Logger = None) -> None: 58 | """ 59 | Set the log level. 60 | Args: 61 | level (:obj:`int`): 62 | Logging level. 63 | logger (:obj:`logging.Logger`): 64 | Logger to set the log level. 65 | """ 66 | if not logger: 67 | _configure_library_root_logger() 68 | logger = _get_library_root_logger() 69 | logger.setLevel(level) 70 | 71 | 72 | def get_logger( 73 | name: Optional[str] = None, 74 | level: Optional[int] = None, 75 | formatter: Optional[str] = None, 76 | ) -> logging.Logger: 77 | """ 78 | Return a logger with the specified name. 79 | """ 80 | 81 | if name is None: 82 | name = _get_library_name() 83 | 84 | _configure_library_root_logger() 85 | 86 | if level is not None: 87 | set_log_level(level) 88 | 89 | if formatter is None: 90 | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s") 91 | _default_handler.setFormatter(formatter) 92 | 93 | return logging.getLogger(name) 94 | 95 | 96 | def get_console_logger(): 97 | return _console 98 | -------------------------------------------------------------------------------- /maverick/utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import threading 4 | from typing import Optional 5 | 6 | from rich.console import Console 7 | 8 | _lock = threading.Lock() 9 | _default_handler: Optional[logging.Handler] = None 10 | 11 | _default_log_level = logging.WARNING 12 | 13 | # fancy logger 14 | _console = Console() 15 | 16 | 17 | def _get_library_name() -> str: 18 | return __name__.split(".")[0] 19 | 20 | 21 | def _get_library_root_logger() -> logging.Logger: 22 | return logging.getLogger(_get_library_name()) 23 | 24 | 25 | def _configure_library_root_logger() -> None: 26 | 27 | global _default_handler 28 | 29 | with _lock: 30 | if _default_handler: 31 | # This library has already configured the library root logger. 32 | return 33 | _default_handler = logging.StreamHandler() # Set sys.stderr as stream. 34 | _default_handler.flush = sys.stderr.flush 35 | 36 | # Apply our default configuration to the library root logger. 37 | library_root_logger = _get_library_root_logger() 38 | library_root_logger.addHandler(_default_handler) 39 | library_root_logger.setLevel(_default_log_level) 40 | library_root_logger.propagate = False 41 | 42 | 43 | def _reset_library_root_logger() -> None: 44 | 45 | global _default_handler 46 | 47 | with _lock: 48 | if not _default_handler: 49 | return 50 | 51 | library_root_logger = _get_library_root_logger() 52 | library_root_logger.removeHandler(_default_handler) 53 | library_root_logger.setLevel(logging.NOTSET) 54 | _default_handler = None 55 | 56 | 57 | def set_log_level(level: int, logger: logging.Logger = None) -> None: 58 | """ 59 | Set the log level. 60 | Args: 61 | level (:obj:`int`): 62 | Logging level. 63 | logger (:obj:`logging.Logger`): 64 | Logger to set the log level. 65 | """ 66 | if not logger: 67 | _configure_library_root_logger() 68 | logger = _get_library_root_logger() 69 | logger.setLevel(level) 70 | 71 | 72 | def get_logger( 73 | name: Optional[str] = None, 74 | level: Optional[int] = None, 75 | formatter: Optional[str] = None, 76 | ) -> logging.Logger: 77 | """ 78 | Return a logger with the specified name. 79 | """ 80 | 81 | if name is None: 82 | name = _get_library_name() 83 | 84 | _configure_library_root_logger() 85 | 86 | if level is not None: 87 | set_log_level(level) 88 | 89 | if formatter is None: 90 | formatter = logging.Formatter( 91 | "%(asctime)s - %(levelname)s - %(name)s - %(message)s" 92 | ) 93 | _default_handler.setFormatter(formatter) 94 | 95 | return logging.getLogger(name) 96 | 97 | 98 | def get_console_logger(): 99 | return _console 100 | -------------------------------------------------------------------------------- /maverick/utils/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from torch.utils.data.sampler import BatchSampler 4 | from torch.utils.data.sampler import Sampler 5 | from torch.utils.data.sampler import SubsetRandomSampler 6 | 7 | 8 | def identity(x): 9 | return x 10 | 11 | 12 | class SortedSampler(Sampler): 13 | """ 14 | Samples elements sequentially, always in the same order. 15 | 16 | Args: 17 | data (`obj`: `Iterable`): 18 | Iterable data. 19 | sort_key (`obj`: `Callable`): 20 | Specifies a function of one argument that is used to 21 | extract a numerical comparison key from each list element. 22 | 23 | Example: 24 | >>> list(SortedSampler(range(10), sort_key=lambda i: -i)) 25 | [9, 8, 7, 6, 5, 4, 3, 2, 1, 0] 26 | 27 | """ 28 | 29 | def __init__(self, data, sort_key=identity): 30 | super().__init__(data) 31 | self.data = data 32 | self.sort_key = sort_key 33 | zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)] 34 | zip_ = sorted(zip_, key=lambda r: r[1]) 35 | self.sorted_indexes = [item[0] for item in zip_] 36 | 37 | def __iter__(self): 38 | return iter(self.sorted_indexes) 39 | 40 | def __len__(self): 41 | return len(self.data) 42 | 43 | 44 | class BucketBatchSampler(BatchSampler): 45 | """ 46 | `BucketBatchSampler` toggles between `sampler` batches and sorted batches. 47 | Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between 48 | random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice 49 | versa. 50 | Background: 51 | ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like 52 | ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar 53 | size length to reduce the padding required for each batch while maintaining some noise 54 | through bucketing. 55 | **AllenNLP Implementation:** 56 | https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py 57 | **torchtext Implementation:** 58 | https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225 59 | 60 | Args: 61 | sampler (`obj`: `torch.data.utils.sampler.Sampler): 62 | batch_size (`int`): 63 | Size of mini-batch. 64 | drop_last (`bool`, optional, defaults to `False`): 65 | If `True` the sampler will drop the last batch if its size would be less than `batch_size`. 66 | sort_key (`obj`: `Callable`, optional, defaults to `identity`): 67 | Callable to specify a comparison key for sorting. 68 | bucket_size_multiplier (`int`, optional, defaults to `100`): 69 | Buckets are of size `batch_size * bucket_size_multiplier`. 70 | Example: 71 | >>> from torchnlp.random import set_seed 72 | >>> set_seed(123) 73 | >>> 74 | >>> from torch.utils.data.sampler import SequentialSampler 75 | >>> sampler = SequentialSampler(list(range(10))) 76 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=False)) 77 | [[6, 7, 8], [0, 1, 2], [3, 4, 5], [9]] 78 | >>> list(BucketBatchSampler(sampler, batch_size=3, drop_last=True)) 79 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 80 | 81 | """ 82 | 83 | def __init__( 84 | self, 85 | sampler, 86 | batch_size, 87 | drop_last: bool = False, 88 | sort_key=identity, 89 | bucket_size_multiplier=100, 90 | ): 91 | super().__init__(sampler, batch_size, drop_last) 92 | self.sort_key = sort_key 93 | _bucket_size = batch_size * bucket_size_multiplier 94 | if hasattr(sampler, "__len__"): 95 | _bucket_size = min(_bucket_size, len(sampler)) 96 | self.bucket_sampler = BatchSampler(sampler, _bucket_size, False) 97 | 98 | def __iter__(self): 99 | for bucket in self.bucket_sampler: 100 | sorted_sampler = SortedSampler(bucket, self.sort_key) 101 | for batch in SubsetRandomSampler( 102 | list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last)) 103 | ): 104 | yield [bucket[i] for i in batch] 105 | 106 | def __len__(self): 107 | if self.drop_last: 108 | return len(self.sampler) // self.batch_size 109 | else: 110 | return math.ceil(len(self.sampler) / self.batch_size) 111 | -------------------------------------------------------------------------------- /maverick/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hydra 3 | import omegaconf 4 | import pytorch_lightning as pl 5 | import torch 6 | 7 | from pathlib import Path 8 | from typing import Optional 9 | from omegaconf import OmegaConf 10 | from pytorch_lightning import Trainer 11 | from pytorch_lightning.callbacks import ( 12 | EarlyStopping, 13 | ModelCheckpoint, 14 | LearningRateMonitor, 15 | ) 16 | from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBar 17 | from pytorch_lightning.loggers import WandbLogger 18 | from rich.console import Console 19 | 20 | from maverick.data.pl_data_modules import BasePLDataModule 21 | from maverick.models.pl_modules import BasePLModule 22 | 23 | torch.set_printoptions(edgeitems=100) 24 | 25 | 26 | def train(conf: omegaconf.DictConfig) -> None: 27 | # fancy logger 28 | console = Console() 29 | # reproducibility 30 | pl.seed_everything(conf.train.seed) 31 | set_determinism_the_old_way(conf.train.pl_trainer.deterministic) 32 | conf.train.pl_trainer.deterministic = True 33 | 34 | console.log(f"Starting training for [bold cyan]{conf.train.model_name}[/bold cyan] model") 35 | if conf.train.pl_trainer.fast_dev_run: 36 | console.log(f"Debug mode {conf.train.pl_trainer.fast_dev_run}. Forcing debugger configuration") 37 | # Debuggers don't like GPUs nor multiprocessing 38 | conf.train.pl_trainer.accelerator = "cpu" 39 | 40 | conf.train.pl_trainer.precision = 32 41 | conf.data.datamodule.num_workers = {k: 0 for k in conf.data.datamodule.num_workers} 42 | # Switch wandb to offline mode to prevent online logging 43 | conf.logging.log = None 44 | # remove model checkpoint callback 45 | conf.train.model_checkpoint_callback = None 46 | 47 | # data module declaration 48 | console.log(f"Instantiating the Data Module") 49 | pl_data_module: BasePLDataModule = hydra.utils.instantiate(conf.data.datamodule, _recursive_=False) 50 | # force setup to get labels initialized for the model 51 | pl_data_module.prepare_data() 52 | pl_data_module.setup("fit") 53 | 54 | # main module declaration 55 | console.log(f"Instantiating the Model") 56 | 57 | pl_module: BasePLModule = hydra.utils.instantiate(conf.model.module, _recursive_=False) 58 | 59 | # pl_module = BasePLModule.load_from_checkpoint(conf.evaluation.checkpoint, _recursive_=False, map_location="cuda:0") 60 | experiment_logger: Optional[WandbLogger] = None 61 | experiment_path: Optional[Path] = None 62 | if conf.logging.log: 63 | console.log(f"Instantiating Wandb Logger") 64 | experiment_logger = hydra.utils.instantiate(conf.logging.wandb_arg) 65 | experiment_logger.watch(pl_module, **conf.logging.watch) 66 | experiment_path = Path(experiment_logger.experiment.dir) 67 | # Store the YaML config separately into the wandb dir 68 | yaml_conf: str = OmegaConf.to_yaml(cfg=conf) 69 | (experiment_path / "hparams.yaml").write_text(yaml_conf) 70 | 71 | # callbacks declaration 72 | callbacks_store = [RichProgressBar()] 73 | 74 | if conf.train.early_stopping_callback is not None: 75 | early_stopping_callback: EarlyStopping = hydra.utils.instantiate(conf.train.early_stopping_callback) 76 | callbacks_store.append(early_stopping_callback) 77 | 78 | if conf.train.model_checkpoint_callback is not None: 79 | model_checkpoint_callback: ModelCheckpoint = hydra.utils.instantiate( 80 | conf.train.model_checkpoint_callback, 81 | dirpath=experiment_path / "checkpoints", 82 | ) 83 | callbacks_store.append(model_checkpoint_callback) 84 | 85 | if conf.train.learning_rate_callback is not None and not conf.train.pl_trainer.fast_dev_run: 86 | lr: LearningRateMonitor = hydra.utils.instantiate(conf.train.learning_rate_callback) 87 | callbacks_store.append(lr) 88 | # trainer 89 | console.log(f"Instantiating the Trainer") 90 | trainer: Trainer = hydra.utils.instantiate(conf.train.pl_trainer, callbacks=callbacks_store, logger=experiment_logger) 91 | 92 | # module fit 93 | trainer.fit(pl_module, datamodule=pl_data_module) 94 | 95 | # module test 96 | trainer.test(pl_module, datamodule=pl_data_module) 97 | 98 | 99 | def set_determinism_the_old_way(deterministic: bool): 100 | # determinism for cudnn 101 | torch.backends.cudnn.deterministic = deterministic 102 | if deterministic: 103 | # fixing non-deterministic part of horovod 104 | # https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383 105 | os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0) 106 | 107 | 108 | @hydra.main(config_path="../conf", config_name="root") 109 | def main(conf: omegaconf.DictConfig): 110 | print(OmegaConf.to_yaml(conf)) 111 | train(conf) 112 | 113 | 114 | if __name__ == "__main__": 115 | main() 116 | -------------------------------------------------------------------------------- /maverick/utils/upload_to_hf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import tempfile 5 | import zipfile 6 | from datetime import datetime 7 | from pathlib import Path 8 | from typing import Optional, Union 9 | 10 | from maverick.utils.loggingl import get_console_logger 11 | 12 | logger = get_console_logger() 13 | 14 | import huggingface_hub 15 | 16 | 17 | def get_md5(path: Path): 18 | """ 19 | Get the MD5 value of a path. 20 | """ 21 | import hashlib 22 | 23 | with path.open("rb") as fin: 24 | data = fin.read() 25 | return hashlib.md5(data).hexdigest() 26 | 27 | 28 | def create_info_file(tmpdir: Path): 29 | # logger.debug("Computing md5 of model.zip") 30 | md5 = get_md5(tmpdir / "model.zip") 31 | date = datetime.now() 32 | 33 | # logger.debug("Dumping info.json file") 34 | with (tmpdir / "info.json").open("w") as f: 35 | json.dump(dict(md5=md5, upload_date=date), f, indent=2) 36 | 37 | 38 | def zip_run( 39 | dir_path: Union[str, os.PathLike], 40 | tmpdir: Union[str, os.PathLike], 41 | zip_name: str = "model.zip", 42 | ) -> Path: 43 | # logger.debug(f"zipping {dir_path} to {tmpdir}") 44 | # creates a zip version of the provided dir_path 45 | run_dir = Path(dir_path) 46 | zip_path = tmpdir / zip_name 47 | 48 | with zipfile.ZipFile(zip_path, "w") as zip_file: 49 | # fully zip the run directory maintaining its structure 50 | for file in run_dir.rglob("*.*"): 51 | if file.is_dir(): 52 | continue 53 | 54 | zip_file.write(file, arcname=file.relative_to(run_dir)) 55 | 56 | return zip_path 57 | 58 | 59 | def get_logged_in_username(): 60 | token = huggingface_hub.HfFolder.get_token() 61 | if token is None: 62 | raise ValueError("No HuggingFace token found. You need to execute `huggingface-cli login` first!") 63 | api = huggingface_hub.HfApi() 64 | user = api.whoami(token=token) 65 | return user["name"] 66 | 67 | 68 | def upload( 69 | model_dir: Union[str, os.PathLike], 70 | model_name: str, 71 | filenames: Optional[list[str]] = None, 72 | organization: str | None = None, 73 | repo_name: str | None = None, 74 | commit: str | None = None, 75 | archive: bool = False, 76 | ): 77 | token = huggingface_hub.HfFolder.get_token() 78 | if token is None: 79 | raise ValueError("No HuggingFace token found. You need to execute `huggingface-cli login` first!") 80 | 81 | repo_id = repo_name or model_name 82 | if organization is not None: 83 | repo_id = f"{organization}/{repo_id}" 84 | with tempfile.TemporaryDirectory() as tmpdir: 85 | api = huggingface_hub.HfApi() 86 | repo_url = api.create_repo( 87 | token=token, 88 | repo_id=repo_id, 89 | exist_ok=True, 90 | ) 91 | repo = huggingface_hub.Repository(str(tmpdir), clone_from=repo_url, use_auth_token=token) 92 | 93 | tmp_path = Path(tmpdir) 94 | if archive: 95 | # otherwise we zip the model_dir 96 | # logger.debug(f"Zipping {model_dir} to {tmp_path}") 97 | zip_run(model_dir, tmp_path) 98 | create_info_file(tmp_path) 99 | else: 100 | # if the user wants to upload a transformers model, we don't need to zip it 101 | # we just need to copy the files to the tmpdir 102 | # logger.debug(f"Copying {model_dir} to {tmpdir}") 103 | # copy only the files that are needed 104 | if filenames is not None: 105 | for filename in filenames: 106 | os.system(f"cp {model_dir}/{filename} {tmpdir}") 107 | else: 108 | os.system(f"cp -r {model_dir}/* {tmpdir}") 109 | 110 | # this method automatically puts large files (>10MB) into git lfs 111 | repo.push_to_hub(commit_message=commit or "initial commit") 112 | 113 | 114 | def parse_args() -> argparse.Namespace: 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("model_dir", help="The directory of the model you want to upload") 117 | parser.add_argument("model_name", help="The model you want to upload") 118 | parser.add_argument( 119 | "--organization", 120 | help="the name of the organization where you want to upload the model", 121 | ) 122 | parser.add_argument( 123 | "--repo_name", 124 | help="Optional name to use when uploading to the HuggingFace repository", 125 | ) 126 | parser.add_argument("--commit", help="Commit message to use when pushing to the HuggingFace Hub") 127 | parser.add_argument( 128 | "--archive", 129 | action="store_true", 130 | help=""" 131 | Whether to compress the model directory before uploading it. 132 | If True, the model directory will be zipped and the zip file will be uploaded. 133 | If False, the model directory will be uploaded as is.""", 134 | ) 135 | return parser.parse_args() 136 | 137 | 138 | upload(**vars(parse_args())) 139 | -------------------------------------------------------------------------------- /maverick/common/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import Counter 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | 6 | # Mention evaluation performed as scores on set of mentions. 7 | class OfficialMentionEvaluator: 8 | def __init__(self): 9 | self.tp, self.fp, self.fn = 0, 0, 0 10 | 11 | # takes predicted mentions and gold mentions: 12 | # a mention is [[(wstart, wend)]] 13 | # calculates tp, fp, and fn as set operations 14 | # returns prf using formula 15 | def update(self, predicted_mentions, gold_mentions): 16 | predicted_mentions = set(predicted_mentions) 17 | gold_mentions = set(gold_mentions) 18 | 19 | self.tp += len(predicted_mentions & gold_mentions) 20 | self.fp += len(predicted_mentions - gold_mentions) 21 | self.fn += len(gold_mentions - predicted_mentions) 22 | 23 | def get_f1(self): 24 | pr = self.get_precision() 25 | rec = self.get_recall() 26 | return 2 * pr * rec / (pr + rec) if pr + rec > 0 else 0.0 27 | 28 | def get_recall(self): 29 | return self.tp / (self.tp + self.fn) if (self.tp + self.fn) > 0 else 0.0 30 | 31 | def get_precision(self): 32 | return self.tp / (self.tp + self.fp) if (self.tp + self.fp) > 0 else 0.0 33 | 34 | def get_prf(self): 35 | return self.get_precision(), self.get_recall(), self.get_f1() 36 | 37 | 38 | # CoNLL-2012 metrics https://aclanthology.org/W12-4501.pdf#page=19 39 | class OfficialCoNLL2012CorefEvaluator(object): 40 | def __init__(self): 41 | self.evaluators = [OfficialEvaluator(m) for m in (muc, b_cubed, ceafe)] 42 | self.metrics = {"muc": 0, "b_cubed": 1, "ceafe": 2} 43 | 44 | # pred & golds : [[((wstart, wend),()), (())...]], mention_to_pred and mention_to_gold: [{(wstart,wend):((wstart, wend)...)}] 45 | def update(self, pred, gold, mention_to_predicted, mention_to_gold): 46 | for e in self.evaluators: 47 | e.update(pred, gold, mention_to_predicted, mention_to_gold) 48 | 49 | def get_f1(self, metric): 50 | if metric == "conll2012": 51 | return sum(e.get_f1() for e in self.evaluators) / len(self.evaluators) 52 | else: 53 | return self.evaluators[self.metrics[metric]].get_f1() 54 | 55 | def get_recall(self, metric): 56 | if metric == "conll2012": 57 | return sum(e.get_recall() for e in self.evaluators) / len(self.evaluators) 58 | else: 59 | return self.evaluators[self.metrics[metric]].get_recall() 60 | 61 | def get_precision(self, metric): 62 | if metric == "conll2012": 63 | return sum(e.get_precision() for e in self.evaluators) / len(self.evaluators) 64 | else: 65 | return self.evaluators[self.metrics[metric]].get_precision() 66 | 67 | def get_prf(self, metric): 68 | return self.get_precision(metric), self.get_recall(metric), self.get_f1(metric) 69 | 70 | 71 | class OfficialEvaluator(object): 72 | def __init__(self, metric, beta=1): 73 | self.p_num = 0 74 | self.p_den = 0 75 | self.r_num = 0 76 | self.r_den = 0 77 | self.metric = metric 78 | self.beta = beta 79 | 80 | def update(self, predicted, gold, mention_to_predicted, mention_to_gold): 81 | if self.metric == ceafe: 82 | pn, pd, rn, rd = self.metric(predicted, gold) 83 | else: 84 | pn, pd = self.metric(predicted, mention_to_gold) 85 | rn, rd = self.metric(gold, mention_to_predicted) 86 | self.p_num += pn 87 | self.p_den += pd 88 | self.r_num += rn 89 | self.r_den += rd 90 | 91 | def get_f1(self): 92 | return f1(self.p_num, self.p_den, self.r_num, self.r_den, beta=self.beta) 93 | 94 | def get_recall(self): 95 | return 0 if self.r_num == 0 else self.r_num / float(self.r_den) 96 | 97 | def get_precision(self): 98 | return 0 if self.p_num == 0 else self.p_num / float(self.p_den) 99 | 100 | def get_prf(self): 101 | return self.get_precision(), self.get_recall(), self.get_f1() 102 | 103 | def get_counts(self): 104 | return self.p_num, self.p_den, self.r_num, self.r_den 105 | 106 | 107 | def f1(p_num, p_den, r_num, r_den, beta=1): 108 | p = 0 if p_den == 0 else p_num / float(p_den) 109 | r = 0 if r_den == 0 else r_num / float(r_den) 110 | return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r) 111 | 112 | 113 | # mention based metric (Bagga and Baldwin, 1998) 114 | def b_cubed(clusters, mention_to_gold): 115 | num, dem = 0, 0 116 | 117 | for c in clusters: 118 | gold_counts = Counter() 119 | correct = 0 120 | for m in c: 121 | if m in mention_to_gold: 122 | gold_counts[tuple(mention_to_gold[m])] += 1 123 | for c2, count in gold_counts.items(): 124 | correct += count * count 125 | if len(c) != 0: 126 | num += correct / float(len(c)) 127 | dem += len(c) 128 | 129 | return num, dem 130 | 131 | 132 | # link based metric (Vilain et al. 2995). 133 | def muc(clusters, mention_to_gold): 134 | tp, p = 0, 0 135 | for c in clusters: 136 | p += len(c) - 1 137 | tp += len(c) 138 | linked = set() 139 | for m in c: 140 | if m in mention_to_gold: 141 | linked.add(mention_to_gold[m]) 142 | else: 143 | tp -= 1 144 | tp -= len(linked) 145 | return tp, p 146 | 147 | 148 | # Entity based metric (Luo, 2005) 149 | def ceafe(clusters, gold_clusters): 150 | # clusters = [c for c in clusters if len(c) != 1] 151 | scores = np.zeros((len(gold_clusters), len(clusters))) 152 | for i in range(len(gold_clusters)): 153 | for j in range(len(clusters)): 154 | scores[i, j] = phi4(gold_clusters[i], clusters[j]) 155 | row_ind, col_ind = linear_sum_assignment(-scores) 156 | similarity = sum(scores[row_ind, col_ind]) 157 | return similarity, len(clusters), similarity, len(gold_clusters) 158 | 159 | 160 | def phi4(c1, c2): 161 | return 2 * len([m for m in c1 if m in c2]) / float(len(c1) + len(c2)) 162 | 163 | 164 | def lea(clusters, mention_to_gold): 165 | num, dem = 0, 0 166 | 167 | for c in clusters: 168 | if len(c) == 1: 169 | continue 170 | 171 | common_links = 0 172 | all_links = len(c) * (len(c) - 1) / 2.0 173 | for i, m in enumerate(c): 174 | if m in mention_to_gold: 175 | for m2 in c[i + 1 :]: 176 | if m2 in mention_to_gold and mention_to_gold[m] == mention_to_gold[m2]: 177 | common_links += 1 178 | 179 | num += len(c) * common_links / float(all_links) 180 | dem += len(c) 181 | 182 | return num, dem 183 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # custom 2 | data/* 3 | experiments/* 4 | modelss/* 5 | .vscode 6 | best/* 7 | resources/* 8 | maverick/utils/corefconversion/* 9 | debugg.py 10 | # Created by https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos 11 | # Edit at https://www.toptal.com/developers/gitignore?templates=jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos 12 | 13 | ### JetBrains+all ### 14 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 15 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 16 | 17 | # User-specific stuff 18 | .idea/**/workspace.xml 19 | .idea/**/tasks.xml 20 | .idea/**/usage.statistics.xml 21 | .idea/**/dictionaries 22 | .idea/**/shelf 23 | 24 | # Generated files 25 | .idea/**/contentModel.xml 26 | 27 | # Sensitive or high-churn files 28 | .idea/**/dataSources/ 29 | .idea/**/dataSources.ids 30 | .idea/**/dataSources.local.xml 31 | .idea/**/sqlDataSources.xml 32 | .idea/**/dynamic.xml 33 | .idea/**/uiDesigner.xml 34 | .idea/**/dbnavigator.xml 35 | 36 | # Gradle 37 | .idea/**/gradle.xml 38 | .idea/**/libraries 39 | 40 | # Gradle and Maven with auto-import 41 | # When using Gradle or Maven with auto-import, you should exclude module files, 42 | # since they will be recreated, and may cause churn. Uncomment if using 43 | # auto-import. 44 | # .idea/artifacts 45 | # .idea/compiler.xml 46 | # .idea/jarRepositories.xml 47 | # .idea/modules.xml 48 | # .idea/*.iml 49 | # .idea/modules 50 | # *.iml 51 | # *.ipr 52 | 53 | # CMake 54 | cmake-build-*/ 55 | 56 | # Mongo Explorer plugin 57 | .idea/**/mongoSettings.xml 58 | 59 | # File-based project format 60 | *.iws 61 | 62 | # IntelliJ 63 | out/ 64 | 65 | # mpeltonen/sbt-idea plugin 66 | .idea_modules/ 67 | 68 | # JIRA plugin 69 | atlassian-ide-plugin.xml 70 | 71 | # Cursive Clojure plugin 72 | .idea/replstate.xml 73 | 74 | # Crashlytics plugin (for Android Studio and IntelliJ) 75 | com_crashlytics_export_strings.xml 76 | crashlytics.properties 77 | crashlytics-build.properties 78 | fabric.properties 79 | 80 | # Editor-based Rest Client 81 | .idea/httpRequests 82 | 83 | # Android studio 3.1+ serialized cache file 84 | .idea/caches/build_file_checksums.ser 85 | 86 | ### JetBrains+all Patch ### 87 | # Ignores the whole .idea folder and all .iml files 88 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 89 | 90 | .idea/ 91 | 92 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 93 | 94 | *.iml 95 | modules.xml 96 | .idea/misc.xml 97 | *.ipr 98 | 99 | # Sonarlint plugin 100 | .idea/sonarlint 101 | 102 | ### JupyterNotebooks ### 103 | # gitignore template for Jupyter Notebooks 104 | # website: http://jupyter.org/ 105 | 106 | .ipynb_checkpoints 107 | */.ipynb_checkpoints/* 108 | 109 | # IPython 110 | profile_default/ 111 | ipython_config.py 112 | 113 | # Remove previous ipynb_checkpoints 114 | # git rm -r .ipynb_checkpoints/ 115 | 116 | ### Linux ### 117 | *~ 118 | 119 | # temporary files which can be created if a process still has a handle open of a deleted file 120 | .fuse_hidden* 121 | 122 | # KDE directory preferences 123 | .directory 124 | 125 | # Linux trash folder which might appear on any partition or disk 126 | .Trash-* 127 | 128 | # .nfs files are created when an open file is removed but is still being accessed 129 | .nfs* 130 | 131 | ### macOS ### 132 | # General 133 | .DS_Store 134 | .AppleDouble 135 | .LSOverride 136 | 137 | # Icon must end with two \r 138 | Icon 139 | 140 | 141 | # Thumbnails 142 | ._* 143 | 144 | # Files that might appear in the root of a volume 145 | .DocumentRevisions-V100 146 | .fseventsd 147 | .Spotlight-V100 148 | .TemporaryItems 149 | .Trashes 150 | .VolumeIcon.icns 151 | .com.apple.timemachine.donotpresent 152 | 153 | # Directories potentially created on remote AFP share 154 | .AppleDB 155 | .AppleDesktop 156 | Network Trash Folder 157 | Temporary Items 158 | .apdisk 159 | 160 | ### Python ### 161 | # Byte-compiled / optimized / DLL files 162 | __pycache__/ 163 | *.py[cod] 164 | *$py.class 165 | 166 | # C extensions 167 | *.so 168 | 169 | # Distribution / packaging 170 | .Python 171 | build/ 172 | develop-eggs/ 173 | dist/ 174 | downloads/ 175 | eggs/ 176 | .eggs/ 177 | lib/ 178 | lib64/ 179 | parts/ 180 | sdist/ 181 | var/ 182 | wheels/ 183 | pip-wheel-metadata/ 184 | share/python-wheels/ 185 | *.egg-info/ 186 | .installed.cfg 187 | *.egg 188 | MANIFEST 189 | 190 | # PyInstaller 191 | # Usually these files are written by a python script from a template 192 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 193 | *.manifest 194 | *.spec 195 | 196 | # Installer logs 197 | pip-log.txt 198 | pip-delete-this-directory.txt 199 | 200 | # Unit test / coverage reports 201 | htmlcov/ 202 | .tox/ 203 | .nox/ 204 | .coverage 205 | .coverage.* 206 | .cache 207 | nosetests.xml 208 | coverage.xml 209 | *.cover 210 | *.py,cover 211 | .hypothesis/ 212 | .pytest_cache/ 213 | pytestdebug.log 214 | 215 | # Translations 216 | *.mo 217 | *.pot 218 | 219 | # Django stuff: 220 | *.log 221 | local_settings.py 222 | db.sqlite3 223 | db.sqlite3-journal 224 | 225 | # Flask stuff: 226 | instance/ 227 | .webassets-cache 228 | 229 | # Scrapy stuff: 230 | .scrapy 231 | 232 | # Sphinx documentation 233 | docs/_build/ 234 | doc/_build/ 235 | 236 | # PyBuilder 237 | target/ 238 | 239 | # Jupyter Notebook 240 | 241 | # IPython 242 | 243 | # pyenv 244 | .python-version 245 | 246 | # pipenv 247 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 248 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 249 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 250 | # install all needed dependencies. 251 | #Pipfile.lock 252 | 253 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 254 | __pypackages__/ 255 | 256 | # Celery stuff 257 | celerybeat-schedule 258 | celerybeat.pid 259 | 260 | # SageMath parsed files 261 | *.sage.py 262 | 263 | # Environments 264 | .env 265 | .venv 266 | env/ 267 | venv/ 268 | ENV/ 269 | env.bak/ 270 | venv.bak/ 271 | pythonenv* 272 | 273 | # Spyder project settings 274 | .spyderproject 275 | .spyproject 276 | 277 | # Rope project settings 278 | .ropeproject 279 | 280 | # mkdocs documentation 281 | /site 282 | 283 | # mypy 284 | .mypy_cache/ 285 | .dmypy.json 286 | dmypy.json 287 | 288 | # Pyre type checker 289 | .pyre/ 290 | 291 | # pytype static type analyzer 292 | .pytype/ 293 | 294 | # profiling data 295 | .prof 296 | 297 | ### vscode ### 298 | .vscode/* 299 | !.vscode/settings.json 300 | !.vscode/tasks.json 301 | !.vscode/launch.json 302 | !.vscode/extensions.json 303 | *.code-workspace 304 | 305 | ### Windows ### 306 | # Windows thumbnail cache files 307 | Thumbs.db 308 | Thumbs.db:encryptable 309 | ehthumbs.db 310 | ehthumbs_vista.db 311 | 312 | # Dump file 313 | *.stackdump 314 | 315 | # Folder config file 316 | [Dd]esktop.ini 317 | 318 | # Recycle Bin used on file shares 319 | $RECYCLE.BIN/ 320 | 321 | # Windows Installer files 322 | *.cab 323 | *.msi 324 | *.msix 325 | *.msm 326 | *.msp 327 | 328 | # Windows shortcuts 329 | *.lnk 330 | 331 | # End of https://www.toptal.com/developers/gitignore/api/jetbrains+all,vscode,python,jupyternotebooks,linux,windows,macos 332 | 333 | -------------------------------------------------------------------------------- /maverick/evaluate.py: -------------------------------------------------------------------------------- 1 | import json 2 | import hydra 3 | import subprocess 4 | import torch 5 | from omegaconf import omegaconf 6 | from maverick.common.util import * 7 | from maverick.common.metrics import * 8 | from tqdm import tqdm 9 | from data.pl_data_modules import BasePLDataModule 10 | from models.pl_modules import BasePLModule 11 | from maverick.utils.loggingl import get_console_logger 12 | 13 | logger = get_console_logger() 14 | 15 | 16 | def jsonlines_to_html(jsonlines_input_name, output): 17 | cwd = str(hydra.utils.get_original_cwd()) 18 | subprocess.call( 19 | "python3 " 20 | + cwd 21 | + "/maverick/utils/corefconversion/jsonlines2text.py " 22 | + cwd 23 | + "/" 24 | + jsonlines_input_name 25 | + " -i -o " 26 | + cwd 27 | + "/experiments/" 28 | + output 29 | + ".html --sing-color" 30 | ' "black" --cm "common"', 31 | shell=True, 32 | ) 33 | 34 | 35 | @torch.no_grad() 36 | def evaluate(conf: omegaconf.DictConfig): 37 | device = conf.evaluation.device 38 | hydra.utils.log.info("Using {} as device".format(device)) 39 | print(conf.data.datamodule) 40 | pl_data_module: BasePLDataModule = hydra.utils.instantiate(conf.data.datamodule, _recursive_=False) 41 | 42 | pl_data_module.prepare_data() 43 | pl_data_module.setup("test") 44 | 45 | # jsonlines_to_html(pl_data_module.test_dataloader().dataset.path, "test") 46 | logger.log(f"Instantiating the Model from {conf.evaluation.checkpoint}") 47 | model = BasePLModule.load_from_checkpoint(conf.evaluation.checkpoint, _recursive_=False, map_location=device) 48 | if "gap" not in pl_data_module.test_dataloader().dataset.path: 49 | gold = [] 50 | info = [] 51 | with open(hydra.utils.get_original_cwd() + "/" + pl_data_module.test_dataloader().dataset.path, "r") as f: 52 | for line in f.readlines(): 53 | doc = json.loads(line) 54 | if "sentences" in doc: 55 | info.append({"doc_key": doc["doc_key"], "sentences": doc["sentences"]}) 56 | if "text" in doc: 57 | info.append({"doc_key": doc["doc_key"], "text": doc["text"]}) 58 | clusters = [] 59 | if "clusters" in doc: 60 | for cluster in doc["clusters"]: 61 | clusters.append(tuple([(m[0], m[1]) for m in cluster])) 62 | gold.append(clusters) 63 | 64 | mention_to_gold_clusters = [extract_mentions_to_clusters([tuple(g) for g in gold_element]) for gold_element in gold] 65 | 66 | predictions = model_predictions_with_dataloader(model, pl_data_module.test_dataloader(), device) 67 | mention_to_predicted_clusters = [extract_mentions_to_clusters(p) for p in predictions] 68 | 69 | print(evaluate_coref_scores(predictions, gold, mention_to_predicted_clusters, mention_to_gold_clusters)) 70 | 71 | with open(hydra.utils.get_original_cwd() + "/experiments/output.jsonlines", "w") as f: 72 | for pred, infos in zip(predictions, info): 73 | infos["clusters"] = pred 74 | f.write(json.dumps(infos) + "\n") 75 | 76 | # jsonlines_to_html("experiments/output.jsonlines", "output") 77 | else: 78 | predictions = model_predictions_with_dataloader(model, pl_data_module.test_dataloader(), device) 79 | with open("data/gap/gap-test-ontoformat.jsonl", "r") as fr: 80 | with open("data/gap/gap-test-output.tsv", "w") as fw: 81 | for line, pred in zip(fr.readlines(), predictions): 82 | doc = json.loads(line) 83 | # flattened = flatten(doc["sentences"]) 84 | A, B = A_and_B_corefs(pred, doc["gap_pronoun_offsets"], doc["gap_A_offsets"], doc["gap_B_offsets"]) 85 | fw.write(doc["doc_key"] + "\t" + str(A) + "\t" + str(B) + "\n") 86 | return 87 | 88 | 89 | def A_and_B_corefs(clusters, pronoun_offsets, A_offsets, B_offsets): 90 | A_Coref = False 91 | B_Coref = False 92 | pronoun_in_clusters = find_offsets_in_clusters(clusters, pronoun_offsets) 93 | A_in_clusters = find_offsets_in_clusters(clusters, A_offsets) 94 | B_in_clusters = find_offsets_in_clusters(clusters, B_offsets) 95 | # pronoun is in clusters 96 | if pronoun_in_clusters != None: 97 | # A is in clusters 98 | if A_in_clusters != None and A_in_clusters != B_in_clusters: 99 | A_Coref = candidate_pronoun_clustered(clusters, pronoun_in_clusters, A_in_clusters) 100 | if B_in_clusters != None and B_in_clusters != A_in_clusters: 101 | B_Coref = candidate_pronoun_clustered(clusters, pronoun_in_clusters, B_in_clusters) 102 | return A_Coref, B_Coref 103 | 104 | 105 | def candidate_pronoun_clustered(clusters, pronoun, candidate): 106 | clustered = False 107 | for cluster in clusters: 108 | if pronoun in cluster and candidate in cluster: 109 | clustered = True 110 | return clustered 111 | 112 | 113 | def find_offsets_in_clusters(clusters, offsets): 114 | clusters = [list(item) for item in clusters] 115 | mentions = [item for sublist in clusters for item in sublist] 116 | result = None 117 | for mention in mentions: 118 | if offsets[0] >= mention[0] and offsets[1] <= mention[1]: 119 | if result == None: 120 | result = mention 121 | else: 122 | if result[1] - result[0] > mention[1] - mention[0]: 123 | result = mention 124 | return result 125 | 126 | 127 | def evaluate_coref_scores(pred, gold, mention_to_pred, mention_to_gold): 128 | evaluator = OfficialCoNLL2012CorefEvaluator() 129 | 130 | for p, g, m2p, m2g in zip(pred, gold, mention_to_pred, mention_to_gold): 131 | evaluator.update(p, g, m2p, m2g) 132 | result = [] 133 | for metric in ["muc", "b_cubed", "ceafe", "conll2012"]: 134 | result.append(dict(zip(["precision", "recall", "f1_score"], evaluator.get_prf(metric)))) 135 | return result 136 | 137 | 138 | def model_predictions_with_dataloader(model, test_dataloader, device): 139 | model.to(device) 140 | model.eval() 141 | predictions = [] 142 | 143 | for batch in tqdm(test_dataloader, desc="Test", total=test_dataloader.__len__()): 144 | output = model.model( 145 | stage="test", 146 | input_ids=batch["input_ids"].to(device), 147 | attention_mask=batch["attention_mask"].to(device), 148 | eos_mask=batch["eos_mask"].to(device), 149 | # gold_starts=batch["gold_starts"].to(device), 150 | # gold_mentions=batch["gold_mentions"].to(device), 151 | # gold_clusters=batch["gold_clusters"].to(device), 152 | tokens=batch["tokens"], 153 | subtoken_map=batch["subtoken_map"], 154 | new_token_map=batch["new_token_map"], 155 | singletons=False, 156 | ) 157 | 158 | clusters_predicted = original_token_offsets( 159 | clusters=output["pred_dict"]["clusters"], 160 | subtoken_map=batch["subtoken_map"][0], 161 | new_token_map=batch["new_token_map"][0], 162 | ) 163 | predictions.append(clusters_predicted) 164 | return predictions 165 | 166 | 167 | @hydra.main(config_path="../conf", config_name="root") 168 | def main(conf: omegaconf.DictConfig): 169 | evaluate(conf) 170 | 171 | 172 | if __name__ == "__main__": 173 | main() 174 | -------------------------------------------------------------------------------- /maverick/common/util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import spacy 3 | import hydra 4 | import torch 5 | import pytorch_lightning as pl 6 | 7 | from nltk import sent_tokenize 8 | from torch.nn import Module, Linear, LayerNorm, Dropout, GELU 9 | 10 | from maverick.common.constants import * 11 | 12 | 13 | def get_category_id(mention, antecedent): 14 | mention, mention_pronoun_id = mention 15 | antecedent, antecedent_pronoun_id = antecedent 16 | 17 | if mention_pronoun_id > -1 and antecedent_pronoun_id > -1: 18 | if mention_pronoun_id == antecedent_pronoun_id: 19 | return CATEGORIES["pron-pron-comp"] 20 | else: 21 | return CATEGORIES["pron-pron-no-comp"] 22 | 23 | if mention_pronoun_id > -1 or antecedent_pronoun_id > -1: 24 | return CATEGORIES["pron-ent"] 25 | 26 | if mention == antecedent: 27 | return CATEGORIES["match"] 28 | 29 | union = mention.union(antecedent) 30 | if len(union) == max(len(mention), len(antecedent)): 31 | return CATEGORIES["contain"] 32 | 33 | return CATEGORIES["other"] 34 | 35 | 36 | def get_pronoun_id(span): 37 | if len(span) == 1: 38 | span = list(span) 39 | if span[0] in PRONOUNS_GROUPS: 40 | return PRONOUNS_GROUPS[span[0]] 41 | return -1 42 | 43 | 44 | def flatten(l): 45 | return [item for sublist in l for item in sublist] 46 | 47 | 48 | def ontonotes_to_dataframe(file_path): 49 | # read file 50 | df = pd.read_json(hydra.utils.get_original_cwd() + "/" + file_path, lines=True) 51 | # ontonotes is split into words and sentences, jin sentences 52 | if "sentences" in df.columns: 53 | df["tokens"] = df["sentences"].apply(lambda x: flatten(x)) 54 | elif "text" in df.columns: # te 55 | nlp = spacy.load("en_core_web_sm", exclude=["tagger", "parser", "lemmatizer", "ner", "textcat"]) 56 | texts = df["text"].tolist() 57 | df["sentences"] = [[[tok.text for tok in s] for s in nlp.pipe(s)] for s in [sent_tokenize(text) for text in texts]] 58 | 59 | nlp = spacy.load("en_core_web_sm", exclude=["tagger", "parser", "lemmatizer", "ner", "textcat"]) 60 | 61 | df["tokens"] = [[tok.text for tok in doc] for doc in nlp.pipe(texts)] 62 | # compute end of sequence indices 63 | if "preco" in file_path: 64 | df["EOS_indices"] = df["tokens"].apply(lambda x: [i + 1 for i, token in enumerate(x) if token == "."]) 65 | else: 66 | df["EOS_lengths"] = df["sentences"].apply(lambda x: [len(value) for value in x]) 67 | df["EOS_indices"] = df["EOS_lengths"].apply(lambda x: [sum(x[0 : (i[0] + 1)]) for i in enumerate(x)]) 68 | # add speakers 69 | if "speakers" in df.columns and "wkc" not in file_path and "q" not in file_path: 70 | df["speakers"] = df["speakers"].apply(lambda x: flatten(x)) 71 | else: 72 | df["speakers"] = df["tokens"].apply(lambda x: ["-"] * len(x)) 73 | 74 | if "clusters" in df.columns: 75 | df = df[["doc_key", "tokens", "speakers", "clusters", "EOS_indices"]] 76 | else: 77 | df = df[["doc_key", "tokens", "speakers", "EOS_indices"]] 78 | df = df.dropna() 79 | df = df.reset_index(drop=True) 80 | return df 81 | 82 | 83 | def extract_mentions_to_clusters(gold_clusters): 84 | mention_to_gold = {} 85 | for gc in gold_clusters: 86 | for mention in gc: 87 | mention_to_gold[mention] = gc 88 | return mention_to_gold 89 | 90 | 91 | def original_token_offsets(clusters, subtoken_map, new_token_map): 92 | return [ 93 | tuple( 94 | [ 95 | ( 96 | new_token_map[subtoken_map[start]], 97 | new_token_map[subtoken_map[end]], 98 | ) 99 | for start, end in cluster 100 | if subtoken_map[start] is not None 101 | and subtoken_map[end] is not None # only happens first evals, model predicts as mentions 102 | and new_token_map[subtoken_map[start]] 103 | is not None # it happens very rarely that in some weidly formatted sentences the model predicts the speaker name as a possible mention 104 | ] 105 | ) 106 | for cluster in clusters 107 | ] 108 | 109 | 110 | def unpad_gold_clusters(gold_clusters): 111 | new_gold_clusters = [] 112 | for batch in gold_clusters: 113 | new_gold_clusters = [] 114 | for cluster in batch: 115 | new_cluster = [] 116 | for span in cluster: 117 | if span[0].item() != -1: 118 | new_cluster.append((span[0].item(), span[1].item())) 119 | if len(new_cluster) != 0: 120 | new_gold_clusters.append(tuple(new_cluster)) 121 | return new_gold_clusters 122 | 123 | 124 | class FullyConnectedLayer(Module): 125 | def __init__(self, input_dim, output_dim, hidden_size, dropout_prob): 126 | super(FullyConnectedLayer, self).__init__() 127 | self.input_dim = input_dim 128 | self.output_dim = output_dim 129 | self.dropout_prob = dropout_prob 130 | 131 | self.dense1 = Linear(self.input_dim, hidden_size) 132 | self.dense = Linear(hidden_size, self.output_dim) 133 | self.layer_norm = LayerNorm(hidden_size) 134 | self.activation_func = GELU() 135 | self.dropout = Dropout(self.dropout_prob) 136 | 137 | def forward(self, inputs): 138 | temp = inputs 139 | temp = self.dense1(temp) 140 | temp = self.dropout(temp) 141 | temp = self.activation_func(temp) 142 | temp = self.layer_norm(temp) 143 | temp = self.dense(temp) 144 | return temp 145 | 146 | 147 | class RepresentationLayer(torch.nn.Module): 148 | def __init__(self, type, input_dim, output_dim, hidden_dim, **kwargs) -> None: 149 | super(RepresentationLayer, self).__init__() 150 | self.hidden_dim = hidden_dim 151 | self.lt = type 152 | if type == "Linear": 153 | self.layer = Linear(input_dim, output_dim) 154 | elif type == "FC": 155 | self.layer = FullyConnectedLayer(input_dim, output_dim, hidden_dim, dropout_prob=0.2) 156 | elif type == "LSTM-left": 157 | self.layer = torch.nn.LSTM(input_size=input_dim, hidden_size=output_dim, bidirectional=True) 158 | elif type == "LSTM-right": 159 | self.layer = torch.nn.LSTM(input_size=input_dim, hidden_size=output_dim, bidirectional=True) 160 | elif type == "LSTM-bidirectional": 161 | self.layer = torch.nn.LSTM(input_size=input_dim, hidden_size=output_dim / 2, bidirectional=True) 162 | elif type == "Conv1d": 163 | self.layer = torch.nn.Conv1d(input_size=input_dim, hidden_size=output_dim, kernel_size=7, stride=1, padding=3) 164 | self.dropout = Dropout(0.2) 165 | 166 | def forward(self, inputs): 167 | if self.lt == "Linear": 168 | return self.layer(inputs) 169 | elif self.lt == "FC": 170 | return self.layer(inputs) 171 | elif self.lt == "LSTM-left": 172 | return self.layer(inputs)[0][: self.hidden_dim] 173 | elif self.lt == "LSTM-right": 174 | return self.layer(inputs)[0][self.hidden_dim :] 175 | elif self.lt == "LSTM-bidirectional": 176 | return self.layer(inputs)[0] 177 | elif self.lt == "Conv1d": 178 | return self.layer(self.dropout(inputs)) 179 | 180 | 181 | def download_load_spacy(): 182 | try: 183 | import nltk 184 | 185 | nlp = spacy.load("en_core_web_sm", exclude=["tagger", "parser", "lemmatizer", "ner", "textcat"]) 186 | 187 | # colab fix 188 | try: 189 | import nltk 190 | 191 | nltk.data.find("tokenizers/punkt") 192 | except: 193 | nltk.download("punkt") 194 | except: 195 | from spacy.cli import download 196 | import nltk 197 | 198 | nltk.download("punkt") 199 | download("en_core_web_sm") 200 | nlp = spacy.load("en_core_web_sm", exclude=["tagger", "parser", "lemmatizer", "ner", "textcat"]) 201 | return nlp 202 | -------------------------------------------------------------------------------- /maverick/models/maverick_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from transformers import AutoTokenizer 5 | 6 | from maverick.models.pl_modules import BasePLModule 7 | from maverick.common.util import * 8 | from maverick.common.constants import * 9 | from maverick.models import * 10 | 11 | from transformers.utils.hub import cached_file as hf_cached_file 12 | 13 | 14 | class Maverick: 15 | # put the pip package online 16 | def __init__(self, hf_name_or_path="sapienzanlp/maverick-mes-ontonotes", device="cuda"): 17 | self.device = device 18 | path = self.__get_model_path__(hf_name_or_path) 19 | self.model = BasePLModule.load_from_checkpoint(path, _recursive_=False, map_location=self.device) 20 | self.model = self.model.eval() 21 | self.model = self.model.model 22 | self.tokenizer = self.__get_model_tokenizer__() 23 | 24 | def __get_model_path__(self, hf_name_or_path): 25 | try: 26 | print(hf_name_or_path, "loading") 27 | path = hf_cached_file(hf_name_or_path, "weights.ckpt") 28 | except: 29 | print(hf_name_or_path, "not found on huggingface, loading from local path ") 30 | path = hf_name_or_path 31 | return path 32 | 33 | def __get_model_tokenizer__(self): 34 | tokenizer = AutoTokenizer.from_pretrained(self.model.encoder_hf_model_name, use_fast=True, add_prefix_space=True) 35 | special_tokens_dict = {"additional_special_tokens": ["[SPEAKER_START]", "[SPEAKER_END]"]} 36 | tokenizer.add_special_tokens(special_tokens_dict) 37 | return tokenizer 38 | 39 | def __sample_type__(self, sample): 40 | if isinstance(sample, str): 41 | result = "text" 42 | if isinstance(sample, list): 43 | result = "word_tokenized" 44 | if len(sample) != 0 and isinstance(sample[0], list): 45 | result = "sentence_tokenized" 46 | return result 47 | 48 | def preprocess(self, sample, speakers=None): 49 | type = self.__sample_type__(sample) 50 | char_offsets = None 51 | if type == "text": 52 | nlp = download_load_spacy() 53 | char_offsets = [] 54 | sentences = [] 55 | off = 0 56 | s = sent_tokenize(sample) 57 | for sent, sentence in zip(nlp.pipe(s), s): 58 | char_offsets.append([(off + tok.idx, off + tok.idx + len(tok.text) - 1) for tok in sent]) 59 | sentences.append([tok.text for tok in sent]) 60 | off += len(sentence) + 1 61 | char_offsets = flatten(char_offsets) 62 | tokens = flatten(sentences) 63 | eos_len = [len(value) for value in sentences] 64 | eos = [sum(eos_len[0 : (i[0] + 1)]) for i in enumerate(eos_len)] 65 | elif type == "word_tokenized": 66 | nlp = download_load_spacy() 67 | tokens = sample 68 | eos = [idx + 1 for idx, tok in enumerate(tokens) if tok == "."] 69 | if len(eos) == 0 or eos[-1] != len(tokens): 70 | eos.append(len(tokens)) 71 | elif type == "sentence_tokenized": 72 | sentences = sample 73 | tokens = flatten(sentences) 74 | eos_len = [len(value) for value in sentences] 75 | eos = [sum(eos_len[0 : (i[0] + 1)]) for i in enumerate(eos_len)] 76 | if speakers == None: 77 | speakers = ["-"] * len(tokens) 78 | else: 79 | speakers = flatten(speakers) 80 | return tokens, eos, speakers, char_offsets 81 | 82 | # takes length of sequence (int) and eos_indices ([]) 83 | # returns len x len zeros matrix with 1 in pos (start, all possible ends) 84 | def eos_mask(self, input_ids_len, eos_indices): 85 | mask = np.zeros((input_ids_len, input_ids_len)) 86 | prec = 0 87 | for eos_idx in eos_indices: 88 | for i in range(prec, eos_idx + 1): 89 | for j in range(prec, eos_idx + 1): 90 | mask[i][j] = 1 91 | prec = eos_idx 92 | mask = np.triu(mask) 93 | return mask 94 | 95 | @torch.no_grad() 96 | def predict(self, sample, singletons=False, add_gold_clusters=None, predefined_mentions=None, speakers=None): 97 | tokens, eos_indices, speakers, char_offsets = self.preprocess(sample, speakers) # [[w1,w2,w3...], []] 98 | tokenized = self.tokenize(tokens, eos_indices, speakers, predefined_mentions, add_gold_clusters) 99 | 100 | output = self.model( 101 | stage="test", 102 | input_ids=torch.tensor(tokenized["input_ids"]).unsqueeze(0).to(self.device), 103 | attention_mask=torch.tensor(tokenized["attention_mask"]).unsqueeze(0).to(self.device), 104 | eos_mask=torch.tensor(tokenized["eos_mask"]).unsqueeze(0).to(self.device), 105 | tokens=[tokenized["tokens"]], 106 | subtoken_map=[tokenized["subtoken_map"]], 107 | new_token_map=[tokenized["new_token_map"]], 108 | singletons=singletons, 109 | add=tokenized["added"], 110 | gold_mentions=( 111 | None 112 | if tokenized["gold_mentions"] == None 113 | else torch.tensor( 114 | self.create_mention_matrix( 115 | len(tokenized["input_ids"]), 116 | tokenized["gold_mentions"], 117 | ) 118 | ) 119 | .unsqueeze(0) 120 | .to(self.device) 121 | ), 122 | ) 123 | 124 | clusters_predicted = original_token_offsets( 125 | clusters=output["pred_dict"]["clusters"], 126 | subtoken_map=tokenized["subtoken_map"], 127 | new_token_map=tokenized["new_token_map"], 128 | ) 129 | result = {} 130 | result["tokens"] = tokens 131 | result["clusters_token_offsets"] = clusters_predicted 132 | result["clusters_char_offsets"] = None 133 | result["clusters_token_text"] = [ 134 | [" ".join(tokens[span[0] : span[1] + 1]) for span in cluster] for cluster in clusters_predicted 135 | ] 136 | result["clusters_char_text"] = None 137 | if char_offsets != None: 138 | result["clusters_char_offsets"] = [ 139 | [(char_offsets[span[0]][0], char_offsets[span[1]][1]) for span in cluster] for cluster in clusters_predicted 140 | ] 141 | 142 | # result["clusters_char_text"] = [ 143 | # [sample[char_offsets[span[0]][0] : char_offsets[span[1]][1] + 1] for span in cluster] 144 | # for cluster in clusters_predicted 145 | # ] 146 | 147 | return result 148 | 149 | def create_mention_matrix(self, input_ids_len, mentions): 150 | matrix = np.zeros((input_ids_len, input_ids_len)) 151 | for start_bpe_idx, end_bpe_idx in mentions: 152 | matrix[start_bpe_idx][end_bpe_idx] = 1 153 | return matrix 154 | 155 | def tokenize(self, tokens, eos_indices, speakers=None, gold_mentions=None, add_gold_clusters=None): 156 | token_to_new_token_map = [] # len() = len(tokens), contains indices of original sequence to new sequence 157 | new_token_map = [] # len() = len(new_tokens), contains indices of new sequence 158 | new_tokens = [] # contains new tokens 159 | last_speaker = None 160 | 161 | for idx, (token, speaker) in enumerate(zip(tokens, speakers)): 162 | if last_speaker != speaker: 163 | new_tokens += ["[SPEAKER_START]", speaker, "[SPEAKER_END]"] 164 | new_token_map += [None, None, None] 165 | last_speaker = speaker 166 | token_to_new_token_map.append(len(new_tokens)) 167 | new_token_map.append(idx) 168 | new_tokens.append(token) 169 | 170 | encoded_text = self.tokenizer(new_tokens, add_special_tokens=True, is_split_into_words=True) 171 | if gold_mentions != None: 172 | gold_mentions = [ 173 | ( 174 | encoded_text.word_to_tokens(token_to_new_token_map[start]).start, 175 | encoded_text.word_to_tokens(token_to_new_token_map[end]).end - 1, 176 | ) 177 | for start, end in gold_mentions 178 | ] 179 | 180 | addedd = None 181 | if add_gold_clusters != None: 182 | addedd = [ 183 | [ 184 | ( 185 | encoded_text.word_to_tokens(token_to_new_token_map[start]).start, 186 | encoded_text.word_to_tokens(token_to_new_token_map[end]).end - 1, 187 | ) 188 | for start, end in cluster 189 | ] 190 | for cluster in add_gold_clusters 191 | ] 192 | 193 | eos_indices = [ 194 | encoded_text.word_to_tokens(token_to_new_token_map[eos - 1]).start 195 | for eos in eos_indices 196 | if encoded_text.word_to_tokens(token_to_new_token_map[eos - 1]) != None 197 | ] 198 | output = { 199 | "tokens": tokens, 200 | "input_ids": encoded_text["input_ids"], 201 | "attention_mask": encoded_text["attention_mask"], 202 | "subtoken_map": encoded_text.word_ids(), 203 | "new_token_map": new_token_map, 204 | "eos_mask": self.eos_mask(len(encoded_text["input_ids"]), eos_indices), 205 | "gold_mentions": gold_mentions, 206 | "added": addedd, 207 | } 208 | return output 209 | -------------------------------------------------------------------------------- /maverick/data/datasets.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import hydra.utils 3 | import torch 4 | 5 | from typing import Tuple 6 | from typing import Dict, Union 7 | from torch.utils.data import Dataset 8 | from transformers import AutoTokenizer 9 | from torch.utils.data import Dataset 10 | from datasets import load_from_disk 11 | from datasets import Dataset as dt 12 | 13 | import maverick.common.util as util 14 | 15 | 16 | NULL_ID_FOR_COREF = -1 17 | 18 | 19 | class OntonotesDataset(Dataset): 20 | def __init__(self, name: str, path: str, batch_size, processed_dataset_path, tokenizer, **kwargs): 21 | super().__init__() 22 | self.stage = name 23 | self.path = path 24 | self.batch_size = batch_size 25 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer, use_fast=True, add_prefix_space=True) 26 | self.max_doc_len = kwargs.get("max_doc_len", None) 27 | special_tokens_dict = {"additional_special_tokens": ["[SPEAKER_START]", "[SPEAKER_END]"]} 28 | self.tokenizer.add_special_tokens(special_tokens_dict) 29 | try: 30 | self.set = load_from_disk(hydra.utils.get_original_cwd() + "/" + processed_dataset_path + "/") 31 | except: 32 | self.set = dt.from_pandas(util.ontonotes_to_dataframe(path)) 33 | if self.stage == "train": 34 | if "clusters" not in self.set.column_names: 35 | print("Training set has to include cluster information") 36 | # self.set = self.set.map(self.cut_document_to_length, batched=False) 37 | self.set = self.prepare_data(self.set) 38 | self.set = self.set.map(self.encode, batched=False) 39 | self.set = self.set.remove_columns(column_names=["speakers"]) 40 | if self.stage != "test": 41 | self.set.save_to_disk(hydra.utils.get_original_cwd() + "/" + processed_dataset_path + "/") 42 | 43 | def prepare_data(self, set): 44 | return set.filter(lambda x: len(self.tokenizer(x["tokens"])["input_ids"]) <= self.max_doc_len) 45 | 46 | def cut_document_to_length(self, set_element): 47 | encoded_text = self.tokenizer(set_element["tokens"], add_special_tokens=True, is_split_into_words=True) 48 | if len(encoded_text["input_ids"]) <= self.max_doc_len + 1: 49 | result = set_element 50 | else: 51 | last_index_input_id_in_sentence = encoded_text.token_to_word(self.max_doc_len) 52 | eos_indices = [end for end in set_element["EOS_indices"] if end < last_index_input_id_in_sentence] 53 | last_sentence_end = eos_indices[-3] 54 | result = { 55 | "doc_key": set_element["doc_key"], 56 | "tokens": set_element["tokens"][:last_sentence_end], 57 | "speakers": set_element["speakers"][:last_sentence_end], 58 | "EOS_indices": eos_indices[:-3], 59 | } 60 | new_clusters = [] 61 | for cluster in set_element["clusters"]: 62 | new_cluster = [] 63 | for span in cluster: 64 | if span[1] < last_sentence_end: 65 | new_cluster.append(span) 66 | if len(new_cluster) >= 1: 67 | new_clusters.append(new_cluster) 68 | result["clusters"] = new_clusters 69 | return result 70 | 71 | # missing genre information 72 | def _tokenize(self, tokens, clusters, speakers, eos_indices): 73 | token_to_new_token_map = [] # len() = len(tokens), contains indices of original sequence to new sequence 74 | new_token_map = [] # len() = len(new_tokens), contains indices of new sequence 75 | new_tokens = [] # contains new tokens 76 | last_speaker = None 77 | 78 | for idx, (token, speaker) in enumerate(zip(tokens, speakers)): 79 | if last_speaker != speaker: 80 | new_tokens += ["[SPEAKER_START]", speaker, "[SPEAKER_END]"] 81 | new_token_map += [None, None, None] 82 | last_speaker = speaker 83 | token_to_new_token_map.append(len(new_tokens)) 84 | new_token_map.append(idx) 85 | new_tokens.append(token) 86 | 87 | for cluster in clusters: 88 | for start, end in cluster: 89 | assert tokens[start : end + 1] == new_tokens[token_to_new_token_map[start] : token_to_new_token_map[end] + 1] 90 | 91 | encoded_text = self.tokenizer(new_tokens, add_special_tokens=True, is_split_into_words=True) 92 | clusters = [ 93 | [ 94 | ( 95 | encoded_text.word_to_tokens(token_to_new_token_map[start]).start, 96 | encoded_text.word_to_tokens(token_to_new_token_map[end]).end - 1, 97 | ) 98 | for start, end in cluster 99 | ] 100 | for cluster in clusters 101 | ] 102 | eos_indices = [ 103 | encoded_text.word_to_tokens(token_to_new_token_map[eos - 1]).start 104 | for eos in eos_indices 105 | if encoded_text.word_to_tokens(token_to_new_token_map[eos - 1]) != None 106 | ] 107 | output = { 108 | "tokens": tokens, 109 | "input_ids": encoded_text["input_ids"], 110 | "attention_mask": encoded_text["attention_mask"], 111 | "gold_clusters": clusters, 112 | "subtoken_map": encoded_text.word_ids(), 113 | "new_token_map": new_token_map, 114 | "EOS_indices": eos_indices, 115 | } 116 | return output 117 | 118 | def encode(self, example): 119 | if "clusters" not in example: 120 | example["clusters"] = [] 121 | encoded = self._tokenize( 122 | example["tokens"], 123 | example["clusters"], 124 | example["speakers"], 125 | example["EOS_indices"], 126 | ) # debug when no clusters 127 | encoded["num_clusters"] = len(encoded["gold_clusters"]) if encoded["gold_clusters"] else 0 128 | encoded["max_cluster_size"] = max(len(c) for c in encoded["gold_clusters"]) if encoded["gold_clusters"] else 0 129 | encoded["length"] = len(encoded["input_ids"]) 130 | return encoded 131 | 132 | def __len__(self) -> int: 133 | return self.set.shape[0] 134 | 135 | def __getitem__(self, index) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: 136 | return self.set[index] 137 | 138 | # takes length of sequence (int) and eos_indices ([]) 139 | # returns len x len zeros matrix with 1 in pos (start, all possible ends) 140 | def eos_mask(self, input_ids_len, eos_indices): 141 | mask = np.zeros((input_ids_len, input_ids_len)) 142 | prec = 0 143 | for eos_idx in eos_indices: 144 | for i in range(prec, eos_idx + 2): 145 | for j in range(prec, eos_idx + 2): 146 | mask[i][j] = 1 147 | prec = eos_idx 148 | mask = np.triu(mask) 149 | return mask 150 | 151 | # takes length of sequence (int) and coreferences ([[()]]) 152 | # returns len x len zeros matrix with 1 in pos (start, end) 153 | def create_mention_matrix(self, input_ids_len, coreferences): 154 | matrix = np.zeros((input_ids_len, input_ids_len)) 155 | for cluster in coreferences: 156 | for start_bpe_idx, end_bpe_idx in cluster: 157 | matrix[start_bpe_idx][end_bpe_idx] = 1 158 | return matrix 159 | 160 | # takes length of sequence (int) and coreferences ([[()]]) 161 | # returns len zeros matrix with 1 in start position 162 | def create_start_matrix(self, input_ids_len, coreferences): 163 | matrix = np.zeros((input_ids_len)) 164 | for cluster in coreferences: 165 | for start_bpe_idx, end_bpe_idx in cluster: 166 | matrix[start_bpe_idx] = 1 167 | return matrix 168 | 169 | # pad don't pad the rest, and is slow, think about something else 170 | def collate_fn(self, batch): 171 | batch = self.tokenizer.pad(batch) 172 | output = { 173 | "input_ids": torch.tensor(batch["input_ids"]), 174 | "attention_mask": torch.tensor(batch["attention_mask"]), 175 | "eos_mask": torch.tensor(self.eos_mask(len(batch["input_ids"][0]), batch["EOS_indices"][0])).unsqueeze(0), 176 | "gold_mentions": torch.tensor( 177 | self.create_mention_matrix(len(batch["input_ids"][0]), batch["gold_clusters"][0]) 178 | ).unsqueeze(0), 179 | "gold_starts": torch.tensor( 180 | self.create_start_matrix(len(batch["input_ids"][0]), batch["gold_clusters"][0]) 181 | ).unsqueeze(0), 182 | } 183 | 184 | max_num_clusters, max_max_cluster_size = max(batch["num_clusters"]), max(batch["max_cluster_size"]) 185 | if max_num_clusters == 0: 186 | padded_clusters = [] 187 | else: 188 | padded_clusters = [ 189 | pad_clusters(cluster, max_num_clusters, max_max_cluster_size) for cluster in batch["gold_clusters"] 190 | ] 191 | 192 | output["gold_clusters"] = torch.tensor(padded_clusters) 193 | output["tokens"] = batch["tokens"] 194 | output["doc_key"] = batch["doc_key"] 195 | output["subtoken_map"] = batch["subtoken_map"] 196 | output["new_token_map"] = batch["new_token_map"] 197 | output["eos_indices"] = torch.tensor(batch["EOS_indices"]) 198 | output["singletons"] = "litbank" in self.path or "preco" in self.path 199 | return output 200 | 201 | 202 | def pad_clusters_inside(clusters, max_cluster_size): 203 | return [cluster + [(NULL_ID_FOR_COREF, NULL_ID_FOR_COREF)] * (max_cluster_size - len(cluster)) for cluster in clusters] 204 | 205 | 206 | def pad_clusters_outside(clusters, max_num_clusters): 207 | return clusters + [[]] * (max_num_clusters - len(clusters)) 208 | 209 | 210 | def pad_clusters(clusters, max_num_clusters, max_cluster_size): 211 | clusters = pad_clusters_outside(clusters, max_num_clusters) 212 | clusters = pad_clusters_inside(clusters, max_cluster_size) 213 | return clusters 214 | -------------------------------------------------------------------------------- /maverick/models/pl_modules.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import hydra 4 | import pytorch_lightning as pl 5 | import torch 6 | 7 | from torchmetrics import * 8 | import transformers 9 | from transformers import Adafactor 10 | 11 | from maverick.common.metrics import * 12 | from maverick.common.util import * 13 | from maverick.models.model_incr import * 14 | from maverick.models.model_mes import Maverick_mes 15 | from maverick.models.model_s2e import * 16 | 17 | 18 | class BasePLModule(pl.LightningModule): 19 | def __init__(self, *args, **kwargs) -> None: 20 | super().__init__() 21 | self.save_hyperparameters() 22 | try: 23 | self.model = hydra.utils.instantiate(self.hparams.model) 24 | except: 25 | self.hparams.model["_target_"] = "maverick." + self.hparams.model["_target_"] 26 | self.model = hydra.utils.instantiate(self.hparams.model) 27 | self.train_step_predictions = [] 28 | self.train_step_gold = [] 29 | self.validation_step_predictions = [] 30 | self.validation_step_gold = [] 31 | self.test_step_predictions = [] 32 | self.test_step_gold = [] 33 | 34 | def forward(self, batch) -> dict: 35 | output_dict = self.model(batch, "forward") 36 | return output_dict 37 | 38 | def evaluate(self, predictions, golds): 39 | mention_evaluator = OfficialMentionEvaluator() 40 | start_evaluator = OfficialMentionEvaluator() 41 | cluster_mention_evaluator = OfficialMentionEvaluator() 42 | coref_evaluator = OfficialCoNLL2012CorefEvaluator() 43 | result = {} 44 | 45 | for pred, gold in zip(predictions, golds): 46 | if "start_idxs" in pred.keys(): 47 | starts_pred = pred["start_idxs"][0].tolist() 48 | starts_gold = (gold["gold_starts"][0] == 1).nonzero(as_tuple=False).squeeze(-1).tolist() 49 | start_evaluator.update(starts_pred, starts_gold) 50 | 51 | if "mention_idxs" in pred.keys(): 52 | mentions_pred = [tuple(p) for p in pred["mention_idxs"][0].tolist()] 53 | mentions_gold = [tuple(g) for g in (gold["gold_mentions"][0] == 1).nonzero(as_tuple=False).tolist()] 54 | mention_evaluator.update(mentions_pred, mentions_gold) 55 | 56 | if "clusters" in pred.keys(): 57 | pred_clusters = pred["clusters"] 58 | gold_clusters = gold["gold_clusters"] 59 | mention_to_gold_clusters = extract_mentions_to_clusters(gold_clusters) 60 | mention_to_predicted_clusters = extract_mentions_to_clusters(pred_clusters) 61 | 62 | coref_evaluator.update( 63 | pred_clusters, 64 | gold_clusters, 65 | mention_to_predicted_clusters, 66 | mention_to_gold_clusters, 67 | ) 68 | 69 | cluster_mention_evaluator.update( 70 | [item for sublist in gold_clusters for item in sublist], 71 | [item for sublist in pred_clusters for item in sublist], 72 | ) 73 | p, r, f1 = start_evaluator.get_prf() 74 | result.update( 75 | { 76 | "start_f1_score": f1, 77 | "start_precision": p, 78 | "start_recall": r, 79 | } 80 | ) 81 | p, r, f1 = mention_evaluator.get_prf() 82 | result.update({"mention_f1_score": f1, "mention_precision": p, "mention_recall": r}) 83 | p, r, f1 = cluster_mention_evaluator.get_prf() 84 | result.update( 85 | { 86 | "cluster_mention_f1_score": f1, 87 | "cluster_mention_precision": p, 88 | "cluster_mention_recall": r, 89 | } 90 | ) 91 | 92 | for metric in ["muc", "b_cubed", "ceafe", "conll2012"]: 93 | p, r, f1 = coref_evaluator.get_prf(metric) 94 | result.update( 95 | { 96 | metric + "_f1_score": f1, 97 | metric + "_precision": p, 98 | metric + "_recall": r, 99 | } 100 | ) 101 | return result 102 | 103 | def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor: 104 | output = self.model( 105 | stage="train", 106 | input_ids=batch["input_ids"], 107 | attention_mask=batch["attention_mask"], 108 | eos_mask=batch["eos_mask"], 109 | gold_starts=batch["gold_starts"], 110 | gold_mentions=batch["gold_mentions"], 111 | gold_clusters=batch["gold_clusters"], 112 | tokens=batch["tokens"], 113 | subtoken_map=batch["subtoken_map"], 114 | new_token_map=batch["new_token_map"], 115 | singletons=batch["singletons"], 116 | ) 117 | self.log_dict({"train/" + k: v for k, v in output["loss_dict"].items()}, on_step=True) 118 | return output["loss"] 119 | 120 | def validation_step(self, batch: dict, batch_idx: int): 121 | output = self.model( 122 | stage="val", 123 | input_ids=batch["input_ids"], 124 | attention_mask=batch["attention_mask"], 125 | eos_mask=batch["eos_mask"], 126 | gold_starts=batch["gold_starts"], 127 | gold_mentions=batch["gold_mentions"], 128 | gold_clusters=batch["gold_clusters"], 129 | tokens=batch["tokens"], 130 | subtoken_map=batch["subtoken_map"], 131 | new_token_map=batch["new_token_map"], 132 | singletons=batch["singletons"], 133 | ) 134 | self.log_dict({"val/" + k: v for k, v in output["loss_dict"].items()}) 135 | output["pred_dict"]["clusters"] = original_token_offsets( 136 | clusters=output["pred_dict"]["clusters"], 137 | subtoken_map=batch["subtoken_map"][0], 138 | new_token_map=batch["new_token_map"][0], 139 | ) 140 | self.validation_step_predictions.append(output["pred_dict"]) 141 | self.validation_step_gold.append( 142 | { 143 | "gold_starts": batch["gold_starts"].cpu(), 144 | "gold_mentions": batch["gold_mentions"].cpu(), 145 | "gold_clusters": original_token_offsets( 146 | clusters=unpad_gold_clusters(batch["gold_clusters"].cpu()), 147 | subtoken_map=batch["subtoken_map"][0], 148 | new_token_map=batch["new_token_map"][0], 149 | ), 150 | } 151 | ) 152 | 153 | def on_validation_epoch_end(self): 154 | self.log_dict( 155 | {"val/" + k: v for k, v in self.evaluate(self.validation_step_predictions, self.validation_step_gold).items()} 156 | ) 157 | self.validation_step_predictions = [] 158 | self.validation_step_gold = [] 159 | 160 | def test_step(self, batch: dict, batch_idx: int) -> Any: 161 | output = self.model( 162 | stage="test", 163 | input_ids=batch["input_ids"], 164 | attention_mask=batch["attention_mask"], 165 | eos_mask=batch["eos_mask"], 166 | eos_indices=batch["eos_indices"], 167 | tokens=batch["tokens"], 168 | subtoken_map=batch["subtoken_map"], 169 | new_token_map=batch["new_token_map"], 170 | singletons=batch["singletons"], 171 | ) 172 | self.log_dict({"test/" + k: v for k, v in output["loss_dict"].items()}) 173 | output["pred_dict"]["clusters"] = original_token_offsets( 174 | clusters=output["pred_dict"]["clusters"], 175 | subtoken_map=batch["subtoken_map"][0], 176 | new_token_map=batch["new_token_map"][0], 177 | ) 178 | self.test_step_predictions.append(output["pred_dict"]) 179 | self.test_step_gold.append( 180 | { 181 | "gold_starts": batch["gold_starts"].cpu(), 182 | "gold_mentions": batch["gold_mentions"].cpu(), 183 | "gold_clusters": original_token_offsets( 184 | clusters=unpad_gold_clusters(batch["gold_clusters"].cpu()), 185 | subtoken_map=batch["subtoken_map"][0], 186 | new_token_map=batch["new_token_map"][0], 187 | ), 188 | } 189 | ) 190 | 191 | def on_test_epoch_end(self): 192 | self.log_dict({"test/" + k: v for k, v in self.evaluate(self.test_step_predictions, self.test_step_gold).items()}) 193 | self.test_step_predictions = [] 194 | self.test_step_gold = [] 195 | 196 | def configure_optimizers(self): 197 | if self.hparams.opt == "RAdam": 198 | opt = hydra.utils.instantiate(self.hparams.RAdam, params=self.parameters()) 199 | return opt 200 | else: 201 | return self.custom_opt() 202 | 203 | def custom_opt(self): 204 | no_decay = ["bias", "LayerNorm.weight"] 205 | head_params = ["representaion", "classifier"] 206 | 207 | model_decay = [ 208 | p 209 | for n, p in self.model.named_parameters() 210 | if not any(hp in n for hp in head_params) and not any(nd in n for nd in no_decay) 211 | ] 212 | model_no_decay = [ 213 | p 214 | for n, p in self.model.named_parameters() 215 | if not any(hp in n for hp in head_params) and any(nd in n for nd in no_decay) 216 | ] 217 | head_decay = [ 218 | p 219 | for n, p in self.model.named_parameters() 220 | if any(hp in n for hp in head_params) and not any(nd in n for nd in no_decay) 221 | ] 222 | head_no_decay = [ 223 | p for n, p in self.model.named_parameters() if any(hp in n for hp in head_params) and any(nd in n for nd in no_decay) 224 | ] 225 | 226 | head_learning_rate = 3e-4 227 | lr = 2e-5 228 | wd = 0.01 229 | optimizer_grouped_parameters = [ 230 | {"params": model_decay, "lr": lr, "weight_decay": wd}, 231 | {"params": model_no_decay, "lr": lr, "weight_decay": 0.0}, 232 | {"params": head_decay, "lr": head_learning_rate, "weight_decay": wd}, 233 | {"params": head_no_decay, "lr": head_learning_rate, "weight_decay": 0.0}, 234 | ] 235 | optimizer = Adafactor(optimizer_grouped_parameters, scale_parameter=False, relative_step=False) 236 | scheduler = transformers.get_linear_schedule_with_warmup( 237 | optimizer, 238 | num_warmup_steps=self.hparams.lr_scheduler.num_training_steps * 0.1, 239 | num_training_steps=self.hparams.lr_scheduler.num_training_steps, 240 | ) 241 | return { 242 | "optimizer": optimizer, 243 | "lr_scheduler": { 244 | "scheduler": scheduler, 245 | "interval": "step", 246 | "frequency": 1, 247 | }, 248 | } 249 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Maverick Coref 3 |

4 |
5 | 6 | 7 | [![Conference](https://img.shields.io/badge/ACL%202024%20Paper-red)](https://aclanthology.org/2024.acl-long.722.pdf) 8 | [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-green.svg)](https://creativecommons.org/licenses/by-nc/4.0/) 9 | [![Pip Package](https://img.shields.io/badge/🐍%20Python%20package-blue)](https://pypi.org/project/maverick-coref/) 10 | [![git](https://img.shields.io/badge/Git%20Repo%20-yellow.svg)](https://github.com/SapienzaNLP/maverick-coref) 11 |
12 | 13 | 14 | This is the official repository for [*Maverick: 15 | Efficient and Accurate Coreference Resolution Defying Recent Trends*](https://aclanthology.org/2024.acl-long.722). 16 | 17 | 18 | # Python Package 19 | The `maverick-coref` Python package provides an easy API to use Maverick models, enabling efficient and accurate coreference resolution with few lines of code. 20 | 21 | Install the library from [PyPI](https://pypi.org/project/maverick-coref/) 22 | 23 | ```bash 24 | pip install maverick-coref 25 | ``` 26 | or from source 27 | 28 | ```bash 29 | git clone https://github.com/SapienzaNLP/maverick-coref.git 30 | cd maverick-coref 31 | pip install -e . 32 | ``` 33 | 34 | ## Loading a Pretrained Model 35 | Maverick models can be loaded using huggingface_id or local path: 36 | ```bash 37 | from maverick import Maverick 38 | model = Maverick( 39 | hf_name_or_path = "maverick_hf_name" | "maverick_ckpt_path", default = "sapienzanlp/maverick-mes-ontonotes" 40 | device = "cpu" | "cuda", default = "cuda:0" 41 | ) 42 | ``` 43 | 44 | ## Available Models 45 | 46 | Available models at [SapienzaNLP huggingface hub](https://huggingface.co/collections/sapienzanlp/maverick-coreference-resolution-66a750a50246fad8d9c7086a): 47 | 48 | | hf_model_name | training dataset | Score | Singletons | 49 | |:-----------------------------------:|:----------------:|:-----:|:----------:| 50 | | ["sapienzanlp/maverick-mes-ontonotes"](https://huggingface.co/sapienzanlp/maverick-mes-ontonotes) | OntoNotes | 83.6 | No | 51 | | ["sapienzanlp/maverick-mes-litbank"](https://huggingface.co/sapienzanlp/maverick-mes-litbank) | LitBank | 78.0 | Yes | 52 | | ["sapienzanlp/maverick-mes-preco"](https://huggingface.co/sapienzanlp/maverick-mes-preco) | PreCo | 87.4 | Yes | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | N.B. Each dataset has different annotation guidelines, choose your model according to your use case. 63 | 64 | ## Inference 65 | ### Inputs 66 | Maverick inputs can be formatted as either: 67 | - plain text: 68 | ```bash 69 | text = "Barack Obama is traveling to Rome. The city is sunny and the president plans to visit its most important attractions" 70 | ``` 71 | - word-tokenized text, as a list of tokens: 72 | ```bash 73 | word_tokenized = ['Barack', 'Obama', 'is', 'traveling', 'to', 'Rome', '.', 'The', 'city', 'is', 'sunny', 'and', 'the', 'president', 'plans', 'to', 'visit', 'its', 'most', 'important', 'attractions'] 74 | ``` 75 | - sentence split, word-tokenized text, i.e., OntoNotes like input, as a list of lists of tokens: 76 | ```bash 77 | ontonotes_format = [['Barack', 'Obama', 'is', 'traveling', 'to', 'Rome', '.'], ['The', 'city', 'is', 'sunny', 'and', 'the', 'president', 'plans', 'to', 'visit', 'its', 'most', 'important', 'attractions']] 78 | ``` 79 | 80 | ### Predict 81 | You can use model.predict() to obtain coreference predictions. 82 | For a sample input, the model will a dictionary containing: 83 | - `tokens`, word tokenized version of the input. 84 | - `clusters_token_offsets`, a list of clusters containing mentions' token offsets. 85 | - `clusters_text_mentions`, a list of clusters containing mentions in plain text. 86 | 87 | Example: 88 | ```bash 89 | model.predict(ontonotes_format) 90 | >>> { 91 | 'tokens': ['Barack', 'Obama', 'is', 'traveling', 'to', 'Rome', '.', 'The', 'city', 'is', 'sunny', 'and', 'the', 'president', 'plans', 'to', 'visit', 'its', 'most', 'important', 'monument', ',', 'the', 'Colosseum'], 92 | 'clusters_token_offsets': [[(5, 5), (7, 8), (17, 17)], [(0, 1), (12, 13)]], 93 | 'clusters_text_mentions': [['Rome', 'The city', 'its'], ['Barack Obama', 'the president']] 94 | } 95 | ``` 96 | 97 | If you input plain text, the model will include also char level offsets as `clusters_char_offsets`: 98 | ```bash 99 | model.predict(text) 100 | >>> { 101 | 'tokens': [...], 102 | 'clusters_token_offsets': [...], 103 | 'clusters_char_offsets': [[(29, 32), (35, 42), (86, 88)], [(0, 11), (57, 69)]], 104 | 'clusters_text_mentions': [...] 105 | } 106 | ``` 107 | 108 | ### 🚨Additional Features🚨 109 | Since Coreference Resolution may serve as a stepping stone for many downstream use cases, in this package we cover multiple additional features: 110 | 111 | - **Singletons**, either include or exclude singletons (i.e., single mention clusters) prediction by setting `singletons` to `True` or `False`. 112 | *(hint: for accurate singletons use preco- or litbank-based models, since ontonotes does not include singletons and therefore the model is not trained to extract any)* 113 | ```bash 114 | #supported input: ontonotes_format 115 | model.predict(ontonotes_format, singletons=True) 116 | {'tokens': [...], 117 | 'clusters_token_offsets': [((5, 5), (7, 8), (17, 17)), ((0, 1), (12, 13)), ((17, 20),)], 118 | 'clusters_char_offsets': None, 119 | 'clusters_token_text': [['Rome', 'The city', 'its'], ['Barack Obama', 'the president'], ['its most important attractions']], 120 | 'clusters_char_text': None 121 | } 122 | ``` 123 | 124 | - **Clustering-only**, predict with predefined mentions (clustering-only), by passing mentions as a list of token offsets. 125 | ```bash 126 | #supported input: ontonotes_format 127 | mentions = [(0, 1), (5, 5), (7, 8)] 128 | model.predict(ontonotes_format, predefined_mentions=mentions) 129 | >>> {'tokens': [...], 130 | 'clusters_token_offsets': [((5, 5), (7, 8))], 131 | 'clusters_char_offsets': None, 132 | 'clusters_token_text': [['Rome', 'The city']], 133 | 'clusters_char_text': None} 134 | ``` 135 | 136 | - **Starting from gold clusters**, predict starting from gold clusters, by passing the model the mentions as a list of token offsets. 137 | *(Note: since starting clusters will be the first in the token offset outputs, to obtain the coreference resolution predictions **only for starting clusters** it is enough to take the first N clusters, where N is the number of starting clusters.)* 138 | ```bash 139 | #supported input: ontonotes_format 140 | clusters = [[(5, 5), (7, 8)], [(0, 1)]] 141 | model.predict(ontonotes_format, add_gold_clusters=clusters) 142 | >>> {'tokens': [...], 'clusters_token_offsets': [((5, 5), (7, 8), (17, 17)), ((0, 1), (12, 13))], 'clusters_char_offsets': None, 'clusters_token_text': [['Rome', 'The city', 'its'], ['Barack Obama', 'the president']], 'clusters_char_text': None} 143 | ``` 144 | 145 | - **Speaker information**, since OntoNotes models are trained with additional speaker information [(more info here)](https://catalog.ldc.upenn.edu/docs/LDC2013T19/OntoNotes-Release-5.0.pdf), you can specify speaker information with OntoNotes format. 146 | 147 | ```bash 148 | #supported input: ontonotes_format 149 | speakers = [["Mark", "Mark", "Mark", "Mark", "Mark"],["Jhon", "Jhon", "Jhon", "Jhon"]] 150 | model.predict(ontonotes_format, speakers=clusters) 151 | ``` 152 | 153 | # Using the official Training and Evaluation Script 154 | 155 | This same repository contains also the code to train and evaluate Maverick systems using pytorch-lightning and Hydra. 156 | 157 | **We strongly suggest to directly use the [python package](https://pypi.org/project/maverick-coref/) for easier inference and downstream usage.** 158 | 159 | 160 | 161 | ## Environment 162 | To set up the training and evaluation environment, run the bash script setup.sh that you can find at top level in this repository. This script will handle the creation of a new conda environment and will take care of all the requirements and data preprocessing for training and evaluating a model on OntoNotes. 163 | 164 | Simply run on the command line: 165 | ``` 166 | bash ./setup.sh 167 | ``` 168 | N.B. Remember to put the zip file *ontonotes-release-5.0_LDC2013T19.tgz* in the folder *data/prepare_ontonotes/* if you want to preprocess Ontonotes with the standard preprocessing proposed by [e2e-coref](https://github.com/kentonl/e2e-coref/). To do this, download *minimize.py* from [e2e-coref](https://github.com/kentonl/e2e-coref/) repository. 169 | OntoNotes dataset can be downloaded, upon registration, at the following [link](https://catalog.ldc.upenn.edu/LDC2013T19) 170 | 171 | 172 | ## Data 173 | Official Links: 174 | - [OntoNotes](https://catalog.ldc.upenn.edu/LDC2013T19) 175 | - [PreCo](https://drive.google.com/file/d/1q0oMt1Ynitsww9GkuhuwNZNq6SjByu-Y/view) 176 | - [LitBank](https://github.com/dbamman/litbank/tree/master/coref/conll) 177 | - [WikiCoref](http://rali.iro.umontreal.ca/rali/?q=en/wikicoref) 178 | 179 | Since those datasets usually require a preprocessing step to obtain the OntoNotes-like jsonlines format, we release ready-to-use version: 180 | https://drive.google.com/drive/u/3/folders/18dtd1Qt4h7vezlm2G0hF72aqFcAEFCUo. 181 | 182 | 183 | ## Hydra 184 | This repository uses [Hydra](https://hydra.cc/) configuration environment. 185 | 186 | - In *conf/data/* each yaml file contains a dataset configuration. 187 | - *conf/evaluation/* contains the model checkpoint file path and device settings for model evaluation. 188 | - *conf/logging/* contains details for wandb logging. 189 | - In *conf/model/*, each yaml file contains a model setup. 190 | - *conf/train/* contains training configurations. 191 | - *conf/root.yaml* regulates the overall configuration of the environment. 192 | 193 | 194 | ## Train 195 | To train a Maverick model, modify *conf/root.yaml* with your custom setup. 196 | By default, this file contains the settings for training and evaluating on the OntoNotes dataset. 197 | 198 | To train a new model, follow the steps in [Environment](#environment) section and run the following script: 199 | ``` 200 | conda activate maverick_env 201 | python maverick/train.py 202 | ``` 203 | 204 | 205 | ## Evaluate 206 | To evaluate an existing model, it is necessary to set up two different environment variables. 207 | 1. Set the dataset path in conf/root.yaml, by default it is set to OntoNotes. 208 | 2. Set the model checkpoint path in conf/evaluation/default_evaluation.yaml. 209 | 210 | Finally run the following: 211 | ``` 212 | conda activate env_name 213 | python maverick/evaluate.py 214 | ``` 215 | This will directly output the CoNLL-2012 scores, and, under the experiments/ folder, a output.jsonlines file containing the model outputs in OntoNotes style. 216 | 217 | ### Replicate paper results 218 | The weights of each model can be found in the [SapienzaNLP huggingface hub](https://huggingface.co/collections/sapienzanlp/maverick-coreference-resolution-66a750a50246fad8d9c7086a). 219 | To replicate any of the paper results, download the weights.ckpt of a model from the its model card files and follow the steps reported in the [Evaluate](#evaluate) section. 220 | 221 | E.G. to replicate the state of the art results of *Maverick_mes* on OntoNotes: 222 | - download the weights from [here](https://huggingface.co/sapienzanlp/maverick-mes-ontonotes/blob/main/weights.ckpt). 223 | - copy the local path of the weights in conf/evaluation/default_evaluation.yaml. 224 | - activate the project's conda environment with *conda activate maverick_coref*. 225 | - run *python maverick/evaluate.py* 226 | 227 | # Citation 228 | This work has been published at [ACL 2024 main conference](https://aclanthology.org/2024.acl-long.722.pdf). If you use any part, please consider citing our paper as follows: 229 | ```bibtex 230 | @inproceedings{martinelli-etal-2024-maverick, 231 | title = "Maverick: Efficient and Accurate Coreference Resolution Defying Recent Trends", 232 | author = "Martinelli, Giuliano and 233 | Barba, Edoardo and 234 | Navigli, Roberto", 235 | booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 236 | month = aug, 237 | year = "2024", 238 | address = "Bangkok, Thailand", 239 | publisher = "Association for Computational Linguistics", 240 | url = "https://aclanthology.org/2024.acl-long.722", 241 | pages = "13380--13394", 242 | } 243 | ``` 244 | 245 | 246 | ## License 247 | 248 | The data and software are licensed under [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/). 249 | 250 | 251 | -------------------------------------------------------------------------------- /maverick/models/model_incr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import random 4 | 5 | from transformers import ( 6 | AutoModel, 7 | AutoConfig, 8 | DistilBertForSequenceClassification, 9 | DistilBertConfig, 10 | ) 11 | 12 | from maverick.common.util import * 13 | from maverick.common.constants import * 14 | 15 | 16 | class MentionClusterClassifier(torch.nn.Module): 17 | def __init__(self, model): 18 | super().__init__() 19 | self.model = model 20 | 21 | def forward(self, mention_hidden_states, cluster_hidden_states, attention_mask, labels=None): 22 | # repreated tensor of mention_hs, to append in first position for each possible mention cluster pair 23 | repeated_mention_hs = mention_hidden_states.unsqueeze(0).repeat(cluster_hidden_states.shape[0], 1, 1) 24 | 25 | # mention cluste pairs by contatenating mention vectors to cluster padded matrix 26 | mention_cluster_pairs = torch.cat((repeated_mention_hs, cluster_hidden_states), dim=1) 27 | attention_mask = torch.cat( 28 | ( 29 | torch.ones(cluster_hidden_states.shape[0], 1, device=self.model.device), 30 | attention_mask, 31 | ), 32 | dim=1, 33 | ) 34 | 35 | logits = self.model(inputs_embeds=mention_cluster_pairs, attention_mask=attention_mask).logits 36 | 37 | loss = None 38 | if labels is not None: 39 | loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.unsqueeze(1).to(self.model.device)) 40 | return loss, logits 41 | 42 | 43 | class Maverick_incr(torch.nn.Module): 44 | def __init__(self, *args, **kwargs): 45 | super().__init__() 46 | # document transformer encoder 47 | self.encoder_hf_model_name = kwargs["huggingface_model_name"] 48 | self.encoder = AutoModel.from_pretrained(self.encoder_hf_model_name) 49 | self.encoder_config = AutoConfig.from_pretrained(self.encoder_hf_model_name) 50 | self.encoder.resize_token_embeddings(self.encoder.embeddings.word_embeddings.num_embeddings + 3) 51 | 52 | # freeze 53 | if kwargs["freeze_encoder"]: 54 | for param in self.encoder.parameters(): 55 | param.requires_grad = False 56 | 57 | # type of representation layer in 'Linear, FC, LSTM-left, LSTM-right, Conv1d' 58 | self.representation_layer_type = "FC" # fullyconnected 59 | # span hidden dimension 60 | self.token_hidden_size = self.encoder_config.hidden_size 61 | 62 | # if span representation method is to concatenate start and end, a mention hidden size will be 2*token_hidden_size 63 | self.mention_hidden_size = self.token_hidden_size * 2 64 | 65 | # incremental transformer classifier 66 | self.incremental_model_hidden_size = kwargs.get("incremental_model_hidden_size", 384) # 768/2 67 | self.incremental_model_num_layers = kwargs.get("incremental_model_num_layers", 1) 68 | self.incremental_model_config = DistilBertConfig(num_labels=1, hidden_size=self.incremental_model_hidden_size) 69 | self.incremental_model = DistilBertForSequenceClassification(self.incremental_model_config).to(self.encoder.device) 70 | self.incremental_model.distilbert.transformer.layer = self.incremental_model.distilbert.transformer.layer[ 71 | : self.incremental_model_num_layers 72 | ] 73 | self.incremental_model.distilbert.embeddings.word_embeddings = None 74 | self.incremental_transformer = MentionClusterClassifier(model=self.incremental_model) 75 | 76 | # encodes mentions for incremental clustering 77 | self.incremental_span_encoder = RepresentationLayer( 78 | type=self.representation_layer_type, 79 | input_dim=self.mention_hidden_size, 80 | output_dim=self.incremental_model_hidden_size, 81 | hidden_dim=int(self.mention_hidden_size / 2), 82 | ) 83 | 84 | # mention extraction layers 85 | # representation of start token 86 | self.start_token_representation = RepresentationLayer( 87 | type=self.representation_layer_type, 88 | input_dim=self.token_hidden_size, 89 | output_dim=self.token_hidden_size, 90 | hidden_dim=self.token_hidden_size, 91 | ) 92 | 93 | # representation of end token 94 | self.end_token_representation = RepresentationLayer( 95 | type=self.representation_layer_type, 96 | input_dim=self.token_hidden_size, 97 | output_dim=self.token_hidden_size, 98 | hidden_dim=self.token_hidden_size, 99 | ) 100 | 101 | # models probability to be the start of a mention 102 | self.start_token_classifier = RepresentationLayer( 103 | type=self.representation_layer_type, 104 | input_dim=self.token_hidden_size, 105 | output_dim=1, 106 | hidden_dim=self.token_hidden_size, 107 | ) 108 | 109 | # model mention probability from start and end representations 110 | self.start_end_classifier = RepresentationLayer( 111 | type=self.representation_layer_type, 112 | input_dim=self.mention_hidden_size, 113 | output_dim=1, 114 | hidden_dim=self.token_hidden_size, 115 | ) 116 | 117 | # takes last_hidden_states, eos_mask, ground truth and stage 118 | def squad_mention_extraction(self, lhs, eos_mask, gold_mentions, gold_starts, stage): 119 | start_idxs = [] 120 | mention_idxs = [] 121 | start_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 122 | mention_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 123 | 124 | for bidx in range(0, lhs.shape[0]): 125 | lhs_batch = lhs[bidx] # SEQ_LEN X HIDD_DIM 126 | eos_mask_batch = eos_mask[bidx] # SEQ_LEN X SEQ_LEN 127 | 128 | # compute start logits 129 | start_logits_batch = self.start_token_classifier(lhs_batch).squeeze(-1) # SEQ_LEN 130 | 131 | if gold_starts != None: 132 | loss = torch.nn.functional.binary_cross_entropy_with_logits(start_logits_batch, gold_starts[bidx]) 133 | 134 | # accumulate loss 135 | start_loss = start_loss + loss 136 | 137 | # compute start positions 138 | start_idxs_batch = ((torch.sigmoid(start_logits_batch) > 0.5)).nonzero(as_tuple=False).squeeze(-1) 139 | 140 | start_idxs.append(start_idxs_batch.detach().clone()) 141 | # in training, use gold starts to learn to extract mentions, inference use predicted ones 142 | if stage == "train": 143 | start_idxs_batch = ( 144 | ((torch.sigmoid(gold_starts[bidx]) > 0.5)).nonzero(as_tuple=False).squeeze(-1) 145 | ) # NUM_GOLD_STARTS 146 | 147 | # contains all possible start end indices pairs, i.e. for all starts, all possible ends looking at EOS index 148 | possibles_start_end_idxs = (eos_mask_batch[start_idxs_batch] == 1).nonzero(as_tuple=False) # STARTS x 2 149 | 150 | # this is to have reference respect to original positions 151 | possibles_start_end_idxs[:, 0] = start_idxs_batch[possibles_start_end_idxs[:, 0]] 152 | 153 | possible_start_idxs = possibles_start_end_idxs[:, 0] 154 | possible_end_idxs = possibles_start_end_idxs[:, 1] 155 | 156 | # extract start and end hidden states 157 | starts_hidden_states = lhs_batch[possible_end_idxs] # start 158 | ends_hidden_states = lhs_batch[possible_start_idxs] # end 159 | 160 | # concatenation of start to end representations created using a representation layer 161 | s2e_representations = torch.cat( 162 | ( 163 | self.start_token_representation(starts_hidden_states), 164 | self.end_token_representation(ends_hidden_states), 165 | ), 166 | dim=-1, 167 | ) 168 | 169 | # classification of mentions 170 | s2e_logits = self.start_end_classifier(s2e_representations).squeeze(-1) 171 | 172 | # mention_start_idxs and mention_end_idxs 173 | mention_idxs.append(possibles_start_end_idxs[torch.sigmoid(s2e_logits) > 0.5].detach().clone()) 174 | 175 | if s2e_logits.shape[0] != 0: 176 | if gold_mentions != None: 177 | mention_loss_batch = torch.nn.functional.binary_cross_entropy_with_logits( 178 | s2e_logits, 179 | gold_mentions[bidx][possible_start_idxs, possible_end_idxs], 180 | ) 181 | mention_loss = mention_loss + mention_loss_batch 182 | 183 | return (start_idxs, mention_idxs, start_loss, mention_loss) 184 | 185 | def incremental_span_clustering(self, mentions_hidden_states, mentions_idxs, gold_clusters, stage): 186 | pred_cluster_idxs = [] # cluster_idxs = list of list of tuple of offsets (also output) up to mention_idx 187 | if gold_clusters != None: 188 | gold_cluster_idxs = unpad_gold_clusters(gold_clusters) # gold_cluster_idxs, but padded 189 | 190 | coreference_loss = torch.tensor([0.0], requires_grad=True, device=self.incremental_model.device) 191 | mentions_hidden_states = mentions_hidden_states[0] 192 | idx_to_hs = dict(zip([tuple(m) for m in mentions_idxs.tolist()], mentions_hidden_states)) 193 | 194 | # for each mention 195 | for idx, ( 196 | mention_hidden_states, 197 | (mention_start_idx, mention_end_idx), 198 | ) in enumerate(zip(mentions_hidden_states, mentions_idxs)): 199 | if idx == 0: 200 | # if first create singleton cluster 201 | pred_cluster_idxs.append([(mention_start_idx.item(), mention_end_idx.item())]) 202 | else: 203 | if stage == "train": 204 | # if we are in training, retrieve use gold cluster idx to induce loss. 205 | cluster_idx, labels = self.new_cluster_idxs_labels( 206 | (mention_start_idx, mention_end_idx), gold_cluster_idxs 207 | ) # can be used using only tensors 208 | else: 209 | cluster_idx, labels = pred_cluster_idxs, None 210 | 211 | # get cluster padded matrix matrix and attention mask (excludes padding) 212 | cluster_hs, cluster_am = self.get_cluster_states_matrix(idx_to_hs, cluster_idx, stage) 213 | 214 | # produce logits for each possible cluster mention pair 215 | mention_cluster_loss, logits = self.incremental_transformer( 216 | mention_hidden_states=mention_hidden_states, 217 | cluster_hidden_states=cluster_hs, 218 | attention_mask=cluster_am, 219 | labels=labels, 220 | ) 221 | 222 | if mention_cluster_loss != None: 223 | coreference_loss = coreference_loss + mention_cluster_loss 224 | 225 | if stage != "train": 226 | # only in inference 227 | num_possible_clustering = torch.sum(torch.sigmoid(logits) > 0.5, dim=0).bool().float() 228 | 229 | if num_possible_clustering == 0: 230 | # if no clustering, create new singleton cluster 231 | pred_cluster_idxs.append([(mention_start_idx.item(), mention_end_idx.item())]) 232 | else: 233 | # otherwise, take most probabile clustering predicted by the model and assign this mention to that cluster 234 | assigned_idx = logits.argmax(axis=0).detach().cpu() 235 | pred_cluster_idxs[assigned_idx.item()].append((mention_start_idx.item(), mention_end_idx.item())) 236 | # normalize loss debug 237 | if gold_clusters != None: 238 | coreference_loss = coreference_loss / (mentions_hidden_states.shape[0] if mentions_hidden_states.shape[0] != 0 else 1) 239 | 240 | # coreference_loss = coreference_loss / (len(gold_cluster_idxs) if len(gold_cluster_idxs) != 0 else 1) 241 | coreferences_pred = [tuple(item) for item in pred_cluster_idxs] # if len(item) > 1] 242 | return coreference_loss, coreferences_pred 243 | 244 | def get_cluster_states_matrix(self, idx_to_hs, cluster_idxs, stage): 245 | # create padded matrix of encoded mentions 246 | max_length = max([len(x) for x in cluster_idxs]) 247 | if stage == "train": 248 | max_length = max_length if max_length < 31 else 30 249 | forward_matrix = torch.zeros( 250 | (len(cluster_idxs), max_length, self.incremental_model_hidden_size), 251 | device=self.encoder.device, 252 | ) 253 | forward_am = torch.zeros((len(cluster_idxs), max_length), device=self.encoder.device) 254 | 255 | for cluster_idx, span_idxs in enumerate(cluster_idxs): 256 | if stage == "train": 257 | if len(span_idxs) > 30: 258 | span_idxs = sorted(span_idxs) 259 | new_idxs = [span_idxs[0]] 260 | new_idxs.extend(random.sample(span_idxs, 28)) 261 | new_idxs.append(span_idxs[-1]) 262 | span_idxs = new_idxs 263 | 264 | hs = torch.stack([idx_to_hs[span_idx] for span_idx in span_idxs]) 265 | 266 | forward_matrix[cluster_idx][: hs.shape[0]] = hs 267 | forward_am[cluster_idx][: hs.shape[0]] = torch.ones((hs.shape[0]), device=self.encoder.device) 268 | 269 | return forward_matrix, forward_am 270 | 271 | # takes the index of the mention (mention_start, mention_end) and gold coreferences, returns filtered indices (up to mention idx) and labels 272 | def new_cluster_idxs_labels(self, mention_idxs, gold_coreference_idxs): 273 | res_coreference_idxs = [] 274 | # list of length number of clusters in gold, and 1.0 where the mention is laying 275 | labels = [ 276 | 1.0 if (mention_idxs[0].item(), mention_idxs[1].item()) in span_idx else 0.0 for span_idx in gold_coreference_idxs 277 | ] 278 | # filter cluster up to the mention you are evaluating 279 | for cluster_idxs in gold_coreference_idxs: 280 | idxs = [] 281 | for span_idx in cluster_idxs: 282 | # if span is antecedent to current mention, stay in possible clusters 283 | if span_idx[0] < mention_idxs[0].item() or ( 284 | span_idx[0] == mention_idxs[0].item() and span_idx[1] < mention_idxs[1].item() 285 | ): 286 | idxs.append((span_idx[0], span_idx[1])) 287 | # idxs = sorted(idxs, reverse=True) 288 | res_coreference_idxs.append(idxs) 289 | 290 | labels = torch.tensor( 291 | [lab for lab, idx in zip(labels, res_coreference_idxs) if len(idx) != 0], 292 | device=self.encoder.device, 293 | ) 294 | res_coreference_idxs = [idx for idx in res_coreference_idxs if len(idx) != 0] 295 | return res_coreference_idxs, labels 296 | 297 | def forward( 298 | self, 299 | stage, 300 | input_ids, 301 | attention_mask, 302 | eos_mask, 303 | gold_starts=None, 304 | gold_mentions=None, 305 | gold_clusters=None, 306 | ): 307 | last_hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] # B x S x TH 308 | 309 | lhs = last_hidden_states 310 | 311 | loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 312 | loss_dict = {} 313 | preds = {} 314 | 315 | ( 316 | start_idxs, 317 | mention_idxs, 318 | start_loss, 319 | mention_loss, 320 | ) = self.squad_mention_extraction( 321 | lhs=last_hidden_states, 322 | eos_mask=eos_mask, 323 | gold_mentions=gold_mentions, 324 | gold_starts=gold_starts, 325 | stage=stage, 326 | ) 327 | 328 | loss_dict["start_loss"] = start_loss 329 | preds["start_idxs"] = [start.detach().cpu() for start in start_idxs] 330 | 331 | loss_dict["mention_loss"] = mention_loss 332 | preds["mention_idxs"] = [mention.detach().cpu() for mention in mention_idxs] 333 | 334 | loss = loss + start_loss + mention_loss 335 | 336 | if stage == "train": 337 | mention_idxs = (gold_mentions[0] == 1).nonzero(as_tuple=False) 338 | else: 339 | mention_idxs = mention_idxs[0] 340 | 341 | mention_start_idxs = mention_idxs[:, 0] 342 | mention_end_idxs = mention_idxs[:, 1] 343 | 344 | mentions_start_hidden_states = torch.index_select(lhs, 1, mention_start_idxs) 345 | mentions_end_hidden_states = torch.index_select(lhs, 1, mention_end_idxs) 346 | 347 | mentions_hidden_states = torch.cat((mentions_start_hidden_states, mentions_end_hidden_states), dim=2) 348 | 349 | mentions_hidden_states = self.incremental_span_encoder(mentions_hidden_states) 350 | 351 | coreference_loss, coreferences = self.incremental_span_clustering( 352 | mentions_hidden_states, mention_idxs, gold_clusters, stage 353 | ) 354 | 355 | loss = loss + coreference_loss 356 | loss_dict["coreference_loss"] = coreference_loss 357 | 358 | if stage != "train": 359 | preds["clusters"] = coreferences 360 | 361 | loss_dict["full_loss"] = loss 362 | output = {"pred_dict": preds, "loss_dict": loss_dict, "loss": loss} 363 | 364 | return output 365 | -------------------------------------------------------------------------------- /maverick/models/model_s2e.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from transformers import AutoModel, AutoConfig 5 | 6 | from maverick.common.util import * 7 | from maverick.common.constants import * 8 | 9 | 10 | class Maverick_s2e(torch.nn.Module): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__() 13 | # document transformer encoder 14 | self.encoder_hf_model_name = kwargs["huggingface_model_name"] 15 | self.encoder = AutoModel.from_pretrained(self.encoder_hf_model_name) 16 | self.encoder_config = AutoConfig.from_pretrained(self.encoder_hf_model_name) 17 | # self.encoder_config.attention_window = 1024 18 | self.encoder.resize_token_embeddings(self.encoder.embeddings.word_embeddings.num_embeddings + 3) 19 | 20 | # freeze 21 | if kwargs["freeze_encoder"]: 22 | for param in self.encoder.parameters(): 23 | param.requires_grad = False 24 | 25 | # span representation, now is concat_start_end 26 | self.span_representation = kwargs["span_representation"] 27 | # type of representation layer in 'Linear, FC, LSTM-left, LSTM-right, Conv1d' 28 | self.representation_layer_type = "FC" # fullyconnected 29 | # span hidden dimension 30 | self.token_hidden_size = self.encoder_config.hidden_size 31 | 32 | # if span representation method is to concatenate start and end, a mention hidden size will be 2*token_hidden_size 33 | if self.span_representation == "concat_start_end": 34 | self.mention_hidden_size = self.token_hidden_size * 2 35 | 36 | self.antecedent_s2s_classifier = RepresentationLayer( 37 | type=self.representation_layer_type, # fullyconnected 38 | input_dim=self.token_hidden_size, 39 | output_dim=self.token_hidden_size, 40 | hidden_dim=self.token_hidden_size, 41 | ) 42 | self.antecedent_e2e_classifier = RepresentationLayer( 43 | type=self.representation_layer_type, # fullyconnected 44 | input_dim=self.token_hidden_size, 45 | output_dim=self.token_hidden_size, 46 | hidden_dim=self.token_hidden_size, 47 | ) 48 | 49 | self.antecedent_s2e_classifier = RepresentationLayer( 50 | type=self.representation_layer_type, # fullyconnected 51 | input_dim=self.token_hidden_size, 52 | output_dim=self.token_hidden_size, 53 | hidden_dim=self.token_hidden_size, 54 | ) 55 | self.antecedent_e2s_classifier = RepresentationLayer( 56 | type=self.representation_layer_type, # fullyconnected 57 | input_dim=self.token_hidden_size, 58 | output_dim=self.token_hidden_size, 59 | hidden_dim=self.token_hidden_size, 60 | ) 61 | 62 | # mention extraction layers 63 | # representation of start token 64 | self.start_token_representation = RepresentationLayer( 65 | type=self.representation_layer_type, 66 | input_dim=self.token_hidden_size, 67 | output_dim=self.token_hidden_size, 68 | hidden_dim=self.token_hidden_size, 69 | ) 70 | 71 | # representation of end token 72 | self.end_token_representation = RepresentationLayer( 73 | type=self.representation_layer_type, 74 | input_dim=self.token_hidden_size, 75 | output_dim=self.token_hidden_size, 76 | hidden_dim=self.token_hidden_size, 77 | ) 78 | 79 | # models probability to be the start of a mention 80 | self.start_token_classifier = RepresentationLayer( 81 | type=self.representation_layer_type, 82 | input_dim=self.token_hidden_size, 83 | output_dim=1, 84 | hidden_dim=self.token_hidden_size, 85 | ) 86 | 87 | # model mention probability from start and end representations 88 | self.start_end_classifier = RepresentationLayer( 89 | type=self.representation_layer_type, 90 | input_dim=self.mention_hidden_size, 91 | output_dim=1, 92 | hidden_dim=self.token_hidden_size, 93 | ) 94 | 95 | # takes last_hidden_states, eos_mask, ground truth and stage 96 | # check if mention_mask deletes some good gold mentions! 97 | def squad_mention_extraction(self, lhs, eos_mask, gold_mentions, gold_starts, stage): 98 | start_idxs = [] 99 | mention_idxs = [] 100 | start_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 101 | mention_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 102 | 103 | for bidx in range(0, lhs.shape[0]): 104 | lhs_batch = lhs[bidx] # SEQ_LEN X HIDD_DIM 105 | eos_mask_batch = eos_mask[bidx] # SEQ_LEN X SEQ_LEN 106 | 107 | # compute start logits 108 | start_logits_batch = self.start_token_classifier(lhs_batch).squeeze(-1) # SEQ_LEN 109 | 110 | if gold_starts != None: 111 | loss = torch.nn.functional.binary_cross_entropy_with_logits(start_logits_batch, gold_starts[bidx]) 112 | 113 | # accumulate loss 114 | start_loss = start_loss + loss 115 | 116 | # compute start positions 117 | start_idxs_batch = ((torch.sigmoid(start_logits_batch) > 0.5)).nonzero(as_tuple=False).squeeze(-1) 118 | 119 | start_idxs.append(start_idxs_batch.detach().clone()) 120 | # in training, use gold starts to learn to extract mentions, inference use predicted ones 121 | if stage == "train": 122 | start_idxs_batch = ( 123 | ((torch.sigmoid(gold_starts[bidx]) > 0.5)).nonzero(as_tuple=False).squeeze(-1) 124 | ) # NUM_GOLD_STARTS 125 | 126 | # contains all possible start end indices pairs, i.e. for all starts, all possible ends looking at EOS index 127 | possibles_start_end_idxs = (eos_mask_batch[start_idxs_batch] == 1).nonzero(as_tuple=False) # STARTS x 2 128 | 129 | # this is to have reference respect to original positions 130 | possibles_start_end_idxs[:, 0] = start_idxs_batch[possibles_start_end_idxs[:, 0]] 131 | 132 | possible_start_idxs = possibles_start_end_idxs[:, 0] 133 | possible_end_idxs = possibles_start_end_idxs[:, 1] 134 | 135 | # extract start and end hidden states 136 | starts_hidden_states = lhs_batch[possible_end_idxs] # start 137 | ends_hidden_states = lhs_batch[possible_start_idxs] # end 138 | 139 | # concatenation of start to end representations created using a representation layer 140 | s2e_representations = torch.cat( 141 | ( 142 | self.start_token_representation(starts_hidden_states), 143 | self.end_token_representation(ends_hidden_states), 144 | ), 145 | dim=-1, 146 | ) 147 | 148 | # classification of mentions 149 | s2e_logits = self.start_end_classifier(s2e_representations).squeeze(-1) 150 | 151 | # mention_start_idxs and mention_end_idxs 152 | mention_idxs.append(possibles_start_end_idxs[torch.sigmoid(s2e_logits) > 0.5].detach().clone()) 153 | 154 | if s2e_logits.shape[0] != 0: 155 | if gold_mentions != None: 156 | mention_loss_batch = torch.nn.functional.binary_cross_entropy_with_logits( 157 | s2e_logits, 158 | gold_mentions[bidx][possible_start_idxs, possible_end_idxs], 159 | ) 160 | mention_loss = mention_loss + mention_loss_batch 161 | 162 | return (start_idxs, mention_idxs, start_loss, mention_loss) 163 | 164 | def create_mention_to_antecedent_singletons(self, span_starts, span_ends, coref_logits): 165 | span_starts = span_starts.unsqueeze(0) 166 | span_ends = span_ends.unsqueeze(0) 167 | bs, n_spans, _ = coref_logits.shape 168 | 169 | no_ant = 1 - torch.sum(torch.sigmoid(coref_logits) > 0.5, dim=-1).bool().float() 170 | # [batch_size, max_k, max_k + 1] 171 | coref_logits = torch.cat((coref_logits, no_ant.unsqueeze(-1)), dim=-1) 172 | 173 | span_starts = span_starts.detach().cpu() 174 | span_ends = span_ends.detach().cpu() 175 | max_antecedents = coref_logits.argmax(axis=-1).detach().cpu() 176 | doc_indices = np.nonzero(max_antecedents < n_spans)[:, 0] 177 | # indices where antecedent is not null. 178 | mention_indices = np.nonzero(max_antecedents < n_spans)[:, 1] 179 | 180 | antecedent_indices = max_antecedents[max_antecedents < n_spans] 181 | span_indices = np.stack([span_starts.detach().cpu(), span_ends.detach().cpu()], axis=-1) 182 | 183 | mentions = span_indices[doc_indices, mention_indices] 184 | antecedents = span_indices[doc_indices, antecedent_indices] 185 | non_mentions = np.nonzero(max_antecedents == n_spans)[:, 1] 186 | 187 | sing_indices = np.zeros_like(len(np.setdiff1d(non_mentions, antecedent_indices))) 188 | singletons = span_indices[sing_indices, np.setdiff1d(non_mentions, antecedent_indices)] 189 | 190 | # mention_to_antecedent = np.stack([mentions, antecedents], axis=1) 191 | 192 | if len(mentions.shape) == 1 and len(antecedents.shape) == 1: 193 | mention_to_antecedent = np.stack([mentions, antecedents], axis=0) 194 | else: 195 | mention_to_antecedent = np.stack([mentions, antecedents], axis=1) 196 | 197 | if len(mentions.shape) == 1: 198 | mention_to_antecedent = [mention_to_antecedent] 199 | 200 | if len(singletons.shape) == 1: 201 | singletons = [singletons] 202 | 203 | return doc_indices, mention_to_antecedent, singletons 204 | 205 | def _get_cluster_labels_after_pruning(self, span_starts, span_ends, all_clusters): 206 | """ 207 | :param span_starts: [batch_size, max_k] 208 | :param span_ends: [batch_size, max_k] 209 | :param all_clusters: [batch_size, max_cluster_size, max_clusters_num, 2] 210 | :return: [batch_size, max_k, max_k + 1] - [b, i, j] == 1 if i is antecedent of j 211 | """ 212 | span_starts = span_starts.unsqueeze(0) 213 | span_ends = span_ends.unsqueeze(0) 214 | batch_size, max_k = span_starts.size() 215 | new_cluster_labels = torch.zeros((batch_size, max_k, max_k), device="cpu") 216 | all_clusters_cpu = all_clusters.cpu().numpy() 217 | for b, (starts, ends, gold_clusters) in enumerate( 218 | zip(span_starts.cpu().tolist(), span_ends.cpu().tolist(), all_clusters_cpu) 219 | ): 220 | gold_clusters = self.extract_clusters(gold_clusters) 221 | mention_to_gold_clusters = extract_mentions_to_clusters(gold_clusters) 222 | gold_mentions = set(mention_to_gold_clusters.keys()) 223 | for i, (start, end) in enumerate(zip(starts, ends)): 224 | if (start, end) not in gold_mentions: 225 | continue 226 | for j, (a_start, a_end) in enumerate(list(zip(starts, ends))[:i]): 227 | if (a_start, a_end) in mention_to_gold_clusters[(start, end)]: 228 | new_cluster_labels[b, i, j] = 1 229 | new_cluster_labels = new_cluster_labels.to(self.encoder.device) 230 | return new_cluster_labels 231 | 232 | def s2e_mention_cluster(self, mention_start_reps, mention_end_reps, mention_start_idxs, mention_end_idxs, gold, stage): 233 | coref_logits = self._calc_coref_logits(mention_start_reps, mention_end_reps) 234 | coreference_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 235 | 236 | coref_logits = torch.stack([matrix.tril().fill_diagonal_(0) for matrix in coref_logits]) 237 | if stage == "train": 238 | labels = self._get_cluster_labels_after_pruning(mention_start_idxs, mention_end_idxs, gold) 239 | coreference_loss = torch.nn.functional.binary_cross_entropy_with_logits(coref_logits, labels) 240 | 241 | doc, m2a, singletons = self.create_mention_to_antecedent_singletons(mention_start_idxs, mention_end_idxs, coref_logits) 242 | coreferences = self.create_clusters(m2a, singletons) 243 | return coreference_loss, coreferences 244 | 245 | def create_clusters(self, m2a, singletons): 246 | # Note: mention_to_antecedent is a numpy array 247 | 248 | clusters, mention_to_cluster = [], {} 249 | for mention, antecedent in m2a: 250 | mention, antecedent = tuple(mention), tuple(antecedent) 251 | if antecedent in mention_to_cluster: 252 | cluster_idx = mention_to_cluster[antecedent] 253 | if mention not in clusters[cluster_idx]: 254 | clusters[cluster_idx].append(mention) 255 | mention_to_cluster[mention] = cluster_idx 256 | elif mention in mention_to_cluster: 257 | cluster_idx = mention_to_cluster[mention] 258 | if antecedent not in clusters[cluster_idx]: 259 | clusters[cluster_idx].append(antecedent) 260 | mention_to_cluster[antecedent] = cluster_idx 261 | else: 262 | cluster_idx = len(clusters) 263 | mention_to_cluster[mention] = cluster_idx 264 | mention_to_cluster[antecedent] = cluster_idx 265 | clusters.append([antecedent, mention]) 266 | 267 | clusters = [tuple(cluster) for cluster in clusters] 268 | if len(singletons) != 0: 269 | clust = [] 270 | while len(clusters) != 0 or len(singletons) != 0: 271 | if len(singletons) == 0: 272 | clust.append(clusters[0]) 273 | clusters = clusters[1:] 274 | elif len(clusters) == 0: 275 | clust.append(tuple([tuple(singletons[0])])) 276 | singletons = singletons[1:] 277 | elif singletons[0][0] < sorted(clusters[0], key=lambda x: x[0])[0][0]: 278 | clust.append(tuple([tuple(singletons[0])])) 279 | singletons = singletons[1:] 280 | else: 281 | clust.append(clusters[0]) 282 | clusters = clusters[1:] 283 | return clust 284 | return clusters 285 | 286 | def _calc_coref_logits(self, top_k_start_coref_reps, top_k_end_coref_reps): 287 | # s2s 288 | temp = self.antecedent_s2s_classifier(top_k_start_coref_reps) # [batch_size, max_k, dim] 289 | top_k_s2s_coref_logits = torch.matmul(temp, top_k_start_coref_reps.permute([0, 2, 1])) # [batch_size, max_k, max_k] 290 | 291 | # e2e 292 | temp = self.antecedent_e2e_classifier(top_k_end_coref_reps) # [batch_size, max_k, dim] 293 | top_k_e2e_coref_logits = torch.matmul(temp, top_k_end_coref_reps.permute([0, 2, 1])) # [batch_size, max_k, max_k] 294 | 295 | # s2e 296 | temp = self.antecedent_s2e_classifier(top_k_start_coref_reps) # [batch_size, max_k, dim] 297 | top_k_s2e_coref_logits = torch.matmul(temp, top_k_end_coref_reps.permute([0, 2, 1])) # [batch_size, max_k, max_k] 298 | 299 | # e2s 300 | temp = self.antecedent_e2s_classifier(top_k_end_coref_reps) # [batch_size, max_k, dim] 301 | top_k_e2s_coref_logits = torch.matmul(temp, top_k_start_coref_reps.permute([0, 2, 1])) # [batch_size, max_k, max_k] 302 | 303 | # sum all terms 304 | coref_logits = ( 305 | top_k_s2e_coref_logits + top_k_e2s_coref_logits + top_k_s2s_coref_logits + top_k_e2e_coref_logits 306 | ) # [batch_size, max_k, max_k] 307 | return coref_logits 308 | 309 | def extract_clusters(gold_clusters): 310 | gold_clusters = [tuple(tuple(m) for m in cluster if (-1) not in m) for cluster in gold_clusters] 311 | gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 0] 312 | return gold_clusters 313 | 314 | def forward( 315 | self, 316 | stage, 317 | input_ids, 318 | attention_mask, 319 | eos_mask, 320 | gold_starts=None, 321 | gold_mentions=None, 322 | gold_clusters=None, 323 | ): 324 | last_hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask)["last_hidden_state"] # B x S x TH 325 | 326 | lhs = last_hidden_states 327 | 328 | loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 329 | loss_dict = {} 330 | preds = {} 331 | 332 | ( 333 | start_idxs, 334 | mention_idxs, 335 | start_loss, 336 | mention_loss, 337 | ) = self.squad_mention_extraction( 338 | lhs=last_hidden_states, 339 | eos_mask=eos_mask, 340 | gold_mentions=gold_mentions, 341 | gold_starts=gold_starts, 342 | stage=stage, 343 | ) 344 | 345 | loss_dict["start_loss"] = start_loss 346 | preds["start_idxs"] = [start.detach().cpu() for start in start_idxs] 347 | 348 | loss_dict["mention_loss"] = mention_loss 349 | preds["mention_idxs"] = [mention.detach().cpu() for mention in mention_idxs] 350 | 351 | loss = loss + start_loss + mention_loss 352 | 353 | if stage == "train": 354 | mention_idxs = (gold_mentions[0] == 1).nonzero(as_tuple=False) 355 | else: 356 | mention_idxs = mention_idxs[0] 357 | 358 | mention_start_idxs = mention_idxs[:, 0] 359 | mention_end_idxs = mention_idxs[:, 1] 360 | 361 | mentions_start_hidden_states = torch.index_select(lhs, 1, mention_start_idxs) 362 | mentions_end_hidden_states = torch.index_select(lhs, 1, mention_end_idxs) 363 | 364 | coreference_loss, coreferences = self.s2e_mention_cluster( 365 | mentions_start_hidden_states, 366 | mentions_end_hidden_states, 367 | mention_start_idxs, 368 | mention_end_idxs, 369 | gold_clusters, 370 | stage, 371 | ) 372 | 373 | loss = loss + coreference_loss 374 | loss_dict["coreference_loss"] = coreference_loss 375 | 376 | if stage != "train": 377 | preds["clusters"] = coreferences 378 | 379 | loss_dict["full_loss"] = loss 380 | output = {"pred_dict": preds, "loss_dict": loss_dict, "loss": loss} 381 | 382 | return output 383 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial-ShareAlike 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International 58 | Public License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial-ShareAlike 4.0 International Public License 63 | ("Public License"). To the extent this Public License may be 64 | interpreted as a contract, You are granted the Licensed Rights in 65 | consideration of Your acceptance of these terms and conditions, and the 66 | Licensor grants You such rights in consideration of benefits the 67 | Licensor receives from making the Licensed Material available under 68 | these terms and conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. BY-NC-SA Compatible License means a license listed at 88 | creativecommons.org/compatiblelicenses, approved by Creative 89 | Commons as essentially the equivalent of this Public License. 90 | 91 | d. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | 99 | e. Effective Technological Measures means those measures that, in the 100 | absence of proper authority, may not be circumvented under laws 101 | fulfilling obligations under Article 11 of the WIPO Copyright 102 | Treaty adopted on December 20, 1996, and/or similar international 103 | agreements. 104 | 105 | f. Exceptions and Limitations means fair use, fair dealing, and/or 106 | any other exception or limitation to Copyright and Similar Rights 107 | that applies to Your use of the Licensed Material. 108 | 109 | g. License Elements means the license attributes listed in the name 110 | of a Creative Commons Public License. The License Elements of this 111 | Public License are Attribution, NonCommercial, and ShareAlike. 112 | 113 | h. Licensed Material means the artistic or literary work, database, 114 | or other material to which the Licensor applied this Public 115 | License. 116 | 117 | i. Licensed Rights means the rights granted to You subject to the 118 | terms and conditions of this Public License, which are limited to 119 | all Copyright and Similar Rights that apply to Your use of the 120 | Licensed Material and that the Licensor has authority to license. 121 | 122 | j. Licensor means the individual(s) or entity(ies) granting rights 123 | under this Public License. 124 | 125 | k. NonCommercial means not primarily intended for or directed towards 126 | commercial advantage or monetary compensation. For purposes of 127 | this Public License, the exchange of the Licensed Material for 128 | other material subject to Copyright and Similar Rights by digital 129 | file-sharing or similar means is NonCommercial provided there is 130 | no payment of monetary compensation in connection with the 131 | exchange. 132 | 133 | l. Share means to provide material to the public by any means or 134 | process that requires permission under the Licensed Rights, such 135 | as reproduction, public display, public performance, distribution, 136 | dissemination, communication, or importation, and to make material 137 | available to the public including in ways that members of the 138 | public may access the material from a place and at a time 139 | individually chosen by them. 140 | 141 | m. Sui Generis Database Rights means rights other than copyright 142 | resulting from Directive 96/9/EC of the European Parliament and of 143 | the Council of 11 March 1996 on the legal protection of databases, 144 | as amended and/or succeeded, as well as other essentially 145 | equivalent rights anywhere in the world. 146 | 147 | n. You means the individual or entity exercising the Licensed Rights 148 | under this Public License. Your has a corresponding meaning. 149 | 150 | 151 | Section 2 -- Scope. 152 | 153 | a. License grant. 154 | 155 | 1. Subject to the terms and conditions of this Public License, 156 | the Licensor hereby grants You a worldwide, royalty-free, 157 | non-sublicensable, non-exclusive, irrevocable license to 158 | exercise the Licensed Rights in the Licensed Material to: 159 | 160 | a. reproduce and Share the Licensed Material, in whole or 161 | in part, for NonCommercial purposes only; and 162 | 163 | b. produce, reproduce, and Share Adapted Material for 164 | NonCommercial purposes only. 165 | 166 | 2. Exceptions and Limitations. For the avoidance of doubt, where 167 | Exceptions and Limitations apply to Your use, this Public 168 | License does not apply, and You do not need to comply with 169 | its terms and conditions. 170 | 171 | 3. Term. The term of this Public License is specified in Section 172 | 6(a). 173 | 174 | 4. Media and formats; technical modifications allowed. The 175 | Licensor authorizes You to exercise the Licensed Rights in 176 | all media and formats whether now known or hereafter created, 177 | and to make technical modifications necessary to do so. The 178 | Licensor waives and/or agrees not to assert any right or 179 | authority to forbid You from making technical modifications 180 | necessary to exercise the Licensed Rights, including 181 | technical modifications necessary to circumvent Effective 182 | Technological Measures. For purposes of this Public License, 183 | simply making modifications authorized by this Section 2(a) 184 | (4) never produces Adapted Material. 185 | 186 | 5. Downstream recipients. 187 | 188 | a. Offer from the Licensor -- Licensed Material. Every 189 | recipient of the Licensed Material automatically 190 | receives an offer from the Licensor to exercise the 191 | Licensed Rights under the terms and conditions of this 192 | Public License. 193 | 194 | b. Additional offer from the Licensor -- Adapted Material. 195 | Every recipient of Adapted Material from You 196 | automatically receives an offer from the Licensor to 197 | exercise the Licensed Rights in the Adapted Material 198 | under the conditions of the Adapter's License You apply. 199 | 200 | c. No downstream restrictions. You may not offer or impose 201 | any additional or different terms or conditions on, or 202 | apply any Effective Technological Measures to, the 203 | Licensed Material if doing so restricts exercise of the 204 | Licensed Rights by any recipient of the Licensed 205 | Material. 206 | 207 | 6. No endorsement. Nothing in this Public License constitutes or 208 | may be construed as permission to assert or imply that You 209 | are, or that Your use of the Licensed Material is, connected 210 | with, or sponsored, endorsed, or granted official status by, 211 | the Licensor or others designated to receive attribution as 212 | provided in Section 3(a)(1)(A)(i). 213 | 214 | b. Other rights. 215 | 216 | 1. Moral rights, such as the right of integrity, are not 217 | licensed under this Public License, nor are publicity, 218 | privacy, and/or other similar personality rights; however, to 219 | the extent possible, the Licensor waives and/or agrees not to 220 | assert any such rights held by the Licensor to the limited 221 | extent necessary to allow You to exercise the Licensed 222 | Rights, but not otherwise. 223 | 224 | 2. Patent and trademark rights are not licensed under this 225 | Public License. 226 | 227 | 3. To the extent possible, the Licensor waives any right to 228 | collect royalties from You for the exercise of the Licensed 229 | Rights, whether directly or through a collecting society 230 | under any voluntary or waivable statutory or compulsory 231 | licensing scheme. In all other cases the Licensor expressly 232 | reserves any right to collect such royalties, including when 233 | the Licensed Material is used other than for NonCommercial 234 | purposes. 235 | 236 | 237 | Section 3 -- License Conditions. 238 | 239 | Your exercise of the Licensed Rights is expressly made subject to the 240 | following conditions. 241 | 242 | a. Attribution. 243 | 244 | 1. If You Share the Licensed Material (including in modified 245 | form), You must: 246 | 247 | a. retain the following if it is supplied by the Licensor 248 | with the Licensed Material: 249 | 250 | i. identification of the creator(s) of the Licensed 251 | Material and any others designated to receive 252 | attribution, in any reasonable manner requested by 253 | the Licensor (including by pseudonym if 254 | designated); 255 | 256 | ii. a copyright notice; 257 | 258 | iii. a notice that refers to this Public License; 259 | 260 | iv. a notice that refers to the disclaimer of 261 | warranties; 262 | 263 | v. a URI or hyperlink to the Licensed Material to the 264 | extent reasonably practicable; 265 | 266 | b. indicate if You modified the Licensed Material and 267 | retain an indication of any previous modifications; and 268 | 269 | c. indicate the Licensed Material is licensed under this 270 | Public License, and include the text of, or the URI or 271 | hyperlink to, this Public License. 272 | 273 | 2. You may satisfy the conditions in Section 3(a)(1) in any 274 | reasonable manner based on the medium, means, and context in 275 | which You Share the Licensed Material. For example, it may be 276 | reasonable to satisfy the conditions by providing a URI or 277 | hyperlink to a resource that includes the required 278 | information. 279 | 3. If requested by the Licensor, You must remove any of the 280 | information required by Section 3(a)(1)(A) to the extent 281 | reasonably practicable. 282 | 283 | b. ShareAlike. 284 | 285 | In addition to the conditions in Section 3(a), if You Share 286 | Adapted Material You produce, the following conditions also apply. 287 | 288 | 1. The Adapter's License You apply must be a Creative Commons 289 | license with the same License Elements, this version or 290 | later, or a BY-NC-SA Compatible License. 291 | 292 | 2. You must include the text of, or the URI or hyperlink to, the 293 | Adapter's License You apply. You may satisfy this condition 294 | in any reasonable manner based on the medium, means, and 295 | context in which You Share Adapted Material. 296 | 297 | 3. You may not offer or impose any additional or different terms 298 | or conditions on, or apply any Effective Technological 299 | Measures to, Adapted Material that restrict exercise of the 300 | rights granted under the Adapter's License You apply. 301 | 302 | 303 | Section 4 -- Sui Generis Database Rights. 304 | 305 | Where the Licensed Rights include Sui Generis Database Rights that 306 | apply to Your use of the Licensed Material: 307 | 308 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 309 | to extract, reuse, reproduce, and Share all or a substantial 310 | portion of the contents of the database for NonCommercial purposes 311 | only; 312 | 313 | b. if You include all or a substantial portion of the database 314 | contents in a database in which You have Sui Generis Database 315 | Rights, then the database in which You have Sui Generis Database 316 | Rights (but not its individual contents) is Adapted Material, 317 | including for purposes of Section 3(b); and 318 | 319 | c. You must comply with the conditions in Section 3(a) if You Share 320 | all or a substantial portion of the contents of the database. 321 | 322 | For the avoidance of doubt, this Section 4 supplements and does not 323 | replace Your obligations under this Public License where the Licensed 324 | Rights include other Copyright and Similar Rights. 325 | 326 | 327 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 328 | 329 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 330 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 331 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 332 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 333 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 334 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 335 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 336 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 337 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 338 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 339 | 340 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 341 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 342 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 343 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 344 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 345 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 346 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 347 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 348 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 349 | 350 | c. The disclaimer of warranties and limitation of liability provided 351 | above shall be interpreted in a manner that, to the extent 352 | possible, most closely approximates an absolute disclaimer and 353 | waiver of all liability. 354 | 355 | 356 | Section 6 -- Term and Termination. 357 | 358 | a. This Public License applies for the term of the Copyright and 359 | Similar Rights licensed here. However, if You fail to comply with 360 | this Public License, then Your rights under this Public License 361 | terminate automatically. 362 | 363 | b. Where Your right to use the Licensed Material has terminated under 364 | Section 6(a), it reinstates: 365 | 366 | 1. automatically as of the date the violation is cured, provided 367 | it is cured within 30 days of Your discovery of the 368 | violation; or 369 | 370 | 2. upon express reinstatement by the Licensor. 371 | 372 | For the avoidance of doubt, this Section 6(b) does not affect any 373 | right the Licensor may have to seek remedies for Your violations 374 | of this Public License. 375 | 376 | c. For the avoidance of doubt, the Licensor may also offer the 377 | Licensed Material under separate terms or conditions or stop 378 | distributing the Licensed Material at any time; however, doing so 379 | will not terminate this Public License. 380 | 381 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 382 | License. 383 | 384 | 385 | Section 7 -- Other Terms and Conditions. 386 | 387 | a. The Licensor shall not be bound by any additional or different 388 | terms or conditions communicated by You unless expressly agreed. 389 | 390 | b. Any arrangements, understandings, or agreements regarding the 391 | Licensed Material not stated herein are separate from and 392 | independent of the terms and conditions of this Public License. 393 | 394 | 395 | Section 8 -- Interpretation. 396 | 397 | a. For the avoidance of doubt, this Public License does not, and 398 | shall not be interpreted to, reduce, limit, restrict, or impose 399 | conditions on any use of the Licensed Material that could lawfully 400 | be made without permission under this Public License. 401 | 402 | b. To the extent possible, if any provision of this Public License is 403 | deemed unenforceable, it shall be automatically reformed to the 404 | minimum extent necessary to make it enforceable. If the provision 405 | cannot be reformed, it shall be severed from this Public License 406 | without affecting the enforceability of the remaining terms and 407 | conditions. 408 | 409 | c. No term or condition of this Public License will be waived and no 410 | failure to comply consented to unless expressly agreed to by the 411 | Licensor. 412 | 413 | d. Nothing in this Public License constitutes or may be interpreted 414 | as a limitation upon, or waiver of, any privileges and immunities 415 | that apply to the Licensor or You, including from the legal 416 | processes of any jurisdiction or authority. 417 | 418 | ======================================================================= 419 | 420 | Creative Commons is not a party to its public 421 | licenses. Notwithstanding, Creative Commons may elect to apply one of 422 | its public licenses to material it publishes and in those instances 423 | will be considered the “Licensor.” The text of the Creative Commons 424 | public licenses is dedicated to the public domain under the CC0 Public 425 | Domain Dedication. Except for the limited purpose of indicating that 426 | material is shared under a Creative Commons public license or as 427 | otherwise permitted by the Creative Commons policies published at 428 | creativecommons.org/policies, Creative Commons does not authorize the 429 | use of the trademark "Creative Commons" or any other trademark or logo 430 | of Creative Commons without its prior written consent including, 431 | without limitation, in connection with any unauthorized modifications 432 | to any of its public licenses or any other arrangements, 433 | understandings, or agreements concerning use of licensed material. For 434 | the avoidance of doubt, this paragraph does not form part of the 435 | public licenses. 436 | 437 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /maverick/models/model_mes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | import numpy as np 4 | 5 | from torch.nn import init 6 | from torch import nn 7 | from transformers import AutoModel, AutoConfig 8 | 9 | from maverick.common.util import * 10 | from maverick.common.constants import * 11 | 12 | 13 | class Maverick_mes(torch.nn.Module): 14 | def __init__(self, *args, **kwargs): 15 | super().__init__() 16 | # document transformer encoder 17 | self.encoder_hf_model_name = kwargs["huggingface_model_name"] 18 | self.encoder = AutoModel.from_pretrained(self.encoder_hf_model_name) 19 | self.encoder_config = AutoConfig.from_pretrained(self.encoder_hf_model_name) 20 | self.encoder.resize_token_embeddings(self.encoder.embeddings.word_embeddings.num_embeddings + 3) 21 | 22 | # freeze 23 | if kwargs["freeze_encoder"]: 24 | for param in self.encoder.parameters(): 25 | param.requires_grad = False 26 | 27 | # span representation, now is concat_start_end 28 | self.span_representation = kwargs["span_representation"] 29 | # type of representation layer in 'Linear, FC, LSTM-left, LSTM-right, Conv1d' 30 | self.representation_layer_type = "FC" # fullyconnected 31 | # span hidden dimension 32 | self.token_hidden_size = self.encoder_config.hidden_size 33 | 34 | # if span representation method is to concatenate start and end, a mention hidden size will be 2*token_hidden_size 35 | if self.span_representation == "concat_start_end": 36 | self.mention_hidden_size = self.token_hidden_size * 2 37 | 38 | self.num_cats = len(CATEGORIES) + 1 # +1 for ALL 39 | self.all_cats_size = self.token_hidden_size * self.num_cats 40 | 41 | self.coref_start_all_mlps = RepresentationLayer( 42 | type=self.representation_layer_type, # fullyconnected 43 | input_dim=self.token_hidden_size, 44 | output_dim=self.all_cats_size, 45 | hidden_dim=self.mention_hidden_size, 46 | ) 47 | 48 | self.coref_end_all_mlps = RepresentationLayer( 49 | type=self.representation_layer_type, # fullyconnected 50 | input_dim=self.token_hidden_size, 51 | output_dim=self.all_cats_size, 52 | hidden_dim=self.mention_hidden_size, 53 | ) 54 | 55 | self.antecedent_s2s_all_weights = nn.Parameter( 56 | torch.empty((self.num_cats, self.token_hidden_size, self.token_hidden_size)) 57 | ) 58 | self.antecedent_e2e_all_weights = nn.Parameter( 59 | torch.empty((self.num_cats, self.token_hidden_size, self.token_hidden_size)) 60 | ) 61 | self.antecedent_s2e_all_weights = nn.Parameter( 62 | torch.empty((self.num_cats, self.token_hidden_size, self.token_hidden_size)) 63 | ) 64 | self.antecedent_e2s_all_weights = nn.Parameter( 65 | torch.empty((self.num_cats, self.token_hidden_size, self.token_hidden_size)) 66 | ) 67 | 68 | self.antecedent_s2s_all_biases = nn.Parameter(torch.empty((self.num_cats, self.token_hidden_size))) 69 | self.antecedent_e2e_all_biases = nn.Parameter(torch.empty((self.num_cats, self.token_hidden_size))) 70 | self.antecedent_s2e_all_biases = nn.Parameter(torch.empty((self.num_cats, self.token_hidden_size))) 71 | self.antecedent_e2s_all_biases = nn.Parameter(torch.empty((self.num_cats, self.token_hidden_size))) 72 | 73 | # mention extraction layers 74 | # representation of start token 75 | self.start_token_representation = RepresentationLayer( 76 | type=self.representation_layer_type, 77 | input_dim=self.token_hidden_size, 78 | output_dim=self.token_hidden_size, 79 | hidden_dim=self.token_hidden_size, 80 | ) 81 | 82 | # representation of end token 83 | self.end_token_representation = RepresentationLayer( 84 | type=self.representation_layer_type, 85 | input_dim=self.token_hidden_size, 86 | output_dim=self.token_hidden_size, 87 | hidden_dim=self.token_hidden_size, 88 | ) 89 | 90 | # models probability to be the start of a mention 91 | self.start_token_classifier = RepresentationLayer( 92 | type=self.representation_layer_type, 93 | input_dim=self.token_hidden_size, 94 | output_dim=1, 95 | hidden_dim=self.token_hidden_size, 96 | ) 97 | 98 | # model mention probability from start and end representations 99 | self.start_end_classifier = RepresentationLayer( 100 | type=self.representation_layer_type, 101 | input_dim=self.mention_hidden_size, 102 | output_dim=1, 103 | hidden_dim=self.token_hidden_size, 104 | ) 105 | self.reset_parameters() 106 | 107 | def reset_parameters(self) -> None: 108 | W = [ 109 | self.antecedent_s2s_all_weights, 110 | self.antecedent_e2e_all_weights, 111 | self.antecedent_s2e_all_weights, 112 | self.antecedent_e2s_all_weights, 113 | ] 114 | 115 | B = [ 116 | self.antecedent_s2s_all_biases, 117 | self.antecedent_e2e_all_biases, 118 | self.antecedent_s2e_all_biases, 119 | self.antecedent_e2s_all_biases, 120 | ] 121 | 122 | for w, b in zip(W, B): 123 | init.kaiming_uniform_(w, a=math.sqrt(5)) 124 | fan_in, _ = init._calculate_fan_in_and_fan_out(w) 125 | bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 126 | init.uniform_(b, -bound, bound) 127 | 128 | # takes last_hidden_states, eos_mask, ground truth and stage 129 | def eos_mention_extraction(self, lhs, eos_mask, gold_mentions, gold_starts, stage): 130 | start_idxs = [] 131 | mention_idxs = [] 132 | start_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 133 | mention_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 134 | 135 | for bidx in range(0, lhs.shape[0]): 136 | lhs_batch = lhs[bidx] # SEQ_LEN X HIDD_DIM 137 | eos_mask_batch = eos_mask[bidx] # SEQ_LEN X SEQ_LEN 138 | 139 | # compute start logits 140 | start_logits_batch = self.start_token_classifier(lhs_batch).squeeze(-1) # SEQ_LEN 141 | 142 | if gold_starts != None: 143 | loss = torch.nn.functional.binary_cross_entropy_with_logits(start_logits_batch, gold_starts[bidx]) 144 | 145 | # accumulate loss 146 | start_loss = start_loss + loss 147 | 148 | # compute start positions 149 | start_idxs_batch = ((torch.sigmoid(start_logits_batch) > 0.5)).nonzero(as_tuple=False).squeeze(-1) 150 | 151 | start_idxs.append(start_idxs_batch.detach().clone()) 152 | # in training, use gold starts to learn to extract mentions, inference use predicted ones 153 | if stage == "train": 154 | start_idxs_batch = ( 155 | ((torch.sigmoid(gold_starts[bidx]) > 0.5)).nonzero(as_tuple=False).squeeze(-1) 156 | ) # NUM_GOLD_STARTS 157 | 158 | # contains all possible start end indices pairs, i.e. for all starts, all possible ends looking at EOS index 159 | possibles_start_end_idxs = (eos_mask_batch[start_idxs_batch] == 1).nonzero(as_tuple=False) # STARTS x 2 160 | 161 | # this is to have reference respect to original positions 162 | possibles_start_end_idxs[:, 0] = start_idxs_batch[possibles_start_end_idxs[:, 0]] 163 | 164 | possible_start_idxs = possibles_start_end_idxs[:, 0] 165 | possible_end_idxs = possibles_start_end_idxs[:, 1] 166 | 167 | # extract start and end hidden states 168 | starts_hidden_states = lhs_batch[possible_end_idxs] # start 169 | ends_hidden_states = lhs_batch[possible_start_idxs] # end 170 | 171 | # concatenation of start to end representations created using a representation layer 172 | s2e_representations = torch.cat( 173 | ( 174 | self.start_token_representation(starts_hidden_states), 175 | self.end_token_representation(ends_hidden_states), 176 | ), 177 | dim=-1, 178 | ) 179 | 180 | # classification of mentions 181 | s2e_logits = self.start_end_classifier(s2e_representations).squeeze(-1) 182 | 183 | # mention_start_idxs and mention_end_idxs 184 | mention_idxs.append(possibles_start_end_idxs[torch.sigmoid(s2e_logits) > 0.5].detach().clone()) 185 | 186 | if s2e_logits.shape[0] != 0 and stage != "test": 187 | if gold_mentions != None: 188 | mention_loss_batch = torch.nn.functional.binary_cross_entropy_with_logits( 189 | s2e_logits, 190 | gold_mentions[bidx][possible_start_idxs, possible_end_idxs], 191 | ) 192 | mention_loss = mention_loss + mention_loss_batch 193 | 194 | return (start_idxs, mention_idxs, start_loss, mention_loss) 195 | 196 | def _get_cluster_labels_after_pruning(self, span_starts, span_ends, all_clusters): 197 | """ 198 | :param span_starts: [batch_size, max_k] 199 | :param span_ends: [batch_size, max_k] 200 | :param all_clusters: [batch_size, max_cluster_size, max_clusters_num, 2] 201 | :return: [batch_size, max_k, max_k + 1] - [b, i, j] == 1 if i is antecedent of j 202 | """ 203 | span_starts = span_starts.unsqueeze(0) 204 | span_ends = span_ends.unsqueeze(0) 205 | batch_size, max_k = span_starts.size() 206 | new_cluster_labels = torch.zeros((batch_size, max_k, max_k), device="cpu") 207 | all_clusters_cpu = all_clusters.cpu().numpy() 208 | for b, (starts, ends, gold_clusters) in enumerate( 209 | zip(span_starts.cpu().tolist(), span_ends.cpu().tolist(), all_clusters_cpu) 210 | ): 211 | gold_clusters = self.extract_clusters(gold_clusters) 212 | mention_to_gold_clusters = extract_mentions_to_clusters(gold_clusters) 213 | gold_mentions = set(mention_to_gold_clusters.keys()) 214 | for i, (start, end) in enumerate(zip(starts, ends)): 215 | if (start, end) not in gold_mentions: 216 | continue 217 | for j, (a_start, a_end) in enumerate(list(zip(starts, ends))[:i]): 218 | if (a_start, a_end) in mention_to_gold_clusters[(start, end)]: 219 | new_cluster_labels[b, i, j] = 1 220 | new_cluster_labels = new_cluster_labels.to(self.encoder.device) 221 | return new_cluster_labels 222 | 223 | def _get_all_labels(self, clusters_labels, categories_masks): 224 | 225 | categories_labels = clusters_labels.unsqueeze(1).repeat(1, self.num_cats, 1, 1) * categories_masks 226 | all_labels = categories_labels 227 | 228 | return all_labels 229 | 230 | def mes_span_clustering( 231 | self, mention_start_reps, mention_end_reps, mention_start_idxs, mention_end_idxs, gold, stage, mask, add, sing 232 | ): 233 | if mention_start_reps[0].shape[0] == 0: 234 | return torch.tensor([0.0], requires_grad=True, device=self.encoder.device), [] 235 | coref_logits = self._calc_coref_logits(mention_start_reps, mention_end_reps) 236 | coref_logits = coref_logits[0] * mask[0] 237 | coreference_loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 238 | 239 | coref_logits = torch.stack([matrix.tril().fill_diagonal_(0) for matrix in coref_logits]).unsqueeze(0) 240 | if stage == "train": 241 | labels = self._get_cluster_labels_after_pruning(mention_start_idxs, mention_end_idxs, gold) 242 | all_labels = self._get_all_labels(labels, mask) 243 | coreference_loss = torch.nn.functional.binary_cross_entropy_with_logits(coref_logits, all_labels) 244 | 245 | coref_logits = coref_logits.sum(dim=1) 246 | doc, m2a, singletons = self.create_mention_to_antecedent_singletons(mention_start_idxs, mention_end_idxs, coref_logits) 247 | if not sing: 248 | singletons = [] 249 | coreferences = self.create_clusters(m2a, singletons, add) 250 | return coreference_loss, coreferences 251 | 252 | def transpose_for_scores(self, x): 253 | new_x_shape = x.size()[:-1] + (self.num_cats, self.token_hidden_size) 254 | x = x.view(*new_x_shape) 255 | return x.permute(0, 2, 1, 3) # bnkf/bnlg 256 | 257 | def _calc_coref_logits(self, top_k_start_coref_reps, top_k_end_coref_reps): 258 | all_starts = self.transpose_for_scores(self.coref_start_all_mlps(top_k_start_coref_reps)) 259 | all_ends = self.transpose_for_scores(self.coref_end_all_mlps(top_k_end_coref_reps)) 260 | 261 | logits = ( 262 | torch.einsum("bnkf, nfg, bnlg -> bnkl", all_starts, self.antecedent_s2s_all_weights, all_starts) 263 | + torch.einsum("bnkf, nfg, bnlg -> bnkl", all_ends, self.antecedent_e2e_all_weights, all_ends) 264 | + torch.einsum("bnkf, nfg, bnlg -> bnkl", all_starts, self.antecedent_s2e_all_weights, all_ends) 265 | + torch.einsum("bnkf, nfg, bnlg -> bnkl", all_ends, self.antecedent_e2s_all_weights, all_starts) 266 | ) 267 | 268 | biases = ( 269 | torch.einsum("bnkf, nf -> bnk", all_starts, self.antecedent_s2s_all_biases).unsqueeze(-2) 270 | + torch.einsum("bnkf, nf -> bnk", all_ends, self.antecedent_e2e_all_biases).unsqueeze(-2) 271 | + torch.einsum("bnkf, nf -> bnk", all_ends, self.antecedent_s2e_all_biases).unsqueeze(-2) 272 | + torch.einsum("bnkf, nf -> bnk", all_starts, self.antecedent_e2s_all_biases).unsqueeze(-2) 273 | ) 274 | 275 | return logits + biases 276 | 277 | def _get_categories_labels(self, tokens, subtoken_map, new_token_map, span_starts, span_ends): 278 | max_k = span_starts.shape[0] 279 | 280 | doc_spans = [] 281 | for start, end in zip(span_starts, span_ends): 282 | token_indices = [new_token_map[0][idx] for idx in set(subtoken_map[0][start : end + 1]) - {None}] 283 | span = {tokens[0][idx].lower() for idx in token_indices if idx is not None} 284 | pronoun_id = get_pronoun_id(span) 285 | doc_spans.append((span - STOPWORDS, pronoun_id)) 286 | 287 | categories_labels = np.zeros((max_k, max_k)) - 1 288 | for i in range(max_k): 289 | for j in list(range(max_k))[:i]: 290 | categories_labels[i, j] = get_category_id(doc_spans[i], doc_spans[j]) 291 | 292 | categories_labels = torch.tensor(categories_labels, device=self.encoder.device).unsqueeze(0) 293 | categories_masks = [categories_labels == cat_id for cat_id in range(self.num_cats - 1)] + [categories_labels != -1] 294 | categories_masks = torch.stack(categories_masks, dim=1).int() 295 | return categories_labels, categories_masks 296 | 297 | def create_mention_to_antecedent_singletons(self, span_starts, span_ends, coref_logits): 298 | span_starts = span_starts.unsqueeze(0) 299 | span_ends = span_ends.unsqueeze(0) 300 | bs, n_spans, _ = coref_logits.shape 301 | # long distance regularization 302 | # a = torch.sigmoid(coref_logits) 303 | # m1 = torch.arange(coref_logits.shape[1], device=a.device).unsqueeze(0).repeat(coref_logits.shape[1], 1) 304 | # m2 = m1.transpose(0, 1) 305 | # m = (torch.ones_like(coref_logits[0], device=a.device) + (m2 - m1 - 1) / 1000).tril().fill_diagonal_(0).unsqueeze(0) 306 | # no_ant = 1 - torch.sum(a - a * m > 0.5, dim=-1).bool().float() 307 | # coref_logits = torch.cat((coref_logits - coref_logits * m, no_ant.unsqueeze(-1)), dim=-1) 308 | no_ant = 1 - torch.sum(torch.sigmoid(coref_logits) > 0.5, dim=-1).bool().float() 309 | # [batch_size, max_k, max_k + 1] 310 | coref_logits = torch.cat((coref_logits, no_ant.unsqueeze(-1)), dim=-1) 311 | 312 | span_starts = span_starts.detach().cpu() 313 | span_ends = span_ends.detach().cpu() 314 | max_antecedents = coref_logits.argmax(axis=-1).detach().cpu() 315 | doc_indices = np.nonzero(max_antecedents < n_spans)[:, 0] 316 | # indices where antecedent is not null. 317 | mention_indices = np.nonzero(max_antecedents < n_spans)[:, 1] 318 | 319 | antecedent_indices = max_antecedents[max_antecedents < n_spans] 320 | span_indices = np.stack([span_starts.detach().cpu(), span_ends.detach().cpu()], axis=-1) 321 | 322 | mentions = span_indices[doc_indices, mention_indices] 323 | antecedents = span_indices[doc_indices, antecedent_indices] 324 | non_mentions = np.nonzero(max_antecedents == n_spans)[:, 1] 325 | 326 | sing_indices = np.zeros_like(len(np.setdiff1d(non_mentions, antecedent_indices))) 327 | singletons = span_indices[sing_indices, np.setdiff1d(non_mentions, antecedent_indices)] 328 | 329 | # mention_to_antecedent = np.stack([mentions, antecedents], axis=1) 330 | 331 | if len(mentions.shape) == 1 and len(antecedents.shape) == 1: 332 | mention_to_antecedent = np.stack([mentions, antecedents], axis=0) 333 | else: 334 | mention_to_antecedent = np.stack([mentions, antecedents], axis=1) 335 | 336 | if len(mentions.shape) == 1: 337 | mention_to_antecedent = [mention_to_antecedent] 338 | 339 | if len(singletons.shape) == 1: 340 | singletons = [singletons] 341 | 342 | return doc_indices, mention_to_antecedent, singletons 343 | 344 | def create_clusters(self, m2a, singletons, add): 345 | # Note: mention_to_antecedent is a numpy array 346 | if add != None: 347 | clusters = add 348 | mention_to_cluster = {m: i for i, c in enumerate(clusters) for m in c} 349 | else: 350 | clusters, mention_to_cluster = [], {} 351 | for mention, antecedent in m2a: 352 | mention, antecedent = tuple(mention), tuple(antecedent) 353 | if antecedent in mention_to_cluster: 354 | cluster_idx = mention_to_cluster[antecedent] 355 | if mention not in clusters[cluster_idx]: 356 | clusters[cluster_idx].append(mention) 357 | mention_to_cluster[mention] = cluster_idx 358 | elif mention in mention_to_cluster: 359 | cluster_idx = mention_to_cluster[mention] 360 | if antecedent not in clusters[cluster_idx]: 361 | clusters[cluster_idx].append(antecedent) 362 | mention_to_cluster[antecedent] = cluster_idx 363 | else: 364 | cluster_idx = len(clusters) 365 | mention_to_cluster[mention] = cluster_idx 366 | mention_to_cluster[antecedent] = cluster_idx 367 | clusters.append([antecedent, mention]) 368 | 369 | clusters = [tuple(cluster) for cluster in clusters] 370 | # maybe order stuff? 371 | if len(singletons) != 0: 372 | clust = [] 373 | while len(clusters) != 0 or len(singletons) != 0: 374 | if len(singletons) == 0: 375 | clust.append(clusters[0]) 376 | clusters = clusters[1:] 377 | elif len(clusters) == 0: 378 | clust.append(tuple([tuple(singletons[0])])) 379 | singletons = singletons[1:] 380 | elif singletons[0][0] < sorted(clusters[0], key=lambda x: x[0])[0][0]: 381 | clust.append(tuple([tuple(singletons[0])])) 382 | singletons = singletons[1:] 383 | else: 384 | clust.append(clusters[0]) 385 | clusters = clusters[1:] 386 | return clust 387 | return clusters 388 | 389 | def extract_clusters(self, gold_clusters): 390 | gold_clusters = [tuple(tuple(m) for m in cluster if (-1) not in m) for cluster in gold_clusters] 391 | gold_clusters = [cluster for cluster in gold_clusters if len(cluster) > 0] 392 | return gold_clusters 393 | 394 | def forward( 395 | self, 396 | stage, 397 | input_ids, 398 | attention_mask, 399 | eos_mask, 400 | gold_starts=None, 401 | gold_mentions=None, 402 | gold_clusters=None, 403 | tokens=None, 404 | subtoken_map=None, 405 | new_token_map=None, 406 | add=None, 407 | singletons=False, 408 | longdoc=None, 409 | ): 410 | 411 | loss = torch.tensor([0.0], requires_grad=True, device=self.encoder.device) 412 | loss_dict = {} 413 | preds = {} 414 | if longdoc == None: 415 | last_hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask)[ 416 | "last_hidden_state" 417 | ] # B x S x TH 418 | 419 | lhs = last_hidden_states 420 | 421 | ( 422 | start_idxs, 423 | mention_idxs, 424 | start_loss, 425 | mention_loss, 426 | ) = self.eos_mention_extraction( 427 | lhs=last_hidden_states, 428 | eos_mask=eos_mask, 429 | gold_mentions=gold_mentions, 430 | gold_starts=gold_starts, 431 | stage=stage, 432 | ) 433 | 434 | loss_dict["start_loss"] = start_loss 435 | preds["start_idxs"] = [start.detach().cpu() for start in start_idxs] 436 | 437 | loss_dict["mention_loss"] = mention_loss 438 | preds["mention_idxs"] = [mention.detach().cpu() for mention in mention_idxs] 439 | 440 | loss = loss + start_loss + mention_loss 441 | 442 | if stage == "train": 443 | mention_idxs = (gold_mentions[0] == 1).nonzero(as_tuple=False) 444 | elif stage == "test" and gold_mentions != None: 445 | mention_idxs = (gold_mentions[0] == 1).nonzero(as_tuple=False) 446 | else: 447 | mention_idxs = mention_idxs[0] 448 | mention_start_idxs = mention_idxs[:, 0] 449 | mention_end_idxs = mention_idxs[:, 1] 450 | mentions_start_hidden_states = torch.index_select(lhs, 1, mention_start_idxs) 451 | mentions_end_hidden_states = torch.index_select(lhs, 1, mention_end_idxs) 452 | 453 | elif longdoc == "naive": 454 | mention_idxs, mentions_start_hidden_states, mentions_end_hidden_states = self.naive_long_doc_forward( 455 | stage, input_ids, attention_mask, eos_mask, gold_starts, gold_mentions 456 | ) 457 | mention_idxs = mention_idxs[0] 458 | elif longdoc == "proper": 459 | pass 460 | else: 461 | raise Exception("unsupported") 462 | 463 | mention_start_idxs = mention_idxs[:, 0] 464 | mention_end_idxs = mention_idxs[:, 1] 465 | 466 | _, categories_masks = self._get_categories_labels( 467 | tokens, subtoken_map, new_token_map, mention_start_idxs, mention_end_idxs 468 | ) 469 | 470 | coreference_loss, coreferences = self.mes_span_clustering( 471 | mentions_start_hidden_states, 472 | mentions_end_hidden_states, 473 | mention_start_idxs, 474 | mention_end_idxs, 475 | gold_clusters, 476 | stage, 477 | categories_masks, 478 | add, 479 | singletons, 480 | ) 481 | 482 | loss = loss + coreference_loss 483 | loss_dict["coreference_loss"] = coreference_loss 484 | 485 | if stage != "train": 486 | preds["clusters"] = coreferences 487 | 488 | loss_dict["full_loss"] = loss 489 | output = {"pred_dict": preds, "loss_dict": loss_dict, "loss": loss} 490 | 491 | return output 492 | 493 | def naive_long_doc_forward( 494 | self, 495 | stage, 496 | input_ids, 497 | attention_mask, 498 | eos_mask, 499 | gold_starts=None, 500 | gold_mentions=None, 501 | max_seq_len=2000, 502 | ): 503 | length = input_ids.shape[1] 504 | slices_seq_index = [ 505 | (eos_mask[0][step] == 1).nonzero(as_tuple=False)[-1][0].item() for step in range(max_seq_len, length, max_seq_len) 506 | ] 507 | star = (eos_mask[0][0] == 1).nonzero(as_tuple=False)[-1][0].item() 508 | if len(slices_seq_index) == 0 or slices_seq_index[-1] != length: 509 | slices_seq_index.append(length) 510 | prev = 0 511 | start_idxs = [] 512 | mention_idxs = [] 513 | index_start_hidden_states = [] 514 | index_end_hidden_states = [] 515 | for index in slices_seq_index: 516 | if prev != 0: 517 | index_input_ids = torch.cat((input_ids[0][0:star], input_ids[0][prev:index]), dim=0).unsqueeze(0) 518 | index_attention_mask = torch.cat((attention_mask[0][0:star], attention_mask[0][prev:index]), dim=0).unsqueeze(0) 519 | index_eos_mask = torch.cat((eos_mask[0][0:star], eos_mask[0][prev:index]), dim=0).unsqueeze(0) 520 | index_eos_mask = torch.cat((index_eos_mask[0][:, 0:star], index_eos_mask[0][:, prev:index]), dim=1).unsqueeze(0) 521 | if gold_mentions != None: 522 | index_gold_mentions = torch.cat((gold_mentions[0][0:star], gold_mentions[0][prev:index]), dim=0).unsqueeze(0) 523 | index_gold_mentions = torch.cat( 524 | (index_gold_mentions[0][:, 0:star], index_gold_mentions[0][:, prev:index]), dim=1 525 | ).unsqueeze(0) 526 | index_gold_starts = torch.cat((gold_starts[0][0:star], input_ids[0][prev:index]), dim=0).unsqueeze(0) 527 | else: 528 | index_gold_mentions = gold_mentions 529 | index_gold_starts = gold_starts 530 | else: 531 | index_input_ids = input_ids[0][prev:index].unsqueeze(0) 532 | index_attention_mask = attention_mask[0][prev:index].unsqueeze(0) 533 | index_eos_mask = eos_mask[0][prev:index].unsqueeze(0) 534 | index_eos_mask = index_eos_mask[0][:, prev:index].unsqueeze(0) 535 | if gold_mentions != None: 536 | index_gold_mentions = gold_mentions[0][prev:index].unsqueeze(0) 537 | index_gold_mentions = index_gold_mentions[0][:, prev:index].unsqueeze(0) 538 | index_gold_starts = gold_starts[0][prev:index].unsqueeze(0) 539 | else: 540 | index_gold_mentions = gold_mentions 541 | index_gold_starts = gold_starts 542 | 543 | index_last_hidden_states = self.encoder(input_ids=index_input_ids, attention_mask=index_attention_mask)[ 544 | "last_hidden_state" 545 | ] # B x S x TH 546 | 547 | ( 548 | index_start_idxs, 549 | index_mention_idxs, 550 | _, 551 | _, 552 | ) = self.eos_mention_extraction( 553 | lhs=index_last_hidden_states, 554 | eos_mask=index_eos_mask, 555 | gold_mentions=index_gold_mentions, 556 | gold_starts=index_gold_starts, 557 | stage=stage, 558 | ) 559 | 560 | if index_gold_mentions != None: 561 | index_mention_idxs = [(index_gold_mentions[0] == 1).nonzero(as_tuple=False)] 562 | if prev != 0: 563 | start_idxs.append(torch.add(index_start_idxs[0][index_start_idxs[0] > star], prev - star)) 564 | mention_idxs.append(torch.add(index_mention_idxs[0][index_mention_idxs[0][:, 0] > star], prev - star)) 565 | index_start_hidden_states.append( 566 | torch.index_select( 567 | index_last_hidden_states, 1, index_mention_idxs[0][index_mention_idxs[0][:, 0] > star][:, 0] 568 | ) 569 | ) 570 | index_end_hidden_states.append( 571 | torch.index_select( 572 | index_last_hidden_states, 1, index_mention_idxs[0][index_mention_idxs[0][:, 0] > star][:, 1] 573 | ) 574 | ) 575 | 576 | else: 577 | start_idxs.append(index_start_idxs[0]) 578 | idxs = index_mention_idxs[0] 579 | mention_idxs.append(idxs) 580 | index_start_hidden_states.append(torch.index_select(index_last_hidden_states, 1, idxs[:, 0])) 581 | index_end_hidden_states.append(torch.index_select(index_last_hidden_states, 1, idxs[:, 1])) 582 | 583 | prev = index 584 | 585 | mentions_start_hidden_states = torch.cat(index_start_hidden_states, dim=1) 586 | mentions_end_hidden_states = torch.cat(index_end_hidden_states, dim=1) 587 | mention_idxs = [torch.cat(mention_idxs, dim=0)] 588 | return mention_idxs, mentions_start_hidden_states, mentions_end_hidden_states 589 | --------------------------------------------------------------------------------