├── expred ├── __init__.py ├── models │ ├── __init__.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── mtl_evidence_classifier.py │ │ ├── mtl_token_identifier.py │ │ └── mtl_pipeline_utils.py │ ├── losses.py │ ├── model_utils.py │ └── mlp_mtl.py ├── params.py ├── eraser_utils.py ├── rationale_tokenization.py ├── bert_rational_feature.py ├── preprocessing.py ├── tokenizer.py ├── eraser_benchmark.py ├── utils.py └── train.py ├── test_bogus.py ├── requirements.txt ├── .idea ├── vcs.xml ├── other.xml ├── misc.xml ├── interpretation_by_design.iml ├── modules.xml ├── deployment.xml ├── runConfigurations │ └── mtl_pipeline_movies_regression_test.xml └── inspectionProfiles │ └── Project_Default.xml ├── scripts ├── run_sweep_wrapper.sh ├── run_fever_final.sh ├── run_fever_gru_final.sh ├── run_fever_rnr_final.sh ├── run_movies_rnr_final.sh ├── run_multirc_gru_final.sh ├── run_fever_b.sh ├── run_movies_gru_final.sh ├── run_multirc_rnr_final_a.sh ├── run_multirc_rnr_final_b.sh ├── run_fever_a.sh ├── run_multirc_rnr_sweep_b.sh ├── run_movies_rnr_sweep.sh ├── run_multirc_rnr_sweep.sh ├── run_multirc_rnr_sweep_a.sh ├── run_probe.sh ├── run_fever_firsts.sh ├── run_multirc_firsts.sh ├── run_gru_firsts.sh ├── run_fever_gru_sweep.sh ├── run_multirc_gru_sweep.sh ├── run_fever_rnr_sweep.sh ├── run_rnr_firsts.sh ├── run_sweep.sh ├── run_b.sh ├── run_a.sh ├── run_multirc_rnr_sweep_a_and_freeze.sh ├── run_multirc_rnr_sweep_c.sh └── run_multirc_rnr_sweep_node05.sh ├── test ├── test_train.py └── test_params │ └── movies_expred.json ├── params ├── movies_expred.json ├── multirc_expred.json └── fever_expred.json ├── README.md ├── .github └── workflows │ └── ci.yml ├── .gitignore └── sample_rationales_dataset.py /expred/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /expred/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /expred/models/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /test_bogus.py: -------------------------------------------------------------------------------- 1 | def test_test(): 2 | assert True 3 | -------------------------------------------------------------------------------- /expred/params.py: -------------------------------------------------------------------------------- 1 | class MTLParams(): 2 | dim_cls_linear: int 3 | num_labels: int 4 | dim_exp_gru: int -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.5.1 2 | torch==1.3.0 3 | torchvision==0.4.2 4 | gensim==3.7.1 5 | scikit-learn==0.20.3 6 | wandb==0.10.24 -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /scripts/run_sweep_wrapper.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dataset=$1 4 | structure=$2 5 | portion=$3 6 | . .venv/bin/activate 7 | for i in 0 1 2 3; do 8 | ./run_sweep.sh $i ${dataset} ${structure} ${portion}& 9 | done 10 | deactivate 11 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 10 | -------------------------------------------------------------------------------- /test/test_train.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from expred.train import main 3 | 4 | class TestTrain(unittest.TestCase): 5 | def test_runs_for_one_epoch(self): 6 | # todo fix path / data problem 7 | args = [ 8 | "--data_dir", "/home/mreimer/datasets/eraser/movies_debug", 9 | "--output_dir", "output/", 10 | "--conf", "test_params/movies_expred.json", 11 | "--batch_size", "4"] 12 | main(args) 13 | 14 | 15 | if __name__ == '__main__': 16 | unittest.main() 17 | -------------------------------------------------------------------------------- /scripts/run_fever_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 2 ) 4 | gpu_id=0,1 5 | batch_size=16 6 | num_epochs=10 7 | dataset='fever' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_fever_gru_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 5. ) 4 | gpu_id=2 5 | batch_size=4 6 | num_epochs=10 7 | dataset='fever' 8 | exp_structure='gru' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_fever_rnr_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.2 ) 4 | gpu_id=1 5 | batch_size=16 6 | num_epochs=10 7 | dataset='fever' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_movies_rnr_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.1 ) 4 | gpu_id=0 5 | batch_size=4 6 | num_epochs=10 7 | dataset='movies' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_multirc_gru_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 2. ) 4 | gpu_id=3 5 | batch_size=4 6 | num_epochs=10 7 | dataset='multirc' 8 | exp_structure='gru' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_fever_b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.1 0.2 0.5 1 2 5 ) 4 | gpu_id=1 5 | batch_size=16 6 | num_epochs=10 7 | dataset='fever' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0.1' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_movies_gru_final.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 100. 5. ) 4 | gpu_id=1 5 | batch_size=4 6 | num_epochs=10 7 | dataset='movies' 8 | exp_structure='gru' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_final_a.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.05 ) 4 | gpu_id=0 5 | batch_size=4 6 | num_epochs=10 7 | dataset='multirc' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_final_b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.1 ) 4 | gpu_id=0 5 | batch_size=4 6 | num_epochs=10 7 | dataset='multirc' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_fever_a.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.02 0.03 0.05 ) 4 | gpu_id=0 5 | batch_size=16 6 | num_epochs=10 7 | dataset='fever' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | train_on_portion='0.1' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_sweep_b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.1 0.2 0.5 1. 2. ) 4 | gpu_id=1 5 | batch_size=16 6 | num_epochs=10 7 | dataset='multirc' 8 | exp_structure='rnr' 9 | benchmark_split='val' 10 | train_on_portion='0.4' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /.idea/interpretation_by_design.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /scripts/run_movies_rnr_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. ) 4 | gpu_id=0 5 | batch_size=16 6 | num_epochs=10 7 | dataset='movies' 8 | exp_structure='rnr' 9 | benchmark_split='val' 10 | train_on_portion='0' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. ) 4 | gpu_id=0 5 | batch_size=16 6 | num_epochs=10 7 | dataset='multirc' 8 | exp_structure='rnr' 9 | benchmark_split='val' 10 | train_on_portion='0.4' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | -------------------------------------------------------------------------------- /.idea/deployment.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 18 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_sweep_a.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 ) 4 | #0.1, 0.2, 0.5, 1., 2., 5. ) 5 | gpu_id=0 6 | batch_size=16 7 | num_epochs=10 8 | dataset='multirc' 9 | exp_structure='rnr' 10 | benchmark_split='val' 11 | train_on_portion='0.4' 12 | 13 | for par_lambda in ${lambdas[@]}; do 14 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 15 | done 16 | -------------------------------------------------------------------------------- /scripts/run_probe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 1. ) 4 | gpu_id=0 5 | train_first=( exp cls ) 6 | batch_size=16 7 | num_epochs=2 8 | dataset='movies' 9 | exp_structure='rnr' 10 | benchmark_split='test' 11 | train_on_portion='0.1' 12 | 13 | for phase in ${train_first[@]}; do 14 | for par_lambda in ${lambdas[@]}; do 15 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --train_${phase}_first; 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /scripts/run_fever_firsts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 1. ) 4 | gpu_id=0 5 | train_first=( exp cls ) 6 | batch_size=16 7 | num_epochs=2 8 | dataset='movies' 9 | exp_structure='rnr' 10 | benchmark_split='test' 11 | train_on_portion='0.1' 12 | 13 | for phase in ${train_first[@]}; do 14 | for par_lambda in ${lambdas[@]}; do 15 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --train_${phase}_first; 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /scripts/run_multirc_firsts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 1. ) 4 | gpu_id=0 5 | train_first=( exp cls ) 6 | batch_size=16 7 | num_epochs=2 8 | dataset='movies' 9 | exp_structure='rnr' 10 | benchmark_split='test' 11 | train_on_portion='0.1' 12 | 13 | for phase in ${train_first[@]}; do 14 | for par_lambda in ${lambdas[@]}; do 15 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --train_${phase}_first; 16 | done 17 | done 18 | -------------------------------------------------------------------------------- /scripts/run_gru_firsts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | par_lambda=1. 4 | #gpu_id=0 5 | #train_first=( exp cls ) 6 | gpu_id=$1 7 | phase=$2 8 | #batch_size=16 9 | batch_size=$3 10 | num_epochs=10 11 | datasets=( fever multirc ) 12 | exp_structure='gru' 13 | benchmark_split='test' 14 | train_on_portion='0' 15 | 16 | for dataset in ${datasets[@]}; do 17 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --load_phase1 --merge_evidences --benchmark_split ${benchmark_split} --do_train --start_from_phase1 --train_on_portion ${train_on_portion} --train_${phase}_first; 18 | done 19 | -------------------------------------------------------------------------------- /scripts/run_fever_gru_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=$1 4 | l0='0.1,0.2,0.5' 5 | l1='1,2,5' 6 | l2='10,20,50' 7 | l3='100,200,500' 8 | lambdas=( $l0 $l1 $l2 $l3 ) 9 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]} 10 | batch_size=5 11 | num_epochs=10 12 | dataset='movies' 13 | exp_structure='gru' 14 | benchmark_split='val' 15 | train_on_portion='0' 16 | 17 | for par_lambda in ${lambdas[@]}; do 18 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 19 | done 20 | -------------------------------------------------------------------------------- /scripts/run_multirc_gru_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=$1 4 | l0='0.1,0.2,0.5' 5 | l1='1,2,5' 6 | l2='10,20,50' 7 | l3='100,200,500' 8 | lambdas=( $l0 $l1 $l2 $l3 ) 9 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]} 10 | batch_size=5 11 | num_epochs=10 12 | dataset='movies' 13 | exp_structure='gru' 14 | benchmark_split='val' 15 | train_on_portion='0' 16 | 17 | for par_lambda in ${lambdas[@]}; do 18 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 19 | done 20 | -------------------------------------------------------------------------------- /scripts/run_fever_rnr_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=$1 4 | l0='0.001,0.002,0.005' 5 | l1='0.01,0.02,0.05' 6 | l2='0.1,0.2,0.5' 7 | l3='1.,2.,5.' 8 | lambdas=( $l0 $l1 $l2 $l3 ) 9 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]} 10 | batch_size=5 11 | num_epochs=10 12 | dataset='fever' 13 | exp_structure='rnr' 14 | benchmark_split='val' 15 | train_on_portion='0.1' 16 | 17 | for par_lambda in ${lambdas[@]}; do 18 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 19 | done 20 | -------------------------------------------------------------------------------- /scripts/run_rnr_firsts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | par_lambda=1. 4 | gpu_id=1 5 | #train_first=( exp cls ) 6 | train_first=( cls ) 7 | batch_size=100 8 | num_epochs=10 9 | datasets=( movies fever multirc ) 10 | #datasets=( movies ) 11 | exp_structure='rnr' 12 | benchmark_split='test' 13 | train_on_portion='0' 14 | 15 | for phase in ${train_first[@]}; do 16 | for dataset in ${datasets[@]}; do 17 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --start_from_phase1 --train_on_portion ${train_on_portion} --train_${phase}_first; 18 | done 19 | done 20 | -------------------------------------------------------------------------------- /scripts/run_sweep.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_id=$1 4 | dataset=$2 5 | exp_structure=$3 6 | train_on_portion=$4 7 | if [[ $exp_structure == 'gru' ]]; then 8 | l0='0.1,0.2,0.5' 9 | l1='1,2,5' 10 | l2='10,20,50' 11 | l3='100,200,500' 12 | else 13 | l0='0.001,0.002,0.005' 14 | l1='0.01,0.02,0.05' 15 | l2='0.1,0.2,0.5' 16 | l3='1.,2.,5' 17 | fi 18 | lambdas=( $l0 $l1 $l2 $l3 ) 19 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]} 20 | batch_size=4 21 | num_epochs=10 22 | benchmark_split='val' 23 | 24 | for par_lambda in ${lambdas[@]}; do 25 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 26 | done 27 | -------------------------------------------------------------------------------- /params/movies_expred.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "dim_cls_linear": 256, 4 | "dim_exp_gru": 128, 5 | "bert_vocab": "bert-base-uncased", 6 | "bert_dir": "bert-base-uncased", 7 | "exp_structure": "gru", 8 | "rebalance_approach": "resampling", 9 | "merge_evidences": 1, 10 | "classes": [ 11 | "NEG", 12 | "POS" 13 | ], 14 | "mtl_token_identifier": { 15 | "par_lambda": 5.0, 16 | "batch_size": 16, 17 | "epochs": 1, 18 | "patience": 3, 19 | "warmup_steps": 50, 20 | "lr": 1e-5, 21 | "use_half_precision": 0, 22 | "sampling_method": "whole_document" 23 | }, 24 | "evidence_classifier": { 25 | "batch_size": 16, 26 | "warmup_steps": 50, 27 | "epochs": 1, 28 | "patience": 3, 29 | "lr": 1e-5, 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /params/multirc_expred.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "dim_cls_linear": 256, 4 | "dim_exp_gru": 128, 5 | "bert_vocab": "bert-base-uncased", 6 | "bert_dir": "bert-base-uncased", 7 | "exp_structure": "gru", 8 | "rebalance_approach": "resampling", 9 | "merge_evidences": 1, 10 | "classes": [ 11 | "True", 12 | "False" 13 | ], 14 | "mtl_token_identifier": { 15 | "par_lambda": 20.0, 16 | "batch_size": 16, 17 | "epochs": 10, 18 | "patience": 3, 19 | "warmup_steps": 50, 20 | "lr": 1e-5, 21 | "use_half_precision": 0, 22 | "sampling_method": "whole_document" 23 | }, 24 | "evidence_classifier": { 25 | "batch_size": 16, 26 | "warmup_steps": 50, 27 | "epochs": 10, 28 | "patience": 3, 29 | "lr": 1e-5, 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /params/fever_expred.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "dim_cls_linear": 256, 4 | "dim_exp_gru": 128, 5 | "bert_vocab": "bert-base-uncased", 6 | "bert_dir": "bert-base-uncased", 7 | "exp_structure": "gru", 8 | "rebalance_approach": "resampling", 9 | "merge_evidences": 1, 10 | "classes": [ 11 | "SUPPORTS", 12 | "REFUTES" 13 | ], 14 | "mtl_token_identifier": { 15 | "par_lambda": 2.0, 16 | "batch_size": 16, 17 | "epochs": 10, 18 | "patience": 3, 19 | "warmup_steps": 50, 20 | "lr": 1e-5, 21 | "use_half_precision": 0, 22 | "sampling_method": "whole_document" 23 | }, 24 | "evidence_classifier": { 25 | "batch_size": 16, 26 | "warmup_steps": 50, 27 | "epochs": 10, 28 | "patience": 3, 29 | "lr": 1e-5, 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /test/test_params/movies_expred.json: -------------------------------------------------------------------------------- 1 | { 2 | "max_length": 512, 3 | "dim_cls_linear": 256, 4 | "dim_exp_gru": 128, 5 | "bert_vocab": "bert-base-uncased", 6 | "bert_dir": "bert-base-uncased", 7 | "exp_structure": "gru", 8 | "rebalance_approach": "resampling", 9 | "merge_evidences": 1, 10 | "classes": [ 11 | "NEG", 12 | "POS" 13 | ], 14 | "mtl_token_identifier": { 15 | "par_lambda": 5.0, 16 | "batch_size": 16, 17 | "epochs": 1, 18 | "patience": 3, 19 | "warmup_steps": 50, 20 | "lr": 1e-5, 21 | "use_half_precision": 0, 22 | "sampling_method": "whole_document" 23 | }, 24 | "evidence_classifier": { 25 | "batch_size": 16, 26 | "warmup_steps": 50, 27 | "epochs": 1, 28 | "patience": 3, 29 | "lr": 1e-5, 30 | "use_half_precision": 0 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ExPred 2 | 3 | This is the implementation of the paper [Explain and Predict, and then Predict Again](https://dl.acm.org/doi/abs/10.1145/3437963.3441758) (accepted in WSDM2021). This code is implemented based on the pipeline model of the [Eraserbenchmark](http://www.eraserbenchmark.com/). All data used by the model can be found from the Eraser Benchmark, too. 4 | 5 | ## Usage: 6 | 1. Install the required packages from the ```requirements.txt``` by ```pip install -r requirements.txt``` 7 | 2. The implementation entry is under ```expred/train```. To run the training, simply copy and paste the following commands: 8 | ``` export PYTHONPATH=$PYTHONPATH:./ && python expred/train.py --data_dir /dir/to/your/datasets/{movies,fever,multirc} --output_dir /dir/to/your/trained_data --conf ./params/{movies,fever,multirc}_expred.json``` 9 | 10 | Not that depending on your hardware you may have to change the `batch_size` in the config file. -------------------------------------------------------------------------------- /scripts/run_b.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 10. 20. 50. ) 4 | gpu_id=1 5 | batch_size=16 6 | num_epochs=10 7 | dataset='movies' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | 11 | lambdas=( 2 ) 12 | gpu_id=1 13 | dataset='fever' 14 | 15 | for par_lambda in ${lambdas[@]}; do 16 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints; 17 | done 18 | 19 | dataset='multirc' 20 | 21 | for par_lambda in ${lambdas[@]}; do 22 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints; 23 | done 24 | -------------------------------------------------------------------------------- /expred/eraser_utils.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | 3 | from expred.models.pipeline.pipeline_utils import SentenceEvidence 4 | 5 | def get_docids(ann): 6 | ret = [] 7 | for ev_group in ann.evidences: 8 | for ev in ev_group: 9 | ret.append(ev.docid) 10 | return ret 11 | 12 | def extract_doc_ids_from_annotations(anns): 13 | ret = set() 14 | for ann in anns: 15 | ret |= set(get_docids(ann)) 16 | return ret 17 | 18 | def chain_sentence_evidences(sentences): 19 | kls = list(chain.from_iterable(s.kls for s in sentences)) 20 | document = list(chain.from_iterable(s.sentence for s in sentences)) 21 | assert len(kls) == len(document) 22 | return SentenceEvidence(kls=kls, 23 | ann_id=sentences[0].ann_id, 24 | sentence=document, 25 | docid=sentences[0].docid, 26 | index=sentences[0].index, 27 | query=sentences[0].query, 28 | has_evidence=any(map(lambda s: s.has_evidence, sentences))) -------------------------------------------------------------------------------- /expred/models/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | 5 | 6 | def resampling_rebalanced_crossentropy(seq_reduction : str ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: 7 | """ 8 | Returns the loss function with given seq_reduction strategy. 9 | 10 | The individual token loss of token $i$ in sequence $S$ s $\vertS_{t^i}\vert \cdot BCE(p^i, t^i)$ where 11 | * $p^i$ and $t^i$ are the predicted and target labels of token $i$ respectively 12 | * \vertS_{t^i}\vert is the number of tokens that have the same target a label as token $i$ 13 | :param seq_reduction: either 'none' or 'mean'. 14 | :return: 15 | """ 16 | def loss(y_pred, y_true): 17 | prior_pos = torch.mean(y_true, dim=-1, keepdims=True) # percentage of positive tokens (rational) 18 | prior_neg = torch.mean(1 - y_true, dim=-1, keepdim=True) # vice versa 19 | eps = 1e-10 20 | weight = y_true / (prior_pos + eps) + (1 - y_true) / (prior_neg + eps) 21 | ret = -weight * (y_true * (torch.log(y_pred + eps)) + (1 - y_true) * (torch.log(1 - y_pred + eps))) 22 | if seq_reduction == 'mean': 23 | return torch.mean(ret, dim=-1) 24 | elif seq_reduction == 'none': 25 | return ret 26 | 27 | return loss 28 | -------------------------------------------------------------------------------- /scripts/run_a.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 10. 20. 50. ) 4 | gpu_id=0 5 | batch_size=16 6 | num_epochs=10 7 | dataset='movies' 8 | exp_structure='rnr' 9 | benchmark_split='test' 10 | 11 | for par_lambda in ${lambdas[@]}; do 12 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints; 13 | done 14 | 15 | #lambdas=( 2 ) 16 | #gpu_id=1 17 | #dataset='fever' 18 | 19 | #for par_lambda in ${lambdas[@]}; do 20 | # python learn_to_interpret.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split}; 21 | #done 22 | 23 | #lambdas=( 2 ) 24 | #gpu_id=1 25 | #dataset='multirc' 26 | 27 | #for par_lambda in ${lambdas[@]}; do 28 | # python learn_to_interpret.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split}; 29 | #done 30 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_sweep_a_and_freeze.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 ) 4 | gpu_id=1 5 | batch_size=16 6 | num_epochs=10 7 | dataset='multirc' 8 | exp_structure='rnr' 9 | benchmark_split='val' 10 | train_on_portion='0.4' 11 | 12 | for par_lambda in ${lambdas[@]}; do 13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}; 14 | done 15 | 16 | par_lambda=1. 17 | datasets=( movies fever multirc ) 18 | exp_structures=( gru rnr ) 19 | freezes=( cls exp ) 20 | train_on_portion='0' 21 | 22 | for dataset in ${datasets[@]}; do 23 | for exp_structure in ${exp_structures[@]}; do 24 | for freeze in ${freezes[@]}; do 25 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --freeze_${freeze}; 26 | done 27 | done 28 | done 29 | 30 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ pytorch ] 9 | pull_request: 10 | branches: [ pytorch ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | env: 20 | GIT_TRACE: 1 21 | GIT_CURL_VERBOSE: 1 22 | - name: Set up Python 3.9 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: 3.9 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install flake8 pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | # stop the build if there are Python syntax errors or undefined names 34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 37 | - name: Test with pytest 38 | run: | 39 | pytest 40 | -------------------------------------------------------------------------------- /.idea/runConfigurations/mtl_pipeline_movies_regression_test.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 24 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_sweep_c.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. ) 4 | batch_size=6 5 | num_epochs=10 6 | dataset='multirc' 7 | exp_structure='rnr' 8 | benchmark_split='test' 9 | train_on_portion='0.4' 10 | 11 | python bert_as_tfkeras_layer.py --par_lambda 0.5 --gpu_id 0 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 12 | python bert_as_tfkeras_layer.py --par_lambda 1. --gpu_id 1 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 13 | python bert_as_tfkeras_layer.py --par_lambda 2. --gpu_id 2 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 14 | python bert_as_tfkeras_layer.py --par_lambda 5. --gpu_id 3 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 15 | -------------------------------------------------------------------------------- /scripts/run_multirc_rnr_sweep_node05.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. ) 4 | batch_size=6 5 | num_epochs=10 6 | dataset='multirc' 7 | exp_structure='rnr' 8 | benchmark_split='test' 9 | train_on_portion='0.4' 10 | 11 | python bert_as_tfkeras_layer.py --par_lambda 0.001 --gpu_id 0 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 12 | python bert_as_tfkeras_layer.py --par_lambda 0.002 --gpu_id 1 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 13 | python bert_as_tfkeras_layer.py --par_lambda 0.005 --gpu_id 2 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 14 | python bert_as_tfkeras_layer.py --par_lambda 0.01 --gpu_id 3 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}& 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | model_components 2 | *deprecated* 3 | pipeline_outputs 4 | *.swp 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # Environments 89 | .env 90 | .venv 91 | env/ 92 | venv/ 93 | ENV/ 94 | env.bak/ 95 | venv.bak/ 96 | 97 | # Spyder project settings 98 | .spyderproject 99 | .spyproject 100 | 101 | # Rope project settings 102 | .ropeproject 103 | 104 | # mkdocs documentation 105 | /site 106 | 107 | # mypy 108 | .mypy_cache/ 109 | 110 | rationale_benchmark/data/esnli_previous 111 | data/esnli_previous 112 | esnli_union/ 113 | 114 | misc/* 115 | -------------------------------------------------------------------------------- /expred/rationale_tokenization.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers.tokenization_bert import WordpieceTokenizer, BertTokenizer, BasicTokenizer, whitespace_tokenize 3 | from transformers.file_utils import http_get 4 | 5 | 6 | def printable_text(text): 7 | if isinstance(text, str): 8 | return text 9 | elif isinstance(text, bytes): 10 | return text.decode("utf-8", "ignore") 11 | else: 12 | raise ValueError("Unsupported string type: %s" % (type(text))) 13 | 14 | 15 | def convert_to_unicode(text): 16 | if isinstance(text, str): 17 | return text 18 | elif isinstance(text, bytes): 19 | return text.decode("utf-8", "ignore") 20 | else: 21 | raise ValueError("Unsupported string type: %s" % (type(text))) 22 | 23 | 24 | class BasicRationalTokenizer(BasicTokenizer): 25 | def __init__(self, do_lower_case=True): 26 | self.do_lower_case = do_lower_case 27 | 28 | def _is_token_rational(self, token_idx, evidences): 29 | if evidences is None: 30 | return 0 31 | for ev in evidences: 32 | if token_idx >= ev.start_token and token_idx < ev.end_token: 33 | return 1 34 | return 0 35 | 36 | def tokenize(self, text, evidences): 37 | text = convert_to_unicode(text) 38 | text = self._clean_text(text) 39 | orig_tokens = whitespace_tokenize(text) 40 | split_tokens = [] 41 | split_rations = [] 42 | for token_idx, token in enumerate(orig_tokens): 43 | if self.do_lower_case: 44 | token = token.lower() 45 | token = self._run_strip_accents(token) 46 | sub_tokens = self._run_split_on_punc(token) 47 | sub_tokens = ' '.join(sub_tokens).strip().split() 48 | if len(sub_tokens) > 0: 49 | split_tokens.extend(sub_tokens) 50 | ration = self._is_token_rational(token_idx, evidences) 51 | split_rations.extend([ration] * len(sub_tokens)) 52 | return zip(split_tokens, split_rations) 53 | 54 | 55 | class FullRationaleTokenizer(BertTokenizer): # Test passed :) 56 | def __init__(self, do_lower_case=True): 57 | if not os.path.isfile('bert-base-uncased-vocab.txt'): 58 | http_get("https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt", 59 | open("bert-base-uncased-vocab.txt", 'wb+')) 60 | super(FullRationaleTokenizer, self).__init__('bert-base-uncased-vocab.txt') 61 | self.basic_rational_tokenizer = BasicRationalTokenizer(do_lower_case=do_lower_case) 62 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token="[UNK]") 63 | 64 | def tokenize(self, text, evidences=None): 65 | split_tokens = [] 66 | split_rations = [] 67 | for token, ration in self.basic_rational_tokenizer.tokenize(text, evidences): 68 | for sub_token in self.wordpiece_tokenizer.tokenize(token): 69 | split_tokens.append(sub_token) 70 | split_rations.append(ration) 71 | return list(zip(split_tokens, split_rations)) 72 | 73 | 74 | -------------------------------------------------------------------------------- /expred/bert_rational_feature.py: -------------------------------------------------------------------------------- 1 | # bert_rational_feature.py 2 | from tqdm import tqdm_notebook 3 | import expred.rationale_tokenization as tokenization 4 | import logging 5 | 6 | IRRATIONAL = 0 7 | RATIONAL = 1 8 | 9 | logger = logging.getLogger('feature converter') 10 | logger.setLevel(logging.INFO) 11 | 12 | 13 | class InputRationalExample(object): 14 | def __init__(self, guid, text_a, text_b=None, label=None, evidences=None): 15 | self.guid = guid 16 | self.text_a = text_a 17 | self.text_b = text_b 18 | self.label = label 19 | self.evidences = evidences 20 | 21 | 22 | class InputRationalFeatures(object): 23 | def __init__(self, 24 | input_ids, 25 | input_mask, 26 | segment_ids, 27 | label_id, 28 | rations=None, 29 | is_real_example=True): 30 | self.rations = rations 31 | self.input_ids = input_ids 32 | self.input_mask = input_mask 33 | self.segment_ids = segment_ids 34 | self.label_id = label_id 35 | self.is_real_example = is_real_example 36 | 37 | 38 | def convert_single_rational_example(ex_index, example, label_list, max_seq_length, tokenizer): 39 | 40 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 41 | while True: 42 | total_length = len(tokens_a) + len(tokens_b) 43 | if total_length <= max_length: 44 | break 45 | if len(tokens_a) > len(tokens_b): 46 | tokens_a.pop() 47 | else: 48 | tokens_b.pop() 49 | 50 | label_map = {} 51 | for (i, label) in enumerate(label_list): 52 | label_map[label] = i 53 | 54 | tokens_a = tokenizer.tokenize(example.text_a) 55 | tokens_b = None # no tokens_b in our tasks 56 | if example.text_b: 57 | tokens_b = tokenizer.tokenize(example.text_b, example.evidences) 58 | 59 | if tokens_b: 60 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3) 61 | if len(tokens_a) > max_seq_length - 2: 62 | tokens_a = tokens_a[0:(max_seq_length - 2)] 63 | 64 | tokens = [] 65 | segment_ids = [] 66 | rations = [] 67 | tokens.append("[CLS]") 68 | segment_ids.append(0) 69 | rations.append(IRRATIONAL) 70 | 71 | for token, ration in tokens_a: 72 | tokens.append(token) 73 | segment_ids.append(0) 74 | rations.append(ration) 75 | tokens.append("[SEP]") 76 | segment_ids.append(0) 77 | rations.append(IRRATIONAL) 78 | 79 | if tokens_b: 80 | for token, ration in tokens_b: 81 | tokens.append(token) 82 | segment_ids.append(1) 83 | rations.append(ration) 84 | tokens.append("[SEP]") 85 | segment_ids.append(1) 86 | rations.append(IRRATIONAL) 87 | 88 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 89 | input_mask = [1] * len(input_ids) 90 | 91 | while len(input_ids) < max_seq_length: 92 | input_ids.append(0) 93 | input_mask.append(0) 94 | segment_ids.append(0) 95 | rations.append(IRRATIONAL) 96 | 97 | assert len(input_ids) == max_seq_length 98 | assert len(input_mask) == max_seq_length 99 | assert len(segment_ids) == max_seq_length 100 | assert len(rations) == max_seq_length 101 | 102 | label_id = label_map[example.label] 103 | 104 | if ex_index < 5: 105 | logger.info("*** Example ***") 106 | logger.info("guid: %s" % (example.guid)) 107 | logger.info("tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens])) 108 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) 109 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) 110 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids])) 111 | logger.info("rations: %s" % " ".join([str(x) for x in rations])) 112 | logger.info('') 113 | logger.info("label: %s (id = %d)" % (example.label, label_id)) 114 | 115 | feature = InputRationalFeatures( 116 | input_ids=input_ids, 117 | input_mask=input_mask, 118 | segment_ids=segment_ids, 119 | label_id=label_id, 120 | rations=rations, 121 | is_real_example=True) 122 | return feature 123 | 124 | 125 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer): 126 | features = [] 127 | for (ex_index, example) in enumerate(tqdm_notebook(examples, desc="Converting examples to features")): 128 | if ex_index % 10000 == 0: 129 | logger.info("Writing example %d of %d" % (ex_index, len(examples))) 130 | feature = convert_single_rational_example(ex_index, example, label_list, 131 | max_seq_length, tokenizer) 132 | features.append(feature) 133 | return features 134 | -------------------------------------------------------------------------------- /sample_rationales_dataset.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from functools import reduce 4 | 5 | import argparse 6 | import json 7 | import logging 8 | import random 9 | import shutil 10 | import sys 11 | from math import floor 12 | from pathlib import Path 13 | 14 | logger = logging.getLogger(__name__) 15 | logger.setLevel(logging.INFO) 16 | 17 | 18 | def is_valid_file(parser, arg): 19 | path = Path(arg) 20 | if not path.exists(): 21 | parser.error("The file %s does not exist!" % arg) 22 | else: 23 | return path 24 | 25 | 26 | class Range(object): 27 | def __init__(self, start, end): 28 | self.start = start 29 | self.end = end 30 | 31 | def __eq__(self, other): 32 | return self.start <= other <= self.end 33 | 34 | def __contains__(self, item): 35 | return self.__eq__(item) 36 | 37 | def __iter__(self): 38 | yield self 39 | 40 | def __str__(self): 41 | return '[{0},{1}]'.format(self.start, self.end) 42 | 43 | 44 | copy_splits = ['val', 'test'] 45 | reduce_splits = ['train'] 46 | 47 | def create_output_folder(dataset_dir: Path, keep_rationals_fraction: float, prefix:str): 48 | dataset_name = dataset_dir.stem 49 | datasets_root = dataset_dir.parent 50 | parts = ([prefix] if prefix != '' else []) + [dataset_name, str(keep_rationals_fraction)] 51 | output_path: Path = datasets_root / '_'.join(parts) 52 | 53 | if output_path.exists(): 54 | raise IOError(f'Output folder {output_path} already exits!') 55 | 56 | output_path.mkdir(parents=True, exist_ok=False) 57 | 58 | return output_path 59 | 60 | def copy_split(split, input_dir : Path, output_path : Path): 61 | file_name = f'{split}.jsonl' 62 | shutil.copy(input_dir / file_name, output_path / file_name) 63 | 64 | def sample_split(split, input_dir : Path, output_path : Path, keep_rationals_fraction :int): 65 | file_name = f'{split}.jsonl' 66 | with open(input_dir / file_name, 'r') as f: 67 | annotations = [json.loads(line) for line in f] 68 | 69 | n = len(annotations) 70 | n_keep_rationales = floor(n * keep_rationals_fraction) 71 | indices_to_keep = random.sample(range(n), n_keep_rationales) 72 | 73 | def remove_rationales(annotation): 74 | annotation['evidences'] = [] 75 | return annotation 76 | 77 | reduces_annotations = [ 78 | ann if i in indices_to_keep else remove_rationales(ann) 79 | for i, ann in enumerate(annotations) 80 | ] 81 | 82 | with open(output_path / file_name, 'w') as f: 83 | dumped_annotations = map(json.dumps, reduces_annotations) 84 | f.write('\n'.join(dumped_annotations)) 85 | 86 | def sample_dataset(dataset_dir: Path, output_path : Path, keep_rationals_fraction: float,): 87 | logger.info(f'Sampling for keep fraction {keep_rationals_fraction} ') 88 | 89 | logger.info('Copying documents') 90 | # copy documents 91 | if (dataset_dir / 'docs').exists(): 92 | shutil.copytree(dataset_dir / 'docs', output_path / 'docs') 93 | elif (dataset_dir / 'docs.jsonl').exists(): 94 | shutil.copy(dataset_dir / 'docs.jsonl', output_path / 'docs.jsonl') 95 | else: 96 | raise ValueError(f'No documents found in {dataset_dir}') 97 | 98 | logger.info(f'Copying unchanged split {copy_splits}') 99 | 100 | # copy 101 | for split in copy_splits: 102 | copy_split(split, input_dir=dataset_dir, output_path=output_path) 103 | 104 | logger.info(f'Sampling and writing results for {reduce_splits}') 105 | 106 | # sample and copy 107 | for split in reduce_splits: 108 | sample_split(split, dataset_dir, output_path, keep_rationals_fraction) 109 | 110 | return output_path 111 | 112 | def main(args): 113 | parser = argparse.ArgumentParser( 114 | 'Takes an eraser dataset and samples a given fraction of the annotations to keep the rationales and removes the rest') 115 | parser.add_argument('--dataset_dir', type=lambda x: is_valid_file(parser, x)) 116 | parser.add_argument('--prefix', type=str, default='') 117 | parser.add_argument('--keep_rationals_fractions', nargs='+', type=float, choices=Range(0, 1)) 118 | 119 | args = parser.parse_args(args) 120 | 121 | logger.info( 122 | f'Running sampling process for fraction {args.keep_rationals_fractions} for dataset {args.dataset_dir}' 123 | ) 124 | 125 | keep_rationals_fractions = [1]+sorted(args.keep_rationals_fractions, reverse=True) 126 | 127 | relative_fractions = [keep_rationals_fractions[i]/keep_rationals_fractions[i-1] for i in range(1, len(keep_rationals_fractions))] 128 | input_dir = args.dataset_dir 129 | 130 | for fraction, rel_fraction in zip(keep_rationals_fractions[1:], relative_fractions): 131 | output_path = create_output_folder(args.dataset_dir, fraction, args.prefix) 132 | input_dir = sample_dataset(input_dir, output_path, rel_fraction) 133 | 134 | 135 | if __name__ == '__main__': 136 | main(sys.argv[1:]) 137 | -------------------------------------------------------------------------------- /expred/preprocessing.py: -------------------------------------------------------------------------------- 1 | # preprocessing.py 2 | from itertools import chain 3 | 4 | from copy import deepcopy 5 | from expred.bert_rational_feature import InputRationalExample, convert_examples_to_features 6 | # from config import * 7 | from expred.utils import Evidence 8 | import logging 9 | 10 | IRRATIONAL = 0 11 | RATIONAL = 1 12 | 13 | logger = logging.getLogger('preprocessing.py') 14 | logger.setLevel(logging.INFO) 15 | 16 | 17 | def load_bert_features(data, docs, label_list, max_seq_length, merge_evidences, tokenizer): 18 | input_examples = [] 19 | for ann in data: 20 | text_a = ann.query 21 | label = ann.classification 22 | if not merge_evidences: 23 | for ev_group in ann.evidences: 24 | doc_ids = list(set([ev.docid for ev in ev_group])) 25 | sentences = chain.from_iterable(docs[doc_id] for doc_id in doc_ids) 26 | flattened_tokens = chain(*sentences) 27 | text_b = ' '.join(flattened_tokens) 28 | evidences = ev_group 29 | input_examples.append(InputRationalExample(guid=None, 30 | text_a=text_a, 31 | text_b=text_b, 32 | label=label, 33 | evidences=evidences)) 34 | if merge_evidences: 35 | docids_to_offsets = dict() 36 | latest_offset = 0 37 | example_evidences = [] 38 | text_b_tokens = [] 39 | for ev_group in ann.evidences: 40 | for ev in ev_group: 41 | if ev.docid in docids_to_offsets: 42 | offset = docids_to_offsets[ev.docid] 43 | else: 44 | tokens = list(chain.from_iterable(docs[ev.docid])) 45 | docids_to_offsets[ev.docid] = latest_offset 46 | offset = latest_offset 47 | latest_offset += len(tokens) 48 | text_b_tokens += tokens 49 | example_ev = Evidence(text=ev.text, 50 | docid=ev.docid, 51 | start_token=offset + ev.start_token, 52 | end_token=offset + ev.end_token, 53 | start_sentence=ev.start_sentence, 54 | end_sentence=ev.end_sentence) 55 | example_evidences.append(deepcopy(example_ev)) 56 | input_examples.append(InputRationalExample(guid=None, 57 | text_b=' '.join(text_b_tokens), 58 | text_a=text_a, 59 | label=label, 60 | evidences=example_evidences)) 61 | # print(input_examples[-1].text_b, input_examples[-1].text_a, input_examples[-1].evi) 62 | 63 | features = convert_examples_to_features(input_examples, label_list, max_seq_length, tokenizer) 64 | return features 65 | 66 | 67 | # def convert_bert_features(features, with_label_id, with_rations, exp_output='gru'): 68 | # feature_names = "input_ids input_mask segment_ids".split() 69 | # 70 | # input_ids, input_masks, segment_ids = \ 71 | # list(map(lambda x: [getattr(f, x) for f in features], feature_names)) 72 | # 73 | # rets = [input_ids, input_masks, segment_ids] 74 | # 75 | # if with_rations: 76 | # feature_names.append('rations') 77 | # rations = [getattr(f, 'rations') for f in features] 78 | # rations = np.array(rations).reshape([-1, MAX_SEQ_LENGTH, 1]) 79 | # if exp_output == 'interval': 80 | # rations = np.concatenate([np.zeros((rations.shape[0], 1, 1)), 81 | # rations, 82 | # np.zeros((rations.shape[0], 1, 1))], axis=-2) 83 | # rations = rations[:, 1:, :] - rations[:, :-1, :] 84 | # rations_start = (rations > 0)[:, :-1, :].astype(np.int32) 85 | # rations_end = (rations < 0)[:, 1:, :].astype(np.int32) 86 | # rations = np.concatenate((rations_start, rations_end), axis=-1) 87 | # rets.append(rations) 88 | # else: 89 | # rets.append(None) 90 | # 91 | # if with_label_id: 92 | # feature_names.append('label_id') 93 | # label_id = [getattr(f, 'label_id') for f in features] 94 | # labels = np.array(label_id).reshape(-1, 1) 95 | # rets.append(labels) 96 | # else: 97 | # rets.append(None) 98 | # return rets 99 | 100 | 101 | # def preprocess(data, docs, label_list, dataset_name, max_seq_length, exp_output, merge_evidences, tokenizer): 102 | # features = load_bert_features(data, docs, label_list, max_seq_length, merge_evidences, tokenizer) 103 | # 104 | # with_rations = ('cls' not in dataset_name) 105 | # with_lable_id = ('seq' not in dataset_name) 106 | # 107 | # return convert_bert_features(features, with_lable_id, with_rations, exp_output) 108 | -------------------------------------------------------------------------------- /expred/tokenizer.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import List, Dict, Tuple 3 | 4 | import os 5 | import torch 6 | from transformers import BertTokenizer, logger 7 | 8 | from expred.utils import Evidence, Annotation 9 | 10 | 11 | class BertTokenizerWithMapping(BertTokenizer): 12 | def __init__(self, *args, **kwargs): 13 | super(BertTokenizerWithMapping, self).__init__(*args, **kwargs) 14 | 15 | def tokenize_doc(self, doc: List[List[str]], 16 | special_token_map: Dict[str, int]) -> \ 17 | Tuple[List[List[str]], List[List[Tuple[int, int]]]]: 18 | """ Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words""" 19 | sents = [] 20 | sent_token_spans = [] 21 | for sent in doc: 22 | tokens = [] 23 | spans = [] 24 | start = 0 25 | for w in sent: 26 | if w in special_token_map: 27 | tokens.append(w) 28 | else: 29 | tokens.extend(super(BertTokenizerWithMapping, self).tokenize(w)) 30 | end = len(tokens) 31 | spans.append((start, end)) 32 | start = end 33 | sents.append(tokens) 34 | sent_token_spans.append(spans) 35 | return sents, sent_token_spans 36 | 37 | def encode_doc(self, 38 | doc: List[List[str]], 39 | special_token_map) -> List[List[int]]: 40 | # return [list(chain.from_iterable(special_token_map.get(w, tokenizer.encode(w, add_special_tokens=False)) 41 | # for w in s)) for s in doc] 42 | return [[special_token_map.get(w, self.convert_tokens_to_ids(w)) 43 | for w in s] 44 | for s in doc] 45 | 46 | def _encode_docs_maybe_load_from_cache(self, documents, cache_fname): 47 | if os.path.exists(cache_fname): 48 | logger.info(f'Loading interned documents from {cache_fname}') 49 | (encoded_docs, encoded_doc_token_slides) = torch.load(cache_fname) 50 | else: 51 | tokenizer = self 52 | logger.info(f'Interning documents') 53 | special_token_map = { 54 | 'SEP': tokenizer.sep_token_id, 55 | '[SEP]': tokenizer.sep_token_id, 56 | '[sep]': tokenizer.sep_token_id, 57 | 'UNK': tokenizer.unk_token_id, 58 | '[UNK]': tokenizer.unk_token_id, 59 | '[unk]': tokenizer.unk_token_id, 60 | 'PAD': tokenizer.unk_token_id, 61 | '[PAD]': tokenizer.unk_token_id, 62 | '[pad]': tokenizer.unk_token_id, 63 | } 64 | encoded_docs = {} 65 | encoded_doc_token_slides = {} 66 | for d, doc in documents.items(): 67 | tokenized_doc, w_slices = self.tokenize_doc(doc, special_token_map=special_token_map) 68 | encoded_docs[d] = self.encode_doc(tokenized_doc, special_token_map=special_token_map) 69 | encoded_doc_token_slides[d] = w_slices 70 | torch.save((encoded_docs, encoded_doc_token_slides), cache_fname) 71 | return encoded_docs, encoded_doc_token_slides 72 | 73 | def encode_docs(self, documents, cache_dir): 74 | cache_fname = os.path.join(cache_dir, 'preprocessed.pkl') 75 | encoded_docs, encoded_doc_token_slides = self._encode_docs_maybe_load_from_cache(documents, cache_fname) 76 | return encoded_docs, encoded_doc_token_slides 77 | 78 | def encode_annotations(self, annotations): 79 | ret = [] 80 | for ann in annotations: 81 | ev_groups = [] 82 | for ev_group in ann.evidences: 83 | evs = [] 84 | for ev in ev_group: 85 | text = list(chain.from_iterable(self.tokenize(w) 86 | for w in ev.text.split())) 87 | if len(text) == 0: 88 | continue 89 | text = self.encode(text, add_special_tokens=False) 90 | evs.append(Evidence(text=tuple(text), 91 | docid=ev.docid, 92 | start_token=ev.start_token, 93 | end_token=ev.end_token, 94 | start_sentence=ev.start_sentence, 95 | end_sentence=ev.end_sentence)) 96 | ev_groups.append(tuple(evs)) 97 | query = list(chain.from_iterable(self.tokenize(w) 98 | for w in ann.query.split())) 99 | if len(query) > 0: 100 | query = self.encode(query, add_special_tokens=False) 101 | else: 102 | query = [] 103 | ret.append(Annotation(annotation_id=ann.annotation_id, 104 | query=tuple(query), 105 | evidences=frozenset(ev_groups), 106 | classification=ann.classification, 107 | query_type=ann.query_type, 108 | docids=ann.docids)) 109 | return ret 110 | -------------------------------------------------------------------------------- /expred/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Dict, List, Set 3 | 4 | import numpy as np 5 | from gensim.models import KeyedVectors 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence 10 | 11 | 12 | @dataclass(eq=True, frozen=True) 13 | class PaddedSequence: 14 | """A utility class for padding variable length sequences mean for RNN input 15 | This class is in the style of PackedSequence from the PyTorch RNN Utils, 16 | but is somewhat more manual in approach. It provides the ability to generate masks 17 | for outputs of the same input dimensions. 18 | The constructor should never be called directly and should only be called via 19 | the autopad classmethod. 20 | 21 | We'd love to delete this, but we pad_sequence, pack_padded_sequence, and 22 | pad_packed_sequence all require shuffling around tuples of information, and some 23 | convenience methods using these are nice to have. 24 | """ 25 | 26 | data: torch.Tensor 27 | batch_sizes: torch.Tensor 28 | batch_first: bool = False 29 | 30 | @classmethod 31 | def autopad(cls, data, batch_first: bool = False, padding_value=0, device=None) -> 'PaddedSequence': 32 | # handle tensors of size 0 (single item) 33 | data_ = [] 34 | for d in data: 35 | if len(d.size()) == 0: 36 | d = d.unsqueeze(0) 37 | data_.append(d) 38 | padded = pad_sequence(data_, batch_first=batch_first, padding_value=padding_value) 39 | if batch_first: 40 | batch_lengths = torch.LongTensor([len(x) for x in data_]) 41 | if any([x == 0 for x in batch_lengths]): 42 | raise ValueError( 43 | "Found a 0 length batch element, this can't possibly be right: {}".format(batch_lengths)) 44 | else: 45 | # TODO actually test this codepath 46 | batch_lengths = torch.LongTensor([len(x) for x in data]) 47 | return PaddedSequence(padded, batch_lengths, batch_first).to(device=device) 48 | 49 | #@classmethod 50 | #def autopad(cls, data, len_queries, max_length, batch_first, device): 51 | 52 | 53 | def pack_other(self, data: torch.Tensor): 54 | return pack_padded_sequence(data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False) 55 | 56 | @classmethod 57 | def from_packed_sequence(cls, ps: PackedSequence, batch_first: bool, padding_value=0) -> 'PaddedSequence': 58 | padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value) 59 | return PaddedSequence(padded, batch_sizes, batch_first) 60 | 61 | def cuda(self) -> 'PaddedSequence': 62 | return PaddedSequence(self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first) 63 | 64 | def to(self, dtype=None, device=None, copy=False, non_blocking=False) -> 'PaddedSequence': 65 | # TODO make to() support all of the torch.Tensor to() variants 66 | return PaddedSequence( 67 | self.data.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking), 68 | self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking), 69 | batch_first=self.batch_first) 70 | 71 | def mask(self, on=int(0), off=int(0), device='cpu', size=None, dtype=None) -> torch.Tensor: 72 | if size is None: 73 | size = self.data.size() 74 | out_tensor = torch.zeros(*size, dtype=dtype) 75 | # TODO this can be done more efficiently 76 | out_tensor.fill_(off) 77 | # note to self: these are probably less efficient than explicilty populating the off values instead of the on values. 78 | if self.batch_first: 79 | for i, bl in enumerate(self.batch_sizes): 80 | out_tensor[i, :bl] = on 81 | else: 82 | for i, bl in enumerate(self.batch_sizes): 83 | out_tensor[:bl, i] = on 84 | return out_tensor.to(device) 85 | 86 | def unpad(self, other: torch.Tensor) -> List[torch.Tensor]: 87 | out = [] 88 | for o, bl in zip(other, self.batch_sizes): 89 | out.append(torch.cat((o[:bl], torch.zeros(max(0, bl-len(o)))))) 90 | return out 91 | 92 | def flip(self) -> 'PaddedSequence': 93 | return PaddedSequence(self.data.transpose(0, 1), not self.batch_first, self.padding_value) 94 | 95 | 96 | def extract_embeddings(vocab: Set[str], embedding_file: str, unk_token: str = 'UNK', pad_token: str = 'PAD') -> ( 97 | nn.Embedding, Dict[str, int], List[str]): 98 | vocab = vocab | set([unk_token, pad_token]) 99 | if embedding_file.endswith('.bin'): 100 | WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True) 101 | 102 | word_to_vector = dict() 103 | WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()]) 104 | 105 | if unk_token not in WVs: 106 | mean_vector = np.mean(WV_matrix, axis=0) 107 | word_to_vector[unk_token] = mean_vector 108 | if pad_token not in WVs: 109 | word_to_vector[pad_token] = np.zeros(WVs.vector_size) 110 | 111 | for v in vocab: 112 | if v in WVs: 113 | word_to_vector[v] = WVs[v] 114 | 115 | interner = dict() 116 | deinterner = list() 117 | vectors = [] 118 | count = 0 119 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})): 120 | vector = word_to_vector[word] 121 | vectors.append(np.array(vector)) 122 | interner[word] = count 123 | deinterner.append(word) 124 | count += 1 125 | vectors = torch.FloatTensor(np.array(vectors)) 126 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token]) 127 | embedding.weight.requires_grad = False 128 | return embedding, interner, deinterner 129 | elif embedding_file.endswith('.txt'): 130 | word_to_vector = dict() 131 | vector = [] 132 | with open(embedding_file, 'r') as inf: 133 | for line in inf: 134 | contents = line.strip().split() 135 | word = contents[0] 136 | vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0) 137 | word_to_vector[word] = vector 138 | embed_size = vector.size() 139 | if unk_token not in word_to_vector: 140 | mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0) 141 | word_to_vector[unk_token] = mean_vector.unsqueeze(0) 142 | if pad_token not in word_to_vector: 143 | word_to_vector[pad_token] = torch.zeros(embed_size) 144 | interner = dict() 145 | deinterner = list() 146 | vectors = [] 147 | count = 0 148 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})): 149 | vector = word_to_vector[word] 150 | vectors.append(vector) 151 | interner[word] = count 152 | deinterner.append(word) 153 | count += 1 154 | vectors = torch.cat(vectors, dim=0) 155 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token]) 156 | embedding.weight.requires_grad = False 157 | return embedding, interner, deinterner 158 | else: 159 | raise ValueError("Unable to open embeddings file {}".format(embedding_file)) 160 | -------------------------------------------------------------------------------- /expred/eraser_benchmark.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from itertools import chain 4 | from copy import deepcopy 5 | 6 | from expred.bert_rational_feature import InputRationalExample, convert_examples_to_features 7 | from expred.eraser_utils import extract_doc_ids_from_annotations 8 | from expred.utils import Annotation 9 | 10 | 11 | def remove_rations(sentence, annotation): 12 | sentence = sentence.lower().split() 13 | rationales = annotation['rationales'][0]['hard_rationale_predictions'] 14 | rationales = [{'end_token': 0, 'start_token': 0}] \ 15 | + sorted(rationales, key=lambda x: x['start_token']) \ 16 | + [{'start_token': len(sentence), 'end_token': len(sentence)}] 17 | ret = [] 18 | for rat_id, rat in enumerate(rationales[:-1]): 19 | ret += ['.'] * (rat['end_token'] - rat['start_token']) \ 20 | + sentence[rat['end_token'] 21 | : rationales[rat_id + 1]['start_token']] 22 | return ' '.join(ret) 23 | 24 | 25 | def extract_rations(sentence, rationale): 26 | sentence = sentence.lower().split() 27 | rationales = rationale['rationales'][0]['hard_rationale_predictions'] 28 | rationales = [{'end_token': 0, 'start_token': 0}] \ 29 | + sorted(rationales, key=lambda x: x['start_token']) \ 30 | + [{'start_token': len(sentence), 'end_token': len(sentence)}] 31 | ret = [] 32 | for rat_id, rat in enumerate(rationales[:-1]): 33 | ret += sentence[rat['start_token']: rat['end_token']] \ 34 | + ['.'] * (rationales[rat_id + 1] 35 | ['start_token'] - rat['end_token']) 36 | return ' '.join(ret) 37 | 38 | 39 | def ce_load_bert_features(rationales, docs, label_list, decorate, max_seq_length, gpu_id, tokenizer=None): 40 | input_examples = [] 41 | for r_idx, rational in enumerate(rationales): 42 | text_a = rational['query'] 43 | docids = rational['docids'] 44 | sentences = chain.from_iterable(docs[docid] for docid in docids) 45 | flattened_tokens = chain(*sentences) 46 | text_b = ' '.join(flattened_tokens) 47 | text_b = decorate(text_b, rational) 48 | label = rational['classification'] 49 | evidences = None 50 | input_examples.append(InputRationalExample(guid=None, 51 | text_a=text_a, 52 | text_b=text_b, 53 | label=label, 54 | evidences=evidences)) 55 | features = convert_examples_to_features(input_examples, label_list, max_seq_length, tokenizer) 56 | return features 57 | 58 | 59 | # def ce_preprocess(rationales, docs, label_list, dataset_name, decorate, max_seq_length, exp_output, gpu_id, tokenizer): 60 | # features = ce_load_bert_features(rationales, docs, label_list, decorate, max_seq_length, gpu_id, tokenizer) 61 | # 62 | # with_rations = ('cls' not in dataset_name) 63 | # with_lable_id = ('seq' not in dataset_name) 64 | # 65 | # return convert_bert_features(features, with_lable_id, with_rations, exp_output) 66 | 67 | 68 | # def get_cls_score(model, rationales, docs, label_list, dataset, decorate, max_seq_length, exp_output, gpu_id, tokenizer): 69 | # rets = ce_preprocess(rationales, docs, label_list, dataset, decorate, max_seq_length, exp_output, gpu_id, tokenizer) 70 | # _input_ids, _input_masks, _segment_ids, _rations, _labels = rets 71 | # 72 | # _inputs = [_input_ids, _input_masks, _segment_ids] 73 | # _pred = model.predict(_inputs) 74 | # return (np.hstack([1 - _pred[0], _pred[0]])) 75 | 76 | 77 | def add_cls_scores(res, cls, c, s, label_list): 78 | res['classification_scores'] = {label_list[0]: cls[0], label_list[1]: cls[1]} 79 | res['comprehensiveness_classification_scores'] = {label_list[0]: c[0], label_list[1]: c[1]} 80 | res['sufficiency_classification_scores'] = {label_list[0]: s[0], label_list[1]: s[1]} 81 | return res 82 | 83 | 84 | def pred_to_exp_mask(exp_pred, count, threshold): 85 | if count is None: 86 | return (np.array(exp_pred).astype(np.float) >= threshold).astype(np.int32) 87 | temp = [(i, p) for i, p in enumerate(exp_pred)] 88 | temp = sorted(temp, key=lambda x: x[1], reverse=True) 89 | ret = np.zeros_like(exp_pred).astype(np.int32) 90 | for i, _ in temp[:count]: 91 | ret[i] = 1 92 | return ret 93 | 94 | 95 | def rational_bits_to_ev_generator(token_list, raw_input_or_docid, exp_pred, hard_selection_count=None, 96 | hard_selection_threshold=0.5): 97 | in_rationale = False 98 | if not isinstance(raw_input_or_docid, Annotation): 99 | docid = raw_input_or_docid 100 | else: 101 | docid = list(extract_doc_ids_from_annotations([raw_input_or_docid]))[0] 102 | ev = {'docid': docid, 103 | 'start_token': -1, 'end_token': -1, 'text': ''} 104 | exp_masks = pred_to_exp_mask( 105 | exp_pred, hard_selection_count, hard_selection_threshold) 106 | for i, p in enumerate(exp_masks): 107 | if p == 0 and in_rationale: # leave rational zone 108 | in_rationale = False 109 | ev['end_token'] = i 110 | ev['text'] = ' '.join( 111 | token_list[ev['start_token']: ev['end_token']]) 112 | yield deepcopy(ev) 113 | elif p == 1 and not in_rationale: # enter rational zone 114 | in_rationale = True 115 | ev['start_token'] = i 116 | if in_rationale: # the final non-padding token is rational 117 | ev['end_token'] = len(exp_pred) 118 | ev['text'] = ' '.join(token_list[ev['start_token']: ev['end_token']]) 119 | yield deepcopy(ev) 120 | 121 | 122 | # [SEP] == 102 123 | # [CLS] == 101 124 | # [PAD] == 0 125 | def extract_texts(tokens, exps=None, text_a=True, text_b=False): 126 | if tokens[0] == 101: 127 | endp_text_a = tokens.index(102) 128 | if text_b: 129 | endp_text_b = endp_text_a + 1 + \ 130 | tokens[endp_text_a + 1:].index(102) 131 | else: 132 | endp_text_a = tokens.index('[SEP]') 133 | if text_b: 134 | endp_text_b = endp_text_a + 1 + \ 135 | tokens[endp_text_a + 1:].index('[SEP]') 136 | ret_token = [] 137 | if text_a: 138 | ret_token += tokens[1: endp_text_a] 139 | if text_b: 140 | ret_token += tokens[endp_text_a + 1: endp_text_b] 141 | if exps is None: 142 | return ret_token 143 | else: 144 | ret_exps = [] 145 | if text_a: 146 | ret_exps += exps[1: endp_text_a] 147 | if text_b: 148 | ret_exps += exps[endp_text_a + 1: endp_text_b] 149 | return ret_token, ret_exps 150 | 151 | 152 | def rnr_matrix_to_rational_mask(rnr_matrix): 153 | start_logits, end_logits = rnr_matrix[:, :1], rnr_matrix[:, 1:] 154 | starts = np.round(start_logits).reshape((-1, 1)) 155 | ends = np.triu(end_logits) 156 | ends = starts * ends 157 | ends_args = np.argmax(ends, axis=1) 158 | ends = np.zeros_like(ends) 159 | for i in range(len(ends_args)): 160 | ends[i, ends_args[i]] = 1 161 | ends = starts * ends 162 | ends = np.sum(ends, axis=0, keepdims=True) 163 | rational_mask = np.cumsum(starts.reshape((1, -1)), axis=1) - np.cumsum(ends, axis=1) + ends 164 | return rational_mask 165 | 166 | 167 | # def pred_to_results(raw_input, input_ids, pred, 168 | # hard_selection_count, hard_selection_threshold, 169 | # vocab, docs, label_list, 170 | # exp_output): 171 | # cls_pred, exp_pred = pred 172 | # if exp_output == 'rnr': 173 | # exp_pred = rnr_matrix_to_rational_mask(exp_pred) 174 | # exp_pred = exp_pred.reshape((-1,)).tolist() 175 | # docid = list(raw_input.evidences)[0][0].docid 176 | # raw_sentence = ' '.join(list(chain.from_iterable(docs[docid]))) 177 | # raw_sentence = re.sub('\x12', '', raw_sentence) 178 | # raw_sentence = raw_sentence.lower().split() 179 | # token_ids, exp_pred = extract_texts(input_ids, exp_pred, text_a=False, text_b=True) 180 | # token_list, exp_pred = convert_subtoken_ids_to_tokens(token_ids, vocab, exp_pred, raw_sentence) 181 | # result = {'annotation_id': raw_input.annotation_id, 'query': raw_input.query} 182 | # ev_groups = [] 183 | # result['docids'] = [docid] 184 | # result['rationales'] = [{'docid': docid}] 185 | # for ev in rational_bits_to_ev_generator(token_list, raw_input, exp_pred, hard_selection_count, 186 | # hard_selection_threshold): 187 | # ev_groups.append(ev) 188 | # result['rationales'][-1]['hard_rationale_predictions'] = ev_groups 189 | # if exp_output != 'rnr': 190 | # result['rationales'][-1]['soft_rationale_predictions'] = exp_pred + [0] * (len(raw_sentence) - len(token_list)) 191 | # result['classification'] = label_list[int(round(cls_pred[0]))] 192 | # return result 193 | -------------------------------------------------------------------------------- /expred/models/mlp_mtl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import BertModel, BertTokenizer 5 | from expred.params import MTLParams 6 | from typing import Any, List 7 | from expred.models.model_utils import PaddedSequence 8 | 9 | 10 | class BertMTL(nn.Module): 11 | def __init__(self, 12 | bert_dir: str, 13 | tokenizer: BertTokenizer, 14 | mtl_params: MTLParams, 15 | max_length: int=512, 16 | use_half_precision=True): 17 | super(BertMTL, self).__init__() 18 | bare_bert = BertModel.from_pretrained(bert_dir) 19 | if use_half_precision: 20 | import apex 21 | bare_bert = bare_bert.half() 22 | self.bare_bert = bare_bert 23 | self.pad_token_id = tokenizer.pad_token_id 24 | self.cls_token_id = tokenizer.cls_token_id 25 | self.sep_token_id = tokenizer.sep_token_id 26 | self.max_length = max_length 27 | 28 | class ExpHead(nn.Module): 29 | def __init__(self, input_size, hidden_size): 30 | super(ExpHead, self).__init__() 31 | self.exp_gru = nn.GRU(input_size, hidden_size) 32 | self.exp_linear = nn.Linear(hidden_size, 1, bias=True) 33 | self.exp_act = nn.Sigmoid() 34 | 35 | def forward(self, x): 36 | return self.exp_act(self.exp_linear(self.exp_gru(x)[0])) 37 | 38 | #self.exp_head = lambda x: exp_act(exp_linear(exp_gru(x)[0])) 39 | self.exp_head = ExpHead(self.bare_bert.config.hidden_size, mtl_params.dim_exp_gru) 40 | self.cls_head = nn.Sequential( 41 | nn.Dropout(0.1), 42 | nn.Linear(self.bare_bert.config.hidden_size, mtl_params.dim_cls_linear, bias=True), 43 | nn.Tanh(), 44 | nn.Linear(mtl_params.dim_cls_linear, mtl_params.num_labels, bias=True), 45 | nn.Softmax(dim=-1) 46 | ) 47 | for layer in self.cls_head: 48 | if type(layer) == nn.Linear: 49 | nn.init.xavier_normal_(layer.weight) 50 | nn.init.xavier_normal_(self.exp_head.exp_linear.weight) 51 | 52 | def forward(self, 53 | query: List[torch.tensor], 54 | docids: List[Any], 55 | document_batch: List[torch.tensor]): 56 | #input_ids, token_type_ids=None, attention_mask=None, labels=None): 57 | assert len(query) == len(document_batch) 58 | #print(next(self.cls_head.parameters()).device) 59 | target_device = next(self.parameters()).device 60 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) 61 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) 62 | input_tensors = [] 63 | for q, d in zip(query, document_batch): 64 | if len(q) + len(d) + 2 > self.max_length: 65 | d = d[:(self.max_length - len(q) - 2)] 66 | input_tensors.append(torch.cat([cls_token, q, sep_token, d])) 67 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, 68 | device=target_device) 69 | attention_mask = bert_input.mask(on=1., off=0., device=target_device) 70 | exp_output, cls_output = self.bare_bert(bert_input.data, attention_mask=attention_mask) 71 | exp_output = self.exp_head(exp_output).squeeze() * attention_mask 72 | cls_output = self.cls_head(cls_output) 73 | assert torch.all(cls_output == cls_output) 74 | assert torch.all(exp_output == exp_output) 75 | return cls_output, exp_output, attention_mask 76 | 77 | 78 | class BertClassifier(nn.Module): 79 | """Thin wrapper around BertForSequenceClassification""" 80 | def __init__(self, 81 | bert_dir: str, 82 | pad_token_id: int, 83 | cls_token_id: int, 84 | sep_token_id: int, 85 | num_labels: int, 86 | mtl_params: MTLParams, 87 | max_length: int=512, 88 | use_half_precision=True): 89 | super(BertClassifier, self).__init__() 90 | bert = BertModel.from_pretrained(bert_dir, num_labels=num_labels) 91 | if use_half_precision: 92 | import apex 93 | bert = bert.half() 94 | self.bert = bert 95 | self.cls_head = nn.Sequential( 96 | nn.Dropout(0.1), 97 | nn.Linear(bert.config.hidden_size, mtl_params.dim_cls_linear, bias=True), 98 | nn.Tanh(), 99 | nn.Linear(mtl_params.dim_cls_linear, mtl_params.num_labels, bias=True), 100 | nn.Softmax(dim=-1) 101 | ) 102 | for layer in self.cls_head: 103 | if type(layer) == nn.Linear: 104 | nn.init.xavier_normal_(layer.weight) 105 | self.pad_token_id = pad_token_id 106 | self.cls_token_id = cls_token_id 107 | self.sep_token_id = sep_token_id 108 | self.max_length = max_length 109 | 110 | def forward(self, 111 | query: List[torch.tensor], 112 | docids: List[Any], 113 | document_batch: List[torch.tensor]): 114 | assert len(query) == len(document_batch) 115 | # note about device management: 116 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) 117 | # we want to keep these conf on the input device (assuming CPU) for as long as possible for cheap memory access 118 | target_device = next(self.parameters()).device 119 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) 120 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) 121 | input_tensors = [] 122 | position_ids = [] 123 | for q, d in zip(query, document_batch): 124 | if len(q) + len(d) + 2 > self.max_length: 125 | d = d[:(self.max_length - len(q) - 2)] 126 | input_tensors.append(torch.cat([cls_token, q, sep_token, d])) 127 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) 128 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) 129 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) 130 | _, classes = self.bert(bert_input.data, attention_mask=bert_input.mask(on=1., off=0., device=target_device), position_ids=positions.data) 131 | classes = self.cls_head(classes) 132 | assert torch.all(classes == classes) # for nans 133 | return classes 134 | 135 | 136 | class BertClassifier2(nn.Module): 137 | """Thin wrapper around BertForSequenceClassification""" 138 | def __init__(self, 139 | bert_dir: str, 140 | pad_token_id: int, 141 | cls_token_id: int, 142 | sep_token_id: int, 143 | num_labels: int, 144 | mtl_params: MTLParams, 145 | max_length: int=512, 146 | use_half_precision=True): 147 | super(BertClassifier2, self).__init__() 148 | bert = BertModel.from_pretrained(bert_dir, num_labels=num_labels) 149 | if use_half_precision: 150 | import apex 151 | bert = bert.half() 152 | self.bert = bert 153 | self.cls_head = nn.Sequential( 154 | #nn.Dropout(0.1), 155 | nn.Linear(bert.config.hidden_size, mtl_params.dim_cls_linear, bias=True), 156 | nn.Tanh(), 157 | nn.Linear(mtl_params.dim_cls_linear, mtl_params.num_labels, bias=True), 158 | nn.Softmax(dim=-1) 159 | ) 160 | for layer in self.cls_head: 161 | if type(layer) == nn.Linear: 162 | nn.init.xavier_normal_(layer.weight) 163 | self.pad_token_id = pad_token_id 164 | self.cls_token_id = cls_token_id 165 | self.sep_token_id = sep_token_id 166 | self.max_length = max_length 167 | 168 | def forward(self, 169 | query: List[torch.tensor], 170 | docids: List[Any], 171 | document_batch: List[torch.tensor]): 172 | assert len(query) == len(document_batch) 173 | # note about device management: 174 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module) 175 | # we want to keep these conf on the input device (assuming CPU) for as long as possible for cheap memory access 176 | target_device = next(self.parameters()).device 177 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device) 178 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device) 179 | input_tensors = [] 180 | position_ids = [] 181 | for q, d in zip(query, document_batch): 182 | if len(q) + len(d) + 2 > self.max_length: 183 | d = d[:(self.max_length - len(q) - 2)] 184 | input_tensors.append(torch.cat([cls_token, q, sep_token, d])) 185 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1)))) 186 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device) 187 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device) 188 | _, classes = self.bert(bert_input.data, attention_mask=bert_input.mask(on=1., off=0., device=target_device), position_ids=positions.data) 189 | classes = self.cls_head(classes) 190 | assert torch.all(classes == classes) # for nans 191 | return classes -------------------------------------------------------------------------------- /expred/models/pipeline/mtl_evidence_classifier.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | 5 | from collections import OrderedDict 6 | 7 | import wandb 8 | from typing import Dict, List, Tuple, Any 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from sklearn.metrics import accuracy_score, classification_report 14 | 15 | # from expred.utils import Annotation 16 | 17 | from expred.models.pipeline.pipeline_utils import SentenceEvidence 18 | from expred.models.pipeline.mtl_pipeline_utils import ( 19 | mask_annotations_to_evidence_classification, 20 | make_mtl_classification_preds_epoch 21 | ) 22 | 23 | 24 | def train_mtl_evidence_classifier(evidence_classifier: nn.Module, 25 | save_dir: str, 26 | train: Tuple[List[Tuple[str, SentenceEvidence]], Any], 27 | val: Tuple[List[Tuple[str, SentenceEvidence]], Any], 28 | documents: Dict[str, List[List[int]]], 29 | model_pars: dict, 30 | class_interner: Dict[str, int], 31 | optimizer=None, 32 | scheduler=None, 33 | tensorize_model_inputs: bool = True) -> Tuple[nn.Module, dict]: 34 | """ 35 | 36 | :param evidence_classifier: 37 | :param save_dir: 38 | :param train: 39 | :param val: 40 | :param documents: 41 | :param model_pars: 42 | :param class_interner: 43 | :param optimizer: 44 | :param scheduler: 45 | :param tensorize_model_inputs: 46 | :return: 47 | """ 48 | logging.info( 49 | f'Beginning training evidence classifier with {len(train[0])} annotations, {len(val[0])} for validation') 50 | # set up output directories 51 | evidence_classifier_output_dir = os.path.join(save_dir, 'evidence_classifier') 52 | os.makedirs(save_dir, exist_ok=True) 53 | os.makedirs(evidence_classifier_output_dir, exist_ok=True) 54 | model_save_file = os.path.join(evidence_classifier_output_dir, 'evidence_classifier.pt') 55 | epoch_save_file = os.path.join(evidence_classifier_output_dir, 'evidence_classifier_epoch_data.pt') 56 | 57 | # set up training (optimizer, loss, patience, ...) 58 | device = next(evidence_classifier.parameters()).device 59 | if optimizer is None: 60 | optimizer = torch.optim.Adam(evidence_classifier.parameters(), lr=model_pars['evidence_classifier']['lr']) 61 | criterion = nn.BCELoss(reduction='none') 62 | batch_size = model_pars['evidence_classifier']['batch_size'] 63 | epochs = model_pars['evidence_classifier']['epochs'] 64 | patience = model_pars['evidence_classifier']['patience'] 65 | max_grad_norm = model_pars['evidence_classifier'].get('max_grad_norm', None) 66 | 67 | # mask out the hard prediction (token 0) and convert to [SentenceEvidence...] 68 | evidence_train_data = mask_annotations_to_evidence_classification(train, class_interner) 69 | evidence_val_data = mask_annotations_to_evidence_classification(val, class_interner) 70 | 71 | class_labels = [k for k, v in sorted(class_interner.items())] 72 | 73 | results = { 74 | 'train_loss': [], 75 | 'train_f1': [], 76 | 'train_acc': [], 77 | 'val_loss': [], 78 | 'val_f1': [], 79 | 'val_acc': [], 80 | } 81 | best_epoch = -1 82 | best_val_loss = float('inf') 83 | best_model_state_dict = None 84 | start_epoch = 0 85 | epoch_data = {} 86 | if os.path.exists(epoch_save_file): 87 | logging.info(f'Restoring model from {model_save_file}') 88 | evidence_classifier.load_state_dict(torch.load(model_save_file)) 89 | epoch_data = torch.load(epoch_save_file) 90 | start_epoch = epoch_data['epoch'] + 1 91 | # handle finishing because patience was exceeded or we didn't get the best final epoch 92 | if bool(epoch_data.get('done', 0)): 93 | start_epoch = epochs 94 | results = epoch_data['results'] 95 | best_epoch = start_epoch 96 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()}) 97 | logging.info(f'Restoring training from epoch {start_epoch}') 98 | logging.info(f'Training evidence classifier from epoch {start_epoch} until epoch {epochs}') 99 | optimizer.zero_grad() 100 | for epoch in range(start_epoch, epochs): 101 | epoch_train_data = random.sample(evidence_train_data, k=len(evidence_train_data)) 102 | epoch_val_data = random.sample(evidence_val_data, k=len(evidence_val_data)) 103 | epoch_train_loss = 0 104 | evidence_classifier.train() 105 | logging.info( 106 | f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples') 107 | for batch_start in range(0, len(epoch_train_data), batch_size): 108 | batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))] 109 | targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) 110 | ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] 111 | targets = [[i == target for i in range(len(class_interner))] for target in targets] 112 | targets = torch.tensor(targets, dtype=torch.float, device=device) 113 | if tensorize_model_inputs: 114 | queries = [torch.tensor(q, dtype=torch.long) for q in queries] 115 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] 116 | preds = evidence_classifier(queries, ids, sentences) 117 | loss = criterion(preds, targets.to(device=preds.device)).sum() 118 | epoch_train_loss += loss.item() 119 | loss = loss / len(preds) # accumulate entire loss above 120 | loss.backward() 121 | assert loss == loss # for nans 122 | if max_grad_norm: 123 | torch.nn.utils.clip_grad_norm_(evidence_classifier.parameters(), max_grad_norm) 124 | optimizer.step() 125 | if scheduler: 126 | scheduler.step() 127 | optimizer.zero_grad() 128 | epoch_train_loss /= len(epoch_train_data) 129 | assert epoch_train_loss == epoch_train_loss # for nans 130 | results['train_loss'].append(epoch_train_loss) 131 | logging.info(f'Epoch {epoch} training loss {epoch_train_loss}') 132 | 133 | with torch.no_grad(): 134 | evidence_classifier.eval() 135 | epoch_train_loss, \ 136 | epoch_train_soft_pred, \ 137 | epoch_train_hard_pred, \ 138 | epoch_train_truth = make_mtl_classification_preds_epoch( 139 | classifier=evidence_classifier, 140 | data=epoch_train_data, 141 | class_interner=class_interner, 142 | batch_size=batch_size, 143 | device=device, 144 | criterion=criterion, 145 | tensorize_model_inputs=tensorize_model_inputs) 146 | results['train_f1'].append( 147 | classification_report(epoch_train_truth, epoch_train_hard_pred, target_names=class_labels, 148 | labels=list(range(len(class_labels))), output_dict=True)) 149 | results['train_acc'].append(accuracy_score(epoch_train_truth, epoch_train_hard_pred)) 150 | epoch_val_loss, \ 151 | epoch_val_soft_pred, \ 152 | epoch_val_hard_pred, \ 153 | epoch_val_truth = make_mtl_classification_preds_epoch( 154 | classifier=evidence_classifier, 155 | data=epoch_val_data, 156 | class_interner=class_interner, 157 | batch_size=batch_size, 158 | device=device, 159 | criterion=criterion, 160 | tensorize_model_inputs=tensorize_model_inputs) 161 | results['val_loss'].append(epoch_val_loss) 162 | results['val_f1'].append( 163 | classification_report(epoch_val_truth, epoch_val_hard_pred, target_names=class_labels, 164 | labels=list(range(len(class_labels))), output_dict=True)) 165 | results['val_acc'].append(accuracy_score(epoch_val_truth, epoch_val_hard_pred)) 166 | assert epoch_val_loss == epoch_val_loss # for nans 167 | logging.info(f'Epoch {epoch} val loss {epoch_val_loss}') 168 | logging.info(f'Epoch {epoch} val acc {results["val_acc"][-1]}') 169 | logging.info(f'Epoch {epoch} val f1 {results["val_f1"][-1]}') 170 | 171 | epoch_metrics = {metric: values[-1] for metric, values in results.items()} 172 | epoch_metrics['epoch'] = epoch 173 | wandb.log(epoch_metrics) 174 | 175 | if epoch_val_loss < best_val_loss: 176 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()}) 177 | best_epoch = epoch 178 | best_val_loss = epoch_val_loss 179 | epoch_data = { 180 | 'epoch': epoch, 181 | 'results': results, 182 | 'best_val_loss': best_val_loss, 183 | 'done': 0, 184 | } 185 | torch.save(evidence_classifier.state_dict(), model_save_file) 186 | torch.save(epoch_data, epoch_save_file) 187 | logging.debug(f'Epoch {epoch} new best model with val loss {epoch_val_loss}') 188 | if epoch - best_epoch > patience: 189 | logging.info(f'Exiting after epoch {epoch} due to no improvement') 190 | epoch_data['done'] = 1 191 | torch.save(epoch_data, epoch_save_file) 192 | break 193 | 194 | epoch_data['done'] = 1 195 | epoch_data['results'] = results 196 | torch.save(epoch_data, epoch_save_file) 197 | evidence_classifier.load_state_dict(best_model_state_dict) 198 | evidence_classifier = evidence_classifier.to(device=device) 199 | evidence_classifier.eval() 200 | return evidence_classifier, results 201 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/Project_Default.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 176 | -------------------------------------------------------------------------------- /expred/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import os 4 | import re 5 | 6 | from dataclasses import dataclass, asdict, is_dataclass 7 | from itertools import chain 8 | from typing import Dict, List, Set, Tuple, Union, FrozenSet 9 | 10 | import transformers 11 | 12 | 13 | 14 | # tensorflow compatibility import, uncomment if needed 15 | # import tensorflow 16 | # if tensorflow.__version__.startswith('2'): 17 | # import tensorflow.compat.v1 as tf 18 | # 19 | # tf.disable_v2_behavior() 20 | # else: 21 | # import tensorflow as tf 22 | 23 | 24 | @dataclass(eq=True, frozen=True) 25 | class Evidence: 26 | """ 27 | (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience. 28 | Args: 29 | text: Some representation of the evidence text 30 | docid: Some identifier for the document 31 | start_token: The canonical start token, inclusive 32 | end_token: The canonical end token, exclusive 33 | start_sentence: Best guess start sentence, inclusive 34 | end_sentence: Best guess end sentence, exclusive 35 | """ 36 | text: Union[str, Tuple[int], Tuple[str]] 37 | docid: str 38 | start_token: int = -1 39 | end_token: int = -1 40 | start_sentence: int = -1 41 | end_sentence: int = -1 42 | 43 | 44 | @dataclass(eq=True, frozen=True) 45 | class Annotation: 46 | """ 47 | Args: 48 | annotation_id: unique ID for this annotation element 49 | query: some representation of a query string 50 | evidences: a set of "evidence groups". 51 | Each evidence group is: 52 | * sufficient to respond to the query (or justify an answer) 53 | * composed of one or more Evidences 54 | * may have multiple documents in it (depending on the dataset) 55 | - e-snli has multiple documents 56 | - other datasets do not 57 | classification: str 58 | query_type: Optional str, additional information about the query 59 | docids: a set of docids in which one may find evidence. 60 | """ 61 | annotation_id: str 62 | query: Union[str, Tuple[int]] 63 | evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]] 64 | classification: str 65 | query_type: str = None 66 | docids: Set[str] = None 67 | 68 | def all_evidences(self) -> Tuple[Evidence]: 69 | return tuple(list(chain.from_iterable(self.evidences))) 70 | 71 | 72 | def annotations_to_jsonl(annotations, output_file): 73 | with open(output_file, 'w') as of: 74 | for ann in sorted(annotations, key=lambda x: x.annotation_id): 75 | as_json = _annotation_to_dict(ann) 76 | as_str = json.dumps(as_json, sort_keys=True) 77 | of.write(as_str) 78 | of.write('\n') 79 | 80 | 81 | def _annotation_to_dict(dc): 82 | # convenience method 83 | if is_dataclass(dc): 84 | d = asdict(dc) 85 | ret = dict() 86 | for k, v in d.items(): 87 | ret[k] = _annotation_to_dict(v) 88 | return ret 89 | elif isinstance(dc, dict): 90 | ret = dict() 91 | for k, v in dc.items(): 92 | k = _annotation_to_dict(k) 93 | v = _annotation_to_dict(v) 94 | ret[k] = v 95 | return ret 96 | elif isinstance(dc, str): 97 | return dc 98 | elif isinstance(dc, (set, frozenset, list, tuple)): 99 | ret = [] 100 | for x in dc: 101 | ret.append(_annotation_to_dict(x)) 102 | return tuple(ret) 103 | else: 104 | return dc 105 | 106 | 107 | def load_jsonl(fp: str) -> List[dict]: 108 | ret = [] 109 | with open(fp, 'r') as inf: 110 | for line in inf: 111 | content = json.loads(line) 112 | ret.append(content) 113 | return ret 114 | 115 | 116 | def write_jsonl(jsonl, output_file): 117 | with open(output_file, 'w') as of: 118 | for js in jsonl: 119 | as_str = json.dumps(js, sort_keys=True) 120 | of.write(as_str) 121 | of.write('\n') 122 | 123 | 124 | def annotations_from_jsonl(fp: str) -> List[Annotation]: 125 | ret = [] 126 | with open(fp, 'r') as inf: 127 | for line in inf: 128 | content = json.loads(line) 129 | ev_groups = [] 130 | for ev_group in content['evidences']: 131 | ev_group = tuple([Evidence(**ev) for ev in ev_group]) 132 | ev_groups.append(ev_group) 133 | content['evidences'] = frozenset(ev_groups) 134 | ret.append(Annotation(**content)) 135 | return ret 136 | 137 | 138 | def decorate_with_docs_ids(annotation : Annotation): 139 | """ 140 | Extracts the docids if not already set from the annotations_id. 141 | Warning does not work for esnli. For esnli_flat it works! 142 | :param annotation: 143 | :return: annotation 144 | """ 145 | if annotation.docids is not None and len(annotation.docids) > 0: 146 | return annotation 147 | 148 | docids = [annotation.annotation_id] 149 | new_annotation = asdict(annotation) 150 | new_annotation['docids'] = docids 151 | assert docids is not None 152 | return Annotation(**new_annotation) 153 | 154 | def load_datasets(data_dir: str) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]: 155 | """Loads a training, validation, and test dataset 156 | 157 | Each dataset is assumed to have been serialized by annotations_to_jsonl, 158 | that is it is a list of json-serialized Annotation instances. 159 | """ 160 | train_data = annotations_from_jsonl(os.path.join(data_dir, 'train.jsonl')) 161 | val_data = annotations_from_jsonl(os.path.join(data_dir, 'val.jsonl')) 162 | test_data = annotations_from_jsonl(os.path.join(data_dir, 'test.jsonl')) 163 | 164 | splits = train_data, val_data, test_data 165 | splits = tuple([decorate_with_docs_ids(ann) for ann in split] for split in splits) 166 | return splits 167 | 168 | 169 | def load_documents(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: 170 | """Loads a subset of available documents from disk. 171 | 172 | Each document is assumed to be serialized as newline ('\n') separated sentences. 173 | Each sentence is assumed to be space (' ') joined tokens. 174 | """ 175 | if os.path.exists(os.path.join(data_dir, 'docs.jsonl')): 176 | assert not os.path.exists(os.path.join(data_dir, 'docs')) 177 | return load_documents_from_file(data_dir, docids) 178 | 179 | docs_dir = os.path.join(data_dir, 'docs') 180 | res = dict() 181 | if docids is None: 182 | docids = sorted(os.listdir(docs_dir)) 183 | else: 184 | docids = sorted(set(str(d) for d in docids)) 185 | for d in docids: 186 | with open(os.path.join(docs_dir, d), 'r') as inf: 187 | lines = [l.strip() for l in inf.readlines()] 188 | lines = list(filter(lambda x: bool(len(x)), lines)) 189 | tokenized = [list(filter(lambda x: bool(len(x)), line.strip().split(' '))) for line in lines] 190 | res[d] = tokenized 191 | return res 192 | 193 | 194 | def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]: 195 | """Loads a subset of available documents from disk. 196 | 197 | Returns a tokenized version of the document. 198 | """ 199 | unflattened_docs = load_documents(data_dir, docids) 200 | flattened_docs = dict() 201 | for doc, unflattened in unflattened_docs.items(): 202 | flattened_docs[doc] = list(chain.from_iterable(unflattened)) 203 | return flattened_docs 204 | 205 | 206 | def intern_documents(documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str): 207 | """ 208 | Replaces every word with its index in an embeddings file. 209 | 210 | If a word is not found, uses the unk_token instead 211 | """ 212 | ret = dict() 213 | unk = word_interner[unk_token] 214 | for docid, sentences in documents.items(): 215 | ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences] 216 | return ret 217 | 218 | 219 | def intern_annotations(annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str): 220 | ret = [] 221 | for ann in annotations: 222 | ev_groups = [] 223 | for ev_group in ann.evidences: 224 | evs = [] 225 | for ev in ev_group: 226 | evs.append(Evidence( 227 | text=tuple([word_interner.get(t, word_interner[unk_token]) for t in ev.text.split()]), 228 | docid=ev.docid, 229 | start_token=ev.start_token, 230 | end_token=ev.end_token, 231 | start_sentence=ev.start_sentence, 232 | end_sentence=ev.end_sentence)) 233 | ev_groups.append(tuple(evs)) 234 | ret.append(Annotation(annotation_id=ann.annotation_id, 235 | query=tuple([word_interner.get(t, word_interner[unk_token]) for t in ann.query.split()]), 236 | evidences=frozenset(ev_groups), 237 | classification=ann.classification, 238 | query_type=ann.query_type)) 239 | return ret 240 | 241 | 242 | def load_documents_from_file(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]: 243 | """Loads a subset of available documents from 'docs.jsonl' file on disk. 244 | 245 | Each document is assumed to be serialized as newline ('\n') separated sentences. 246 | Each sentence is assumed to be space (' ') joined tokens. 247 | """ 248 | docs_file = os.path.join(data_dir, 'docs.jsonl') 249 | documents = load_jsonl(docs_file) 250 | documents = {doc['docid']: doc['document'] for doc in documents} 251 | res = dict() 252 | if docids is None: 253 | docids = sorted(list(documents.keys())) 254 | else: 255 | docids = sorted(set(str(d) for d in docids)) 256 | for d in docids: 257 | lines = documents[d].split('\n') 258 | tokenized = [line.strip().split(' ') for line in lines] 259 | res[d] = tokenized 260 | return res 261 | 262 | 263 | NEG = 0 264 | POS = 1 265 | 266 | pattern = re.compile('') 267 | 268 | 269 | def cache_decorator(*dump_fnames): 270 | def excution_decorator(func): 271 | def wrapper(*args, **kwargs): 272 | if len(dump_fnames) == 1: 273 | dump_fname = dump_fnames[0] 274 | if not os.path.isfile(dump_fname): 275 | ret = func(*args, **kwargs) 276 | with open(dump_fname, 'wb') as fdump: 277 | pickle.dump(ret, fdump) 278 | return ret 279 | 280 | with open(dump_fname, 'rb') as fdump: 281 | ret = pickle.load(fdump) 282 | return ret 283 | 284 | rets = None 285 | for fname in dump_fnames: 286 | if not os.path.isfile(fname): 287 | rets = func(*args, **kwargs) 288 | break 289 | if rets is not None: 290 | for r, fname in zip(rets, dump_fnames): 291 | with open(fname, 'wb') as fdump: 292 | pickle.dump(r, fdump) 293 | return rets 294 | 295 | rets = [] 296 | for fname in dump_fnames: 297 | with open(fname, 'rb') as fdump: 298 | rets.append(pickle.load(fdump)) 299 | return tuple(rets) 300 | 301 | return wrapper 302 | 303 | return excution_decorator 304 | 305 | 306 | def convert_subtoken_ids_to_tokens(ids:List[int], 307 | tokenizer:transformers.BertTokenizer, 308 | token_mapping=None, 309 | exps=None, 310 | raw_sentence=None): 311 | subtokens = tokenizer.convert_ids_to_tokens(ids) 312 | tokens, exps_outputs = [], [] 313 | if not isinstance(exps[0], list): 314 | exps = [exps] 315 | exps_inputs = [[0] * len(ids)] if exps is None else exps 316 | raw_sentence = subtokens if raw_sentence is None else raw_sentence 317 | subtokens = list(reversed([t[2:] if t.startswith('##') else t for t in subtokens])) 318 | if token_mapping is None: 319 | exps_inputs = list(zip(*(list(reversed(e)) for e in exps_inputs))) 320 | for ref_token in raw_sentence: 321 | t, es = '', [0] * len(exps_inputs[0]) 322 | while t != ref_token and len(subtokens) > 0: 323 | t += subtokens.pop() 324 | es = [max(old, new) for old, new in zip(es, exps_inputs.pop())] 325 | tokens.append(t) 326 | exps_outputs.append(es) 327 | if len(subtokens) == 0: 328 | # the last sub-token is incomplete, ditch it directly 329 | if ref_token != tokens[-1]: 330 | tokens = tokens[:-1] 331 | exps_outputs = exps_outputs[:-1] 332 | break 333 | else: 334 | hard_rats, soft_rats = exps 335 | for ref_token_idx, (token_piece_start, token_piece_end) in enumerate(token_mapping): 336 | if token_piece_start >= len(hard_rats): 337 | break 338 | tokens.append(raw_sentence[ref_token_idx]) 339 | max_hard_rat = max(hard_rats[token_piece_start: token_piece_end]) 340 | max_soft_rat = max(soft_rats[token_piece_start: token_piece_end]) 341 | exps_outputs.append((max_hard_rat, max_soft_rat)) 342 | if exps is None: 343 | return tokens 344 | return tokens, exps_outputs 345 | -------------------------------------------------------------------------------- /expred/train.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict 2 | 3 | import sys 4 | 5 | import argparse 6 | import logging 7 | import wandb as wandb 8 | from typing import List, Dict, Set, Tuple 9 | 10 | import torch 11 | import os 12 | import json 13 | import numpy as np 14 | import random 15 | 16 | from itertools import chain 17 | 18 | from expred import metrics 19 | from expred.params import MTLParams 20 | from expred.models.mlp_mtl import BertMTL, BertClassifier 21 | from expred.tokenizer import BertTokenizerWithMapping 22 | from expred.models.pipeline.mtl_pipeline_utils import decode 23 | from expred.utils import load_datasets, load_documents, write_jsonl, Annotation 24 | from expred.models.pipeline.mtl_token_identifier import train_mtl_token_identifier 25 | from expred.models.pipeline.mtl_evidence_classifier import train_mtl_evidence_classifier 26 | from expred.eraser_utils import get_docids 27 | 28 | BATCH_FIRST = True 29 | 30 | 31 | def initialize_models(conf: dict, 32 | tokenizer: BertTokenizerWithMapping, 33 | batch_first: bool) -> Tuple[BertMTL, BertClassifier, Dict[int, str]]: 34 | """ 35 | Does several things: 36 | 1. Create a mapping from label names to ids 37 | 2. Configure and create the multi task learner, the first stage of the model (BertMTL) 38 | 3. Configure and create the evidence classifier, second stage of the model (BertClassifier) 39 | :param conf: 40 | :param tokenizer: 41 | :param batch_first: 42 | :return: BertMTL, BertClassifier, label mapping 43 | """ 44 | assert batch_first 45 | max_length = conf['max_length'] 46 | # label mapping 47 | labels = dict((y, x) for (x, y) in enumerate(conf['classes'])) 48 | 49 | # configure multi task learner 50 | mtl_params = MTLParams 51 | mtl_params.num_labels = len(labels) 52 | mtl_params.dim_exp_gru = conf['dim_exp_gru'] 53 | mtl_params.dim_cls_linear = conf['dim_cls_linear'] 54 | bert_dir = conf['bert_dir'] 55 | use_half_precision = bool(conf['mtl_token_identifier'].get('use_half_precision', 1)) 56 | evidence_identifier = BertMTL(bert_dir=bert_dir, 57 | tokenizer=tokenizer, 58 | mtl_params=mtl_params, 59 | max_length=max_length, 60 | use_half_precision=use_half_precision) 61 | 62 | # set up the evidence classifier 63 | use_half_precision = bool(conf['evidence_classifier'].get('use_half_precision', 1)) 64 | evidence_classifier = BertClassifier(bert_dir=bert_dir, 65 | pad_token_id=tokenizer.pad_token_id, 66 | cls_token_id=tokenizer.cls_token_id, 67 | sep_token_id=tokenizer.sep_token_id, 68 | num_labels=mtl_params.num_labels, 69 | max_length=max_length, 70 | mtl_params=mtl_params, 71 | use_half_precision=use_half_precision) 72 | 73 | return evidence_identifier, evidence_classifier, labels 74 | 75 | 76 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s') 77 | logger = logging.getLogger(__name__) 78 | 79 | # let's make this more or less deterministic (not resistent to restarts) 80 | random.seed(12345) 81 | np.random.seed(67890) 82 | torch.manual_seed(10111213) 83 | torch.backends.cudnn.deterministic = True 84 | torch.backends.cudnn.benchmark = False 85 | 86 | 87 | # or, uncomment the following sentences to make it more than random 88 | # rand_seed_1 = ord(os.urandom(1)) * ord(os.urandom(1)) 89 | # rand_seed_2 = ord(os.urandom(1)) * ord(os.urandom(1)) 90 | # rand_seed_3 = ord(os.urandom(1)) * ord(os.urandom(1)) 91 | # random.seed(rand_seed_1) 92 | # np.random.seed(rand_seed_2) 93 | # torch.manual_seed(rand_seed_3) 94 | # torch.backends.cudnn.deterministic = False 95 | # torch.backends.cudnn.benchmark = True 96 | 97 | 98 | def main(args : List[str]): 99 | # setup the Argument Parser 100 | parser = argparse.ArgumentParser(description=('Trains a pipeline model.\n' 101 | '\n' 102 | 'Step 1 is evidence identification, the MTL happens here. It ' 103 | 'predicts the label of the current sentence and tags its\n ' 104 | ' sub-tokens in the same time \n' 105 | ' Step 2 is evidence classification, a BERT classifier takes the output of the evidence identifier and predicts its \n' 106 | ' sentiment. Unlike in Deyong et al. this classifier takes in the same length as the identifier\'s input but with \n' 107 | ' irrational sub-tokens masked.\n' 108 | '\n' 109 | ' These models should be separated into two separate steps, but at the moment:\n' 110 | ' * prep data (load, intern documents, load json)\n' 111 | ' * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives\n' 112 | ' * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values.\n' 113 | ' * train evidence identification\n' 114 | ' * convert data for evidence classification - take all rationales + decisions and use this as input\n' 115 | ' * train evidence classification\n' 116 | ' * decode first the evidence, then run classification for each split\n' 117 | '\n' 118 | ' '), formatter_class=argparse.RawTextHelpFormatter) 119 | parser.add_argument('--data_dir', dest='data_dir', required=True, 120 | help='Which directory contains a {train,val,test}.jsonl file?') 121 | parser.add_argument('--output_dir', dest='output_dir', required=True, 122 | help='Where shall we write intermediate models + final data to?') 123 | parser.add_argument('--conf', dest='conf', required=True, 124 | help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.') 125 | parser.add_argument('--batch_size', type=int, required=False, default=None, 126 | help='Overrides the batch_size given in the config file. Helpful for debugging') 127 | args = parser.parse_args(args) 128 | 129 | wandb.init(entity="explainable-nlp", project="expred") 130 | # Configure 131 | os.makedirs(args.output_dir, exist_ok=True) 132 | 133 | # loads the config 134 | with open(args.conf, 'r') as fp: 135 | logger.info(f'Loading configuration from {args.conf}') 136 | conf = json.load(fp) 137 | if args.batch_size is not None: 138 | logger.info( 139 | 'Overwriting batch_sizes' 140 | f'(mtl_token_identifier:{conf["mtl_token_identifier"]["batch_size"]}' 141 | f'evidence_classifier:{conf["evidence_classifier"]["batch_size"]})' 142 | f'provided in config by command line argument({args.batch_size})' 143 | ) 144 | conf['mtl_token_identifier']['batch_size'] = args.batch_size 145 | conf['evidence_classifier']['batch_size'] = args.batch_size 146 | logger.info(f'Configuration: {json.dumps(conf, indent=2, sort_keys=True)}') 147 | 148 | # todo add seeds 149 | wandb.config.update(conf) 150 | wandb.config.update(args) 151 | 152 | # load the annotation data 153 | train, val, test = load_datasets(args.data_dir) 154 | 155 | # get's all docids needed that are contained in the loaded splits 156 | docids: Set[str] = set(chain.from_iterable(map(lambda ann: get_docids(ann), 157 | chain(train, val, test)))) 158 | 159 | documents: Dict[str, List[List[str]]] = load_documents(args.data_dir, docids) 160 | logger.info(f'Load {len(documents)} documents') 161 | # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important 162 | 163 | tokenizer = BertTokenizerWithMapping.from_pretrained(conf['bert_vocab']) 164 | mtl_token_identifier, evidence_classifier, labels_mapping = \ 165 | initialize_models(conf, tokenizer, batch_first=BATCH_FIRST) 166 | # logger.info(f'We have {len(word_interner)} wordpieces') 167 | 168 | # tokenizes and caches tokenized_docs, same for annotations 169 | # todo typo here? slides = slices (words?) 170 | tokenized_docs, tokenized_doc_token_slices = tokenizer.encode_docs(documents, args.output_dir) 171 | indexed_train, indexed_val, indexed_test = [tokenizer.encode_annotations(data) for data in [train, val, test]] 172 | 173 | logger.info('Beginning training of the MTL identifier') 174 | mtl_token_identifier = mtl_token_identifier.cuda() 175 | mtl_token_identifier, mtl_token_identifier_results, \ 176 | train_machine_annotated, eval_machine_annotated, test_machine_annotated = \ 177 | train_mtl_token_identifier(mtl_token_identifier, 178 | args.output_dir, 179 | indexed_train, 180 | indexed_val, 181 | indexed_test, 182 | labels_mapping=labels_mapping, 183 | interned_documents=tokenized_docs, 184 | source_documents=documents, 185 | token_mapping=tokenized_doc_token_slices, 186 | model_pars=conf, 187 | tensorize_model_inputs=True) 188 | mtl_token_identifier = mtl_token_identifier.cpu() 189 | # evidence identifier ends 190 | 191 | logger.info('Beginning training of the evidence classifier') 192 | evidence_classifier = evidence_classifier.cuda() 193 | optimizer = None 194 | scheduler = None 195 | 196 | # trains the classifier on the masked (based on rationales) documents 197 | evidence_classifier, evidence_class_results = train_mtl_evidence_classifier(evidence_classifier, 198 | args.output_dir, 199 | train_machine_annotated, 200 | eval_machine_annotated, 201 | tokenized_docs, 202 | conf, 203 | optimizer=optimizer, 204 | scheduler=scheduler, 205 | class_interner=labels_mapping, 206 | tensorize_model_inputs=True) 207 | # evidence classifier ends 208 | 209 | logger.info('Beginning final decoding') 210 | mtl_token_identifier = mtl_token_identifier.cuda() 211 | pipeline_batch_size = min( 212 | [conf['evidence_classifier']['batch_size'], conf['mtl_token_identifier']['batch_size']]) 213 | pipeline_results, train_decoded, val_decoded, test_decoded = decode(evidence_identifier=mtl_token_identifier, 214 | evidence_classifier=evidence_classifier, 215 | train=indexed_train, 216 | mrs_train=train_machine_annotated, 217 | val=indexed_val, 218 | mrs_eval=eval_machine_annotated, 219 | test=indexed_test, 220 | mrs_test=test_machine_annotated, 221 | source_documents=documents, 222 | interned_documents=tokenized_docs, 223 | token_mapping=tokenized_doc_token_slices, 224 | class_interner=labels_mapping, 225 | tensorize_modelinputs=True, 226 | batch_size=pipeline_batch_size, 227 | tokenizer=tokenizer) 228 | write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl')) 229 | write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl')) 230 | write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl')) 231 | with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \ 232 | open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output: 233 | ident_output.write(json.dumps(mtl_token_identifier_results)) 234 | class_output.write(json.dumps(evidence_class_results)) 235 | for k, v in pipeline_results.items(): 236 | if type(v) is dict: 237 | for k1, v1 in v.items(): 238 | logging.info(f'Pipeline results for {k}, {k1}={v1}') 239 | else: 240 | logging.info(f'Pipeline results {k}\t={v}') 241 | # decode ends 242 | 243 | scores = metrics.main( 244 | [ 245 | '--data_dir', args.data_dir, 246 | '--split', 'test', 247 | '--results', os.path.join(args.output_dir, 'test_decoded.jsonl'), 248 | '--score_file', os.path.join(args.output_dir, 'test_scores.jsonl') 249 | ] 250 | ) 251 | 252 | wandb.log(scores) 253 | 254 | wandb.save(os.path.join(args.output_dir, '*.jsonl')) 255 | 256 | 257 | 258 | if __name__ == '__main__': 259 | main(sys.argv[1:]) 260 | -------------------------------------------------------------------------------- /expred/models/pipeline/mtl_token_identifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import logging 4 | import numpy as np 5 | import random 6 | import wandb 7 | 8 | from typing import List, Dict, Tuple, Callable, Union, Any 9 | from collections import OrderedDict, namedtuple, defaultdict 10 | from sklearn.metrics import accuracy_score, classification_report 11 | from itertools import chain 12 | from torch import nn 13 | from torch.nn import Module 14 | 15 | from expred.eraser_utils import chain_sentence_evidences 16 | from expred.models.model_utils import PaddedSequence 17 | from expred.utils import Annotation 18 | from expred.models.pipeline.pipeline_utils import ( 19 | SentenceEvidence, score_token_rationales) 20 | from expred.models.pipeline.mtl_pipeline_utils import ( 21 | annotations_to_mtl_token_identification, 22 | make_mtl_token_preds_epoch 23 | ) 24 | from expred.models.losses import resampling_rebalanced_crossentropy 25 | 26 | AnnotatedDocument = namedtuple('AnnotatedDocument', 'kls evd ann_id query docid index sentence') 27 | 28 | 29 | def _get_sampling_method(params): 30 | if params['sampling_method'] == 'whole_document': 31 | def whole_document_sampler(sentences, _): 32 | return chain_sentence_evidences(sentences) 33 | 34 | return whole_document_sampler 35 | else: 36 | raise NotImplementedError 37 | 38 | 39 | def _prep_data_for_epoch(evidence_data: Tuple[str, Dict[str, Dict[str, List[SentenceEvidence]]]], 40 | sampler: Callable[ 41 | [List[SentenceEvidence], Dict[str, List[SentenceEvidence]]], List[SentenceEvidence]] 42 | ) -> List[SentenceEvidence]: 43 | """ 44 | Shuffle the annotations and sample from documents (can also be the whole document depending on the sampler) 45 | :param evidence_data: 46 | :param sampler: 47 | :return: 48 | """ 49 | output_annotations = [] 50 | ann_ids = sorted(evidence_data.keys()) 51 | # in place shuffle so we get a different per-epoch ordering 52 | random.shuffle(ann_ids) 53 | for ann_id in ann_ids: 54 | for docid, sentences in evidence_data[ann_id][1].items(): 55 | data = sampler(sentences, None) 56 | output_annotations.append((evidence_data[ann_id][0], data)) 57 | return output_annotations 58 | 59 | 60 | def train_mtl_token_identifier(mtl_token_identifier: nn.Module, 61 | save_dir: str, 62 | train: List[Annotation], 63 | val: List[Annotation], 64 | test: List[Annotation], 65 | interned_documents: Dict[str, List[List[int]]], 66 | source_documents: Dict[str, List[List[str]]], 67 | token_mapping: Dict[str, List[List[Tuple[int, int]]]], 68 | model_pars: dict, 69 | labels_mapping: Dict[str, int], 70 | optimizer=None, 71 | scheduler=None, 72 | tensorize_model_inputs: bool = True) -> Tuple[ 73 | Module, Union[Dict[str, list], Any], List[Tuple[Union[SentenceEvidence, List[SentenceEvidence]], Any, Any]], List[ 74 | Tuple[Union[SentenceEvidence, List[SentenceEvidence]], Any, Any]], List[ 75 | Tuple[Union[SentenceEvidence, List[SentenceEvidence]], Any, Any]]]: 76 | """Trains a module for token-level rationale identification. 77 | This method tracks loss on the entire validation set, saves intermediate 78 | models, and supports restoring from an unfinished state. The best model on 79 | the validation set is maintained, and the model stops training if a patience 80 | (see below) number of epochs with no improvement is exceeded. 81 | As there are likely too many negative examples to reasonably train a 82 | classifier on everything, every epoch we subsample the negatives. 83 | 84 | :param mtl_token_identifier: token-wise evidence identifier using Multi-Task Learning 85 | :param save_dir: a place to save intermediate and final results and models. 86 | :param train: a List of interned Annotation objects. 87 | :param val: a List of interned Annotation objects. 88 | :param test: a List of interned Annotation objects. 89 | :param interned_documents: a Dict of interned sentences 90 | :param source_documents: a Dict of original sentences, used for sub-token alignment 91 | :param token_mapping: a mapping from original token to sub tokens 92 | :param model_pars: model parameters 93 | :param labels_mapping: a mapping from str labels to their int ids 94 | :param optimizer: what pytorch optimizer to use, if none, initialize Adam 95 | :param scheduler: optional, do we want a scheduler involved in learning? 96 | :param tensorize_model_inputs: should we convert our data to tensors before passing it to the model? 97 | Useful if we have a model that performs its own tokenization (e.g. BERT as a Service) 98 | :returns 99 | the trained MTL evidence token identifier 100 | the intermediate results 101 | machine-annotated train/eval/test datasets 102 | """ 103 | 104 | 105 | 106 | # set up output folder structure 107 | logging.info(f'Beginning training with {len(train)} annotations, {len(val)} for validation') 108 | evidence_identifier_output_dir = os.path.join(save_dir, 'evidence_token_identifier') 109 | os.makedirs(save_dir, exist_ok=True) 110 | os.makedirs(evidence_identifier_output_dir, exist_ok=True) 111 | 112 | model_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_token_identifier.pt') 113 | epoch_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_token_identifier_epoch_data.pt') 114 | 115 | # set up training (optimizer, loss (both), sampling_method, ...) 116 | if optimizer is None: 117 | optimizer = torch.optim.Adam(mtl_token_identifier.parameters(), lr=model_pars['mtl_token_identifier']['lr']) 118 | cls_criterion = nn.BCELoss(reduction='none') 119 | 120 | exp_criterion = resampling_rebalanced_crossentropy(seq_reduction='none') # nn.CrossEntropyLoss(reduction='none') 121 | sampling_method = _get_sampling_method(model_pars['mtl_token_identifier']) 122 | batch_size = model_pars['mtl_token_identifier']['batch_size'] 123 | max_length = model_pars['max_length'] 124 | epochs = model_pars['mtl_token_identifier']['epochs'] 125 | 126 | patience = model_pars['mtl_token_identifier']['patience'] 127 | max_grad_norm = model_pars['mtl_token_identifier'].get('max_grad_norm', None) 128 | par_lambda = model_pars['mtl_token_identifier']['par_lambda'] 129 | # annotation id -> docid -> [SentenceEvidence]) 130 | # calculates the classification of the sequence tokens and some other stuff 131 | evidence_train_data = annotations_to_mtl_token_identification(train, 132 | source_documents=source_documents, 133 | interned_documents=interned_documents, 134 | token_mapping=token_mapping) 135 | evidence_val_data = annotations_to_mtl_token_identification(val, 136 | source_documents=source_documents, 137 | interned_documents=interned_documents, 138 | token_mapping=token_mapping) 139 | 140 | evidence_test_data = annotations_to_mtl_token_identification(test, 141 | source_documents=source_documents, 142 | interned_documents=interned_documents, 143 | token_mapping=token_mapping) 144 | 145 | device = next(mtl_token_identifier.parameters()).device 146 | 147 | results = { 148 | 'sampled_epoch_train_losses': [], 149 | 'epoch_val_total_losses': [], 150 | 'epoch_val_cls_losses': [], 151 | 'epoch_val_exp_losses': [], 152 | 'epoch_val_exp_acc': [], 153 | 'epoch_val_exp_f': [], 154 | 'epoch_val_cls_acc': [], 155 | 'epoch_val_cls_f': [], 156 | 'full_epoch_val_rationale_scores': [] 157 | } 158 | 159 | # allow restoring an existing training run 160 | start_epoch = 0 161 | best_epoch = -1 162 | best_val_total_loss = float('inf') 163 | best_model_state_dict = None 164 | epoch_data = {} 165 | if os.path.exists(epoch_save_file): 166 | mtl_token_identifier.load_state_dict(torch.load(model_save_file)) 167 | epoch_data = torch.load(epoch_save_file) 168 | start_epoch = epoch_data['epoch'] + 1 169 | # handle finishing because patience was exceeded or we didn't get the best final epoch 170 | if bool(epoch_data.get('done', 0)): 171 | start_epoch = epochs 172 | results = epoch_data['results'] 173 | best_epoch = start_epoch 174 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in mtl_token_identifier.state_dict().items()}) 175 | logging.info(f'Training evidence identifier from epoch {start_epoch} until epoch {epochs}') 176 | optimizer.zero_grad() 177 | for epoch in range(start_epoch, epochs): 178 | epoch_train_data = _prep_data_for_epoch(evidence_train_data, sampling_method) 179 | epoch_val_data = _prep_data_for_epoch(evidence_val_data, sampling_method) 180 | sampled_epoch_train_loss = 0 181 | losses = defaultdict(lambda: []) 182 | mtl_token_identifier.train() 183 | logging.info( 184 | f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples') 185 | # one epoch of training 186 | for batch_start in range(0, len(epoch_train_data), batch_size): 187 | batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))] 188 | # we sample every time to thereoretically get a better representation of instances over the corpus. 189 | # this might just take more time than doing so in advance. 190 | labels, targets, queries, sentences, has_evidence = zip( 191 | *[(s[0], s[1].kls, s[1].query, s[1].sentence, s[1].has_evidence) for s in batch_elements]) 192 | 193 | # one hot encoding for classification 194 | labels = [[i == labels_mapping[label] for i in range(len(labels_mapping))] for label in labels] 195 | labels = torch.tensor(labels, dtype=torch.float, device=device) 196 | 197 | ids = [(s[1].ann_id, s[1].docid, s[1].index) for s in batch_elements] 198 | 199 | # truncation 200 | cropped_targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep] 201 | + list(target[:(max_length - len(query) - 2)]) for query, target in 202 | zip(queries, targets)] 203 | cropped_targets = PaddedSequence.autopad( 204 | [torch.tensor(t, dtype=torch.float, device=device) for t in cropped_targets], 205 | batch_first=True, device=device) 206 | targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep] 207 | + list(target) for query, target in zip(queries, targets)] 208 | targets = PaddedSequence.autopad([torch.tensor(t, dtype=torch.float, device='cpu') for t in targets], 209 | batch_first=True, device='cpu') 210 | if tensorize_model_inputs: 211 | if all(q is None for q in queries): 212 | queries = [torch.tensor([], dtype=torch.long) for _ in queries] 213 | else: 214 | assert all(q is not None for q in queries) 215 | queries = [torch.tensor(q, dtype=torch.long) for q in queries] 216 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] 217 | 218 | # prediction 219 | preds = mtl_token_identifier(queries, ids, sentences) 220 | cls_preds, exp_preds, attention_masks = preds 221 | cls_loss = cls_criterion(cls_preds, labels).mean(dim=-1).sum() 222 | 223 | exp_loss_per_instance = exp_criterion(exp_preds, cropped_targets.data.squeeze()).mean(dim=-1) 224 | has_evidence_mask = torch.tensor(has_evidence, dtype=float, device=exp_loss_per_instance.device) 225 | 226 | exp_loss = (exp_loss_per_instance * has_evidence_mask).sum() 227 | loss = cls_loss + par_lambda * exp_loss 228 | 229 | losses['cls_loss'].append(cls_loss.item()) 230 | losses['exp_loss'].append(exp_loss.item()) 231 | losses['loss'].append(loss.item()) 232 | 233 | sampled_epoch_train_loss += loss.item() 234 | loss.backward() 235 | if max_grad_norm: 236 | torch.nn.utils.clip_grad_norm_(mtl_token_identifier.parameters(), max_grad_norm) 237 | optimizer.step() 238 | if scheduler: 239 | scheduler.step() 240 | optimizer.zero_grad() 241 | sampled_epoch_train_loss /= len(epoch_train_data) 242 | results['sampled_epoch_train_losses'].append(sampled_epoch_train_loss) 243 | 244 | mean_losses = {f'train_{key}':np.mean(loss) for key, loss in losses.items()} 245 | mean_losses['epoch'] = epoch 246 | wandb.log(mean_losses) 247 | 248 | logging.info(f'Epoch {epoch} training loss {sampled_epoch_train_loss}') 249 | 250 | # validation 251 | with torch.no_grad(): 252 | mtl_token_identifier.eval() 253 | epoch_val_total_loss, epoch_val_cls_loss, epoch_val_exp_loss, \ 254 | epoch_val_soft_pred, epoch_val_hard_pred, epoch_val_token_targets, \ 255 | epoch_val_pred_labels, epoch_val_labels = \ 256 | make_mtl_token_preds_epoch(mtl_token_identifier, 257 | epoch_val_data, 258 | labels_mapping, 259 | token_mapping, 260 | batch_size, 261 | max_length, 262 | par_lambda, 263 | device, 264 | cls_criterion, 265 | exp_criterion, 266 | tensorize_model_inputs) 267 | # epoch_val_soft_pred = list(chain.from_iterable(epoch_val_soft_pred.tolist())) 268 | # epoch_val_hard_pred = list(chain.from_iterable(epoch_val_hard_pred)) 269 | # epoch_val_truth = list(chain.from_iterable(epoch_val_truth)) 270 | results['epoch_val_total_losses'].append(epoch_val_total_loss) 271 | results['epoch_val_cls_losses'].append(epoch_val_cls_loss) 272 | results['epoch_val_exp_losses'].append(epoch_val_exp_loss) 273 | epoch_val_hard_pred_chained = list(chain.from_iterable(epoch_val_hard_pred)) 274 | epoch_val_token_targets_chained = list(chain.from_iterable(epoch_val_token_targets)) 275 | results['epoch_val_exp_acc'].append(accuracy_score(epoch_val_token_targets_chained, 276 | epoch_val_hard_pred_chained)) 277 | results['epoch_val_exp_f'].append(classification_report(epoch_val_token_targets_chained, 278 | epoch_val_hard_pred_chained, 279 | labels=[0, 1], # of course rational and irrational 280 | output_dict=True)) 281 | flattened_epoch_val_pred_labels = [np.argmax(x) for x in epoch_val_pred_labels] 282 | flattened_epoch_val_labels = [np.argmax(x) for x in epoch_val_labels] 283 | results['epoch_val_cls_acc'].append(accuracy_score(flattened_epoch_val_pred_labels, 284 | flattened_epoch_val_labels)) 285 | # print(flattened_epoch_val_labels) 286 | # print(flattened_epoch_val_pred_labels) 287 | results['epoch_val_cls_f'].append(classification_report(flattened_epoch_val_labels, 288 | flattened_epoch_val_pred_labels, 289 | labels=[v for _, v in labels_mapping.items()], 290 | output_dict=True)) 291 | results['full_epoch_val_rationale_scores'].append( 292 | score_token_rationales(val, source_documents, 293 | epoch_val_data, 294 | token_mapping, 295 | epoch_val_hard_pred, 296 | epoch_val_soft_pred)) 297 | 298 | validation_metrics = {metric:values[-1] for metric, values in results.items()} 299 | validation_metrics['epoch'] = epoch 300 | # epoch_val_soft_pred_for_scoring = [[[1 - z, z] for z in y] for y in epoch_val_soft_pred] 301 | # logging.info( 302 | # f'Epoch {epoch} full val loss {epoch_val_total_loss}, accuracy: {results["epoch_val_acc"][-1]}, f: {results["epoch_val_f"][-1]}, rationale scores: look, it\'s already a pain to duplicate this code. What do you want from me.') 303 | 304 | # if epoch_val_loss < best_val_loss: 305 | if epoch_val_total_loss < best_val_total_loss: 306 | logging.debug(f'Epoch {epoch} new best model with val loss {epoch_val_total_loss}') 307 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in mtl_token_identifier.state_dict().items()}) 308 | best_epoch = epoch 309 | best_val_loss = epoch_val_total_loss 310 | torch.save(mtl_token_identifier.state_dict(), model_save_file) 311 | epoch_data = { 312 | 'epoch': epoch, 313 | 'results': results, 314 | 'best_val_loss': best_val_loss, 315 | 'done': 0 316 | } 317 | torch.save(epoch_data, epoch_save_file) 318 | if epoch - best_epoch > patience: 319 | epoch_data['done'] = 1 320 | torch.save(epoch_data, epoch_save_file) 321 | break 322 | 323 | epoch_data['done'] = 1 324 | epoch_data['results'] = results 325 | torch.save(epoch_data, epoch_save_file) 326 | mtl_token_identifier.load_state_dict(best_model_state_dict) 327 | mtl_token_identifier = mtl_token_identifier.to(device=device) 328 | mtl_token_identifier.eval() 329 | 330 | def prepare_for_cl(input_data, keep_corrected_only=False): 331 | """ 332 | Extract rationale prediction from document only. 333 | If keep_corrected_only=True keep only the annotation where the classification was correct 334 | :param input_data: 335 | :param keep_corrected_only: 336 | :return: 337 | """ 338 | epoch_input_data = _prep_data_for_epoch(input_data, sampling_method) 339 | _, _, _, soft_pred_for_cl, hard_pred_for_cl, _, \ 340 | pred_labels_for_cl, labels_for_cl = \ 341 | make_mtl_token_preds_epoch(mtl_token_identifier, epoch_input_data, labels_mapping, 342 | token_mapping, batch_size, max_length, par_lambda, 343 | device, cls_criterion, exp_criterion, tensorize_model_inputs) 344 | hard_pred_for_cl = [h.cpu().tolist() for h in hard_pred_for_cl] 345 | hard_pred_for_cl = [h[len(d[1].query) + 2:] for h, d in zip(hard_pred_for_cl, epoch_input_data)] 346 | soft_pred_for_cl = [s[len(d[1].query) + 2:] for s, d in zip(soft_pred_for_cl, epoch_input_data)] 347 | train_ids = list(range(len(labels_for_cl))) 348 | if keep_corrected_only: 349 | labels_for_cl = [np.argmax(x) for x in labels_for_cl] 350 | pred_labels_for_cl = [np.argmax(x) for x in pred_labels_for_cl] 351 | train_ids = list(filter(lambda i: labels_for_cl[i] == pred_labels_for_cl[i], 352 | range(len(labels_for_cl)))) 353 | return [(epoch_input_data[i], soft_pred_for_cl[i], hard_pred_for_cl[i]) for i in train_ids] 354 | 355 | train_machine_annotated = prepare_for_cl(evidence_train_data, True) 356 | eval_machine_annotated = prepare_for_cl(evidence_val_data, False) 357 | test_machine_annotated = prepare_for_cl(evidence_test_data, False) 358 | return mtl_token_identifier, results, \ 359 | train_machine_annotated, eval_machine_annotated, test_machine_annotated 360 | -------------------------------------------------------------------------------- /expred/models/pipeline/mtl_pipeline_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import defaultdict, namedtuple 3 | from itertools import chain 4 | from typing import Any, Dict, List, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | from sklearn.metrics import classification_report, accuracy_score 9 | 10 | from expred.eraser_utils import chain_sentence_evidences, get_docids 11 | from expred.models.model_utils import PaddedSequence 12 | from expred.models.pipeline.pipeline_utils import SentenceEvidence, _grouper, \ 13 | score_rationales 14 | from expred.utils import Annotation 15 | from expred.eraser_benchmark import rational_bits_to_ev_generator 16 | from expred.utils import convert_subtoken_ids_to_tokens 17 | 18 | 19 | def mask_annotations_to_evidence_classification(mrs: List[Tuple[Tuple[str, SentenceEvidence], Any]], # mrs for machine rationales 20 | class_interner: dict) -> List[SentenceEvidence]: 21 | """ 22 | 23 | :param mrs: 24 | :param class_interner: 25 | :return: 26 | """ 27 | ret = [] 28 | for mr, _, hard_prediction in mrs: 29 | kls = class_interner[mr[0]] 30 | evidence = mr[1] 31 | sentence = evidence.sentence 32 | query = evidence.query 33 | try: 34 | assert len(hard_prediction) == len(sentence) 35 | except Exception: 36 | print(mr) 37 | print(len(hard_prediction)) 38 | print(len(query)) 39 | print(len(sentence)) 40 | masked_sentence = [p * d for p, d in zip(hard_prediction, sentence)] 41 | ret.append(SentenceEvidence(kls=kls, 42 | query=query, 43 | ann_id=evidence.ann_id, 44 | docid=evidence.docid, 45 | index=-1, 46 | sentence=masked_sentence, 47 | has_evidence=evidence.has_evidence)) 48 | return ret 49 | 50 | 51 | def annotations_to_evidence_token_identification(annotations: List[Annotation], 52 | source_documents: Dict[str, List[List[str]]], 53 | interned_documents: Dict[str, List[List[int]]], 54 | token_mapping: Dict[str, List[List[Tuple[int, int]]]] 55 | ) -> Dict[str, Dict[str, List[SentenceEvidence]]]: 56 | """ 57 | Calculates the start and end positions for 58 | * tokens 59 | * sentences 60 | * rationales 61 | and create a classification map for the tokens. 62 | 63 | 64 | :param annotations: 65 | :param source_documents: 66 | :param interned_documents: 67 | :param token_mapping: 68 | :return: dict[ann_id]dict[doc_id][SentenceEvidence, ...] 69 | """ 70 | # TODO document 71 | # TODO should we simplify to use only source text? 72 | ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences 73 | positive_tokens = 0 74 | negative_tokens = 0 75 | for ann in annotations: 76 | annid = ann.annotation_id 77 | docids = list(set(get_docids(ann))) 78 | sentence_offsets = defaultdict(list) # docid -> [(start, end)] 79 | classes = defaultdict(list) # docid -> [token is yea or nay] 80 | absolute_word_mapping = defaultdict(list) # docid -> [(absolute wordpiece start, absolute wordpiece end)] 81 | 82 | # chain the sentences and store start and of each token same for each sentence 83 | # also prepare classes (list of zeros for each word piece token) 84 | for docid in docids: 85 | start = 0 86 | assert len(source_documents[docid]) == len(interned_documents[docid]) 87 | for sentence_id, (whole_token_sent, wordpiece_sent) in enumerate( 88 | zip(source_documents[docid], interned_documents[docid])): 89 | classes[docid].extend([0 for _ in wordpiece_sent]) 90 | end = start + len(wordpiece_sent) 91 | sentence_offsets[docid].append((start, end)) 92 | absolute_word_mapping[docid].extend([(start + relative_wp_start, 93 | start + relative_wp_end) 94 | for relative_wp_start, 95 | relative_wp_end in token_mapping[docid][sentence_id]]) 96 | start = end 97 | # calculate the start and end tokens for the evidence spans and set classes to 1 respectively 98 | for ev in chain.from_iterable(ann.evidences): 99 | if len(ev.text) == 0: 100 | continue 101 | flat_token_map = list(chain.from_iterable(token_mapping[ev.docid])) 102 | if ev.start_token != -1 and ev.start_sentence != -1: 103 | # start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1] 104 | sentence_offset_start = sentence_offsets[ev.docid][ev.start_sentence][0] 105 | sentence_offset_end = sentence_offsets[ev.docid][ev.end_sentence - 1][0] 106 | start = sentence_offset_start + flat_token_map[ev.start_token][0] 107 | end = sentence_offset_end + flat_token_map[ev.end_token - 1][1] 108 | elif ev.start_token == -1 and ev.start_sentence != -1: 109 | start = sentence_offsets[ev.start_sentence][0] 110 | end = sentence_offsets[ev.end_sentence - 1][1] 111 | elif ev.start_token != -1 and ev.start_sentence == -1: 112 | start = absolute_word_mapping[ev.docid][ev.start_token][0] 113 | end = absolute_word_mapping[ev.docid][ev.end_token][1] 114 | else: 115 | continue 116 | for i in range(start, end): 117 | try: 118 | classes[ev.docid][i] = 1 119 | except IndexError: 120 | print(ev) 121 | print(ev.docid) 122 | print(classes) 123 | print(len(classes[ev.docid])) 124 | print(i) 125 | raise IndexError 126 | for docid, offsets in sentence_offsets.items(): 127 | token_assignments = classes[docid] 128 | positive_tokens += sum(token_assignments) 129 | negative_tokens += len(token_assignments) - sum(token_assignments) 130 | for s, (start, end) in enumerate(offsets): 131 | sent = interned_documents[docid][s] 132 | ret[annid][docid].append(SentenceEvidence(kls=tuple(token_assignments[start:end]), 133 | query=ann.query, 134 | ann_id=ann.annotation_id, 135 | docid=docid, 136 | index=s, 137 | sentence=sent, 138 | has_evidence=len(ann.evidences) > 0)) 139 | logging.info(f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens") 140 | return ret 141 | 142 | 143 | def annotations_to_mtl_token_identification(annotations: object, 144 | source_documents: object, 145 | interned_documents: object, 146 | token_mapping: object 147 | ) -> object: 148 | """ 149 | See annotations_to_evidence_token_identification for more details 150 | :param annotations: 151 | :param source_documents: 152 | :param interned_documents: 153 | :param token_mapping: 154 | :return: dict[ann_id][classification ,dict[doc_id][SentenceEvidence, ...]] 155 | """ 156 | rets = annotations_to_evidence_token_identification( 157 | annotations, 158 | source_documents, 159 | interned_documents, 160 | token_mapping 161 | ) 162 | # adds the final sequence classification 163 | for ann in annotations: 164 | ann_id = ann.annotation_id 165 | ann_kls = ann.classification 166 | rets[ann_id] = [ann_kls, rets[ann_id]] 167 | return rets 168 | 169 | 170 | # for mtl pipeline 171 | def make_mtl_token_preds_batch(classifier: nn.Module, 172 | batch_elements: List[SentenceEvidence], 173 | labels_mapping: Dict[str, int], 174 | token_mapping: Dict[str, List[List[Tuple[int, int]]]], 175 | max_length: int, 176 | par_lambda: int, 177 | device=None, 178 | cls_criterion: nn.Module = None, 179 | exp_criterion: nn.Module = None, 180 | tensorize_model_inputs: bool = True) -> Tuple[float, float, float, List[float], List[int], List[int]]: 181 | batch_elements = [s for s in batch_elements if s is not None] 182 | labels, targets, queries, sentences = zip(*[(s[0], s[1].kls, s[1].query, s[1].sentence) 183 | for s in batch_elements]) 184 | labels = [[i == labels_mapping[label] for i in range(len(labels_mapping))] for label in labels] 185 | labels = torch.tensor(labels, dtype=torch.float, device=device) 186 | ids = [(s[1].ann_id, s[1].docid, s[1].index) for s in batch_elements] 187 | 188 | cropped_targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep] 189 | + list(target[:(max_length - len(query) - 2)]) for query, target in zip(queries, targets)] 190 | cropped_targets = PaddedSequence.autopad( 191 | [torch.tensor(t, dtype=torch.float, device=device) for t in cropped_targets], 192 | batch_first=True, device=device) 193 | 194 | targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep] 195 | + list(target) for query, target in zip(queries, targets)] 196 | targets = PaddedSequence.autopad([torch.tensor(t, dtype=torch.float, device=device) for t in targets], 197 | batch_first=True, device=device) 198 | 199 | if tensorize_model_inputs: 200 | if all(q is None for q in queries): 201 | queries = [torch.tensor([], dtype=torch.long) for _ in queries] 202 | else: 203 | assert all(q is not None for q in queries) 204 | queries = [torch.tensor(q, dtype=torch.long) for q in queries] 205 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] 206 | preds = classifier(queries, ids, sentences) 207 | cls_preds, exp_preds, attention_masks = preds 208 | cls_loss = cls_criterion(cls_preds, labels).mean(dim=-1).sum() 209 | cls_preds = [x.cpu().tolist() for x in cls_preds] 210 | labels = [x.cpu().tolist() for x in labels] 211 | 212 | exp_loss = exp_criterion(exp_preds, cropped_targets.data.squeeze()).mean(dim=-1).sum() 213 | #print(exp_loss.shape, cls_loss.shape) 214 | exp_preds = [x.cpu() for x in exp_preds] 215 | hard_preds = [torch.round(x).to(dtype=torch.int).cpu() for x in targets.unpad(exp_preds)] 216 | exp_preds = [x.tolist() for x in targets.unpad(exp_preds)] 217 | token_targets = [[y.item() for y in x] for x in targets.unpad(targets.data.cpu())] 218 | total_loss = cls_loss + par_lambda * exp_loss 219 | 220 | return total_loss, cls_loss, exp_loss, \ 221 | exp_preds, hard_preds, token_targets, \ 222 | cls_preds, labels 223 | 224 | 225 | # for mtl pipeline 226 | def make_mtl_token_preds_epoch(classifier: nn.Module, 227 | data: List[SentenceEvidence], 228 | labels_mapping: Dict[str, int], 229 | token_mapping: Dict[str, List[List[Tuple[int, int]]]], 230 | batch_size: int, 231 | max_length: int, 232 | par_lambda: int, 233 | device=None, 234 | cls_criterion: nn.Module = None, 235 | exp_criterion: nn.Module = None, 236 | tensorize_model_inputs: bool = True): 237 | epoch_total_loss = 0 238 | epoch_cls_loss = 0 239 | epoch_exp_loss = 0 240 | epoch_soft_pred = [] 241 | epoch_hard_pred = [] 242 | epoch_token_targets = [] 243 | epoch_pred_labels = [] 244 | epoch_labels = [] 245 | batches = _grouper(data, batch_size) 246 | classifier.eval() 247 | #for p in classifier.parameters(): 248 | # print(str(p.device)) 249 | # print('cuda:0') 250 | # print(str(p.device) == 'cuda:0') 251 | # if str(p.device) != 'cuda:0': 252 | # print(p) 253 | # assert False 254 | for batch in batches: 255 | total_loss, cls_loss, exp_loss, \ 256 | soft_preds, hard_preds, token_targets, \ 257 | pred_labels, labels = make_mtl_token_preds_batch(classifier, 258 | batch, 259 | labels_mapping, 260 | token_mapping, 261 | max_length, 262 | par_lambda, 263 | device, 264 | cls_criterion=cls_criterion, 265 | exp_criterion=exp_criterion, 266 | tensorize_model_inputs=tensorize_model_inputs) 267 | if total_loss is not None: 268 | epoch_total_loss += total_loss.sum().item() 269 | if cls_loss is not None: 270 | epoch_cls_loss += cls_loss.sum().item() 271 | if exp_loss is not None: 272 | epoch_exp_loss += exp_loss.sum().item() 273 | epoch_hard_pred.extend(hard_preds) 274 | epoch_soft_pred.extend(soft_preds) 275 | epoch_token_targets.extend(token_targets) 276 | epoch_pred_labels.extend(pred_labels) 277 | epoch_labels.extend(labels) 278 | epoch_total_loss /= len(data) 279 | epoch_cls_loss /= len(data) 280 | epoch_exp_loss /= len(data) 281 | return epoch_total_loss, epoch_cls_loss, epoch_exp_loss, \ 282 | epoch_soft_pred, epoch_hard_pred, epoch_token_targets, \ 283 | epoch_pred_labels, epoch_labels 284 | 285 | 286 | def make_mtl_classification_preds_batch(classifier: nn.Module, 287 | batch_elements: List[SentenceEvidence], 288 | class_interner: dict, 289 | device=None, 290 | criterion: nn.Module = None, 291 | tensorize_model_inputs: bool = True) -> Tuple[float, List[float], List[int], List[int]]: 292 | batch_elements = filter(lambda x: x is not None, batch_elements) 293 | targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements]) 294 | ids = [(s.ann_id, s.docid, s.index) for s in batch_elements] 295 | targets = [[i == target for i in range(len(class_interner))] for target in targets] 296 | targets = torch.tensor(targets, dtype=torch.float, device=device) 297 | if tensorize_model_inputs: 298 | queries = [torch.tensor(q, dtype=torch.long) for q in queries] 299 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences] 300 | preds = classifier(queries, ids, sentences) 301 | targets = targets.to(device=preds.device) 302 | if criterion: 303 | loss = criterion(preds, targets) 304 | else: 305 | loss = None 306 | # .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16 307 | hard_preds = torch.argmax(preds.float(), dim=-1) 308 | return loss, preds, hard_preds, targets 309 | 310 | 311 | def make_mtl_classification_preds_epoch(classifier: nn.Module, 312 | data: List[SentenceEvidence], 313 | class_interner: dict, 314 | batch_size: int, 315 | device=None, 316 | criterion: nn.Module = None, 317 | tensorize_model_inputs: bool = True): 318 | epoch_loss = 0 319 | epoch_soft_pred = [] 320 | epoch_hard_pred = [] 321 | epoch_truth = [] 322 | batches = _grouper(data, batch_size) 323 | classifier.eval() 324 | for batch in batches: 325 | loss, soft_preds, hard_preds, targets = make_mtl_classification_preds_batch(classifier=classifier, 326 | batch_elements=batch, 327 | class_interner=class_interner, 328 | device=device, 329 | criterion=criterion, 330 | tensorize_model_inputs=tensorize_model_inputs) 331 | if loss is not None: 332 | epoch_loss += loss.sum().item() 333 | epoch_hard_pred.extend(hard_preds) 334 | epoch_soft_pred.extend(soft_preds.cpu()) 335 | epoch_truth.extend(targets) 336 | epoch_loss /= len(data) 337 | epoch_hard_pred = [x.item() for x in epoch_hard_pred] 338 | epoch_truth = [x.argmax().item() for x in epoch_truth] 339 | return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth 340 | 341 | 342 | def convert_to_global_token_mapping(token_mapping): 343 | ret = [] 344 | sent_offset = 0 345 | for sent in token_mapping: 346 | for mapping in sent: 347 | ret.append((mapping[0]+sent_offset, mapping[1]+sent_offset)) 348 | sent_offset = ret[-1][-1] 349 | return ret 350 | 351 | 352 | def decode(evidence_identifier: nn.Module, 353 | evidence_classifier: nn.Module, 354 | train: List[Annotation], 355 | val: List[Annotation], 356 | test: List[Annotation], 357 | source_documents, 358 | token_mapping, 359 | mrs_train, 360 | mrs_eval, 361 | mrs_test, 362 | class_interner: Dict[str, int], 363 | batch_size: int, 364 | tensorize_modelinputs: bool, 365 | interned_documents: bool=None, 366 | tokenizer=None) -> dict: 367 | device = None 368 | class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])] 369 | if interned_documents is None: 370 | interned_documents = source_documents 371 | 372 | def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]: 373 | identification_data = annotations_to_mtl_token_identification(annotations=data, 374 | source_documents=source_documents, 375 | interned_documents=interned_documents, 376 | token_mapping=token_mapping) 377 | 378 | identification_data = {ann_id: [v[0], {docid: chain_sentence_evidences(sentences) 379 | for docid, sentences in v[1].items()}] 380 | for ann_id, v in identification_data.items()} 381 | classification_data = mask_annotations_to_evidence_classification(mrs=mrs_test, 382 | class_interner=class_interner) 383 | #ann_doc_sents = defaultdict(lambda: defaultdict(dict)) # ann id -> docid -> sent idx -> sent data 384 | ret = [] 385 | for sent_ev in classification_data: 386 | id_data = identification_data[sent_ev.ann_id][1][sent_ev.docid] 387 | ret.append((id_data, sent_ev)) 388 | assert id_data.ann_id == sent_ev.ann_id 389 | assert id_data.docid == sent_ev.docid 390 | #assert id_data.index == sent_ev.index 391 | assert len(ret) == len(classification_data) 392 | return ret 393 | 394 | def decode_batch(data: List[Tuple[SentenceEvidence, SentenceEvidence]], 395 | mrs, 396 | name: str, 397 | score: bool = False, 398 | annotations: List[Annotation] = None, 399 | tokenizer=None) -> dict: 400 | """Identifies evidence statements and then makes classifications based on it. 401 | 402 | Args: 403 | data: a paired list of SentenceEvidences, differing only in the kls field. 404 | The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class 405 | name: a name for a results dict 406 | """ 407 | 408 | num_uniques = len(set((x.ann_id, x.docid) for x, _ in data)) 409 | logging.info(f'Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations') 410 | identifier_data, classifier_data = zip(*data) 411 | results = dict() 412 | IdentificationClassificationResult = namedtuple('IdentificationClassificationResult', 413 | 'identification_data classification_data soft_identification hard_identification soft_classification hard_classification') 414 | with torch.no_grad(): 415 | # make predictions for the evidence_identifier 416 | evidence_identifier.eval() 417 | evidence_classifier.eval() 418 | _, soft_identification_preds, hard_identification_preds = zip(*mrs) 419 | assert len(soft_identification_preds) == len(data) 420 | identification_results = defaultdict(list) 421 | for id_data, cls_data, soft_id_pred, hard_id_pred in zip(identifier_data, classifier_data, 422 | soft_identification_preds, 423 | hard_identification_preds): 424 | res = IdentificationClassificationResult(identification_data=id_data, 425 | classification_data=cls_data, 426 | # 1 is p(evidence|sent,query) 427 | soft_identification=soft_id_pred, 428 | hard_identification=hard_id_pred, 429 | soft_classification=None, 430 | hard_classification=False) 431 | identification_results[(id_data.ann_id, id_data.docid)].append(res) # in original eraser, each sentence 432 | # is stored separately, thence for each ann_idxdocid key there is a list of identification results, each 433 | # corresponds to a sentence. While in our approach a document is chained together from the begining and 434 | # rationalities are predicted in token-level granularity 435 | 436 | best_identification_results = {key: max(value, key=lambda x: x.soft_identification) for key, value in 437 | identification_results.items()} 438 | logging.info( 439 | f'Selected the best sentence for {len(identification_results)} examples from a total of {len(soft_identification_preds)} sentences') 440 | ids, classification_data = zip( 441 | *[(k, v.classification_data) for k, v in best_identification_results.items()]) 442 | _, soft_classification_preds, hard_classification_preds, classification_truth = \ 443 | make_mtl_classification_preds_epoch(classifier=evidence_classifier, 444 | data=classification_data, 445 | class_interner=class_interner, 446 | batch_size=batch_size, 447 | device=device, 448 | tensorize_model_inputs=tensorize_modelinputs) 449 | classification_results = dict() 450 | for eyeD, soft_class, hard_class in zip(ids, soft_classification_preds, hard_classification_preds): 451 | input_id_result = best_identification_results[eyeD] 452 | res = IdentificationClassificationResult(identification_data=input_id_result.identification_data, 453 | classification_data=input_id_result.classification_data, 454 | soft_identification=input_id_result.soft_identification, 455 | hard_identification=input_id_result.hard_identification, 456 | soft_classification=soft_class, 457 | hard_classification=hard_class) 458 | classification_results[eyeD] = res 459 | 460 | if score: 461 | truth = [] 462 | pred = [] 463 | for res in classification_results.values(): 464 | truth.append(res.classification_data.kls) 465 | pred.append(res.hard_classification) 466 | # results[f'{name}_f1'] = classification_report(classification_truth, pred, target_names=class_labels, output_dict=True) 467 | results[f'{name}_f1'] = classification_report(classification_truth, hard_classification_preds, 468 | target_names=class_labels, 469 | labels=list(range(len(class_labels))), output_dict=True) 470 | results[f'{name}_acc'] = accuracy_score(classification_truth, hard_classification_preds) 471 | results[f'{name}_rationale'] = score_rationales(annotations, interned_documents, identifier_data, 472 | soft_identification_preds) 473 | 474 | # turn the above results into a format suitable for scoring via the rationale scorer 475 | # n.b. the sentence-level evidence predictions (hard and soft) are 476 | # broadcast to the token level for scoring. The comprehensiveness class 477 | # score is also a lie since the pipeline model above is faithful by 478 | # design. 479 | decoded = dict() 480 | decoded_scores = defaultdict(list) 481 | for (ann_id, docid), pred in classification_results.items(): 482 | #sentence_prediction_scores = [x.soft_identification for x in identification_results[(ann_id, docid)]] 483 | hard_rationale_predictions = list(chain.from_iterable(x.hard_identification for x in identification_results[(ann_id, docid)])) 484 | soft_rationale_predictions = list(chain.from_iterable(x.soft_identification for x in identification_results[(ann_id, docid)])) 485 | subtoken_ids = list(chain.from_iterable(interned_documents[docid])) 486 | raw_document = [] 487 | for word in chain.from_iterable(source_documents[docid]): 488 | token_ids_origin = tokenizer.encode(word, add_special_tokens=False) 489 | if token_ids_origin[0] == tokenizer.unk_token_id: 490 | raw_document.append('[UNK]') 491 | else: 492 | tokenized = ''.join(tokenizer.basic_tokenizer.tokenize(word)) # dumm ass ˈlʊdvɪɡ_væn_ˈbeɪˌtoʊvən 493 | raw_document.append(tokenized) 494 | global_token_mapping = convert_to_global_token_mapping(token_mapping[docid]) 495 | tokens, exp_outputs = convert_subtoken_ids_to_tokens(subtoken_ids, 496 | tokenizer=tokenizer, 497 | token_mapping=global_token_mapping, 498 | exps=(hard_rationale_predictions, 499 | soft_rationale_predictions), 500 | raw_sentence=raw_document) 501 | hard_rationale_predictions, soft_rationale_predictions = list(zip(*exp_outputs)) 502 | # if docid == 'Ludwig_van_Beethoven': 503 | # #print(len(hard_rationale_predictions)) 504 | # print(len(soft_rationale_predictions)) 505 | ev_generator = rational_bits_to_ev_generator(tokens, 506 | docid, 507 | hard_rationale_predictions) 508 | hard_rationale_predictions = [ev for ev in ev_generator] 509 | 510 | if ann_id not in decoded: 511 | decoded[ann_id] = { 512 | "annotation_id": ann_id, 513 | "rationales": [], 514 | "classification": class_labels[pred.hard_classification], 515 | "classification_scores": {class_labels[i]: s.item() for i, s in 516 | enumerate(pred.soft_classification)}, 517 | # TODO this should turn into the data distribution for the predicted class 518 | # "comprehensiveness_classification_scores": 0.0, 519 | "truth": pred.classification_data.kls, 520 | } 521 | decoded[ann_id]['rationales'].append({ 522 | "docid": docid, 523 | "hard_rationale_predictions": hard_rationale_predictions, 524 | "soft_rationale_predictions": soft_rationale_predictions, 525 | }) 526 | decoded_scores[ann_id].append(pred.soft_classification) 527 | 528 | return results, list(decoded.values()) 529 | 530 | test_results, test_decoded = decode_batch(prep(test), mrs_test, 'test', score=False, tokenizer=tokenizer) 531 | val_results, val_decoded = dict(), [] 532 | train_results, train_decoded = dict(), [] 533 | # val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val) 534 | # train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train) 535 | return dict(**train_results, **val_results, **test_results), train_decoded, val_decoded, test_decoded 536 | --------------------------------------------------------------------------------