├── expred
├── __init__.py
├── models
│ ├── __init__.py
│ ├── pipeline
│ │ ├── __init__.py
│ │ ├── mtl_evidence_classifier.py
│ │ ├── mtl_token_identifier.py
│ │ └── mtl_pipeline_utils.py
│ ├── losses.py
│ ├── model_utils.py
│ └── mlp_mtl.py
├── params.py
├── eraser_utils.py
├── rationale_tokenization.py
├── bert_rational_feature.py
├── preprocessing.py
├── tokenizer.py
├── eraser_benchmark.py
├── utils.py
└── train.py
├── test_bogus.py
├── requirements.txt
├── .idea
├── vcs.xml
├── other.xml
├── misc.xml
├── interpretation_by_design.iml
├── modules.xml
├── deployment.xml
├── runConfigurations
│ └── mtl_pipeline_movies_regression_test.xml
└── inspectionProfiles
│ └── Project_Default.xml
├── scripts
├── run_sweep_wrapper.sh
├── run_fever_final.sh
├── run_fever_gru_final.sh
├── run_fever_rnr_final.sh
├── run_movies_rnr_final.sh
├── run_multirc_gru_final.sh
├── run_fever_b.sh
├── run_movies_gru_final.sh
├── run_multirc_rnr_final_a.sh
├── run_multirc_rnr_final_b.sh
├── run_fever_a.sh
├── run_multirc_rnr_sweep_b.sh
├── run_movies_rnr_sweep.sh
├── run_multirc_rnr_sweep.sh
├── run_multirc_rnr_sweep_a.sh
├── run_probe.sh
├── run_fever_firsts.sh
├── run_multirc_firsts.sh
├── run_gru_firsts.sh
├── run_fever_gru_sweep.sh
├── run_multirc_gru_sweep.sh
├── run_fever_rnr_sweep.sh
├── run_rnr_firsts.sh
├── run_sweep.sh
├── run_b.sh
├── run_a.sh
├── run_multirc_rnr_sweep_a_and_freeze.sh
├── run_multirc_rnr_sweep_c.sh
└── run_multirc_rnr_sweep_node05.sh
├── test
├── test_train.py
└── test_params
│ └── movies_expred.json
├── params
├── movies_expred.json
├── multirc_expred.json
└── fever_expred.json
├── README.md
├── .github
└── workflows
│ └── ci.yml
├── .gitignore
└── sample_rationales_dataset.py
/expred/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/expred/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/expred/models/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/test_bogus.py:
--------------------------------------------------------------------------------
1 | def test_test():
2 | assert True
3 |
--------------------------------------------------------------------------------
/expred/params.py:
--------------------------------------------------------------------------------
1 | class MTLParams():
2 | dim_cls_linear: int
3 | num_labels: int
4 | dim_exp_gru: int
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==2.5.1
2 | torch==1.3.0
3 | torchvision==0.4.2
4 | gensim==3.7.1
5 | scikit-learn==0.20.3
6 | wandb==0.10.24
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/scripts/run_sweep_wrapper.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | dataset=$1
4 | structure=$2
5 | portion=$3
6 | . .venv/bin/activate
7 | for i in 0 1 2 3; do
8 | ./run_sweep.sh $i ${dataset} ${structure} ${portion}&
9 | done
10 | deactivate
11 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/test/test_train.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from expred.train import main
3 |
4 | class TestTrain(unittest.TestCase):
5 | def test_runs_for_one_epoch(self):
6 | # todo fix path / data problem
7 | args = [
8 | "--data_dir", "/home/mreimer/datasets/eraser/movies_debug",
9 | "--output_dir", "output/",
10 | "--conf", "test_params/movies_expred.json",
11 | "--batch_size", "4"]
12 | main(args)
13 |
14 |
15 | if __name__ == '__main__':
16 | unittest.main()
17 |
--------------------------------------------------------------------------------
/scripts/run_fever_final.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 2 )
4 | gpu_id=0,1
5 | batch_size=16
6 | num_epochs=10
7 | dataset='fever'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_fever_gru_final.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 5. )
4 | gpu_id=2
5 | batch_size=4
6 | num_epochs=10
7 | dataset='fever'
8 | exp_structure='gru'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_fever_rnr_final.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.2 )
4 | gpu_id=1
5 | batch_size=16
6 | num_epochs=10
7 | dataset='fever'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_movies_rnr_final.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.1 )
4 | gpu_id=0
5 | batch_size=4
6 | num_epochs=10
7 | dataset='movies'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_multirc_gru_final.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 2. )
4 | gpu_id=3
5 | batch_size=4
6 | num_epochs=10
7 | dataset='multirc'
8 | exp_structure='gru'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_fever_b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.1 0.2 0.5 1 2 5 )
4 | gpu_id=1
5 | batch_size=16
6 | num_epochs=10
7 | dataset='fever'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0.1'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_movies_gru_final.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 100. 5. )
4 | gpu_id=1
5 | batch_size=4
6 | num_epochs=10
7 | dataset='movies'
8 | exp_structure='gru'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_final_a.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.05 )
4 | gpu_id=0
5 | batch_size=4
6 | num_epochs=10
7 | dataset='multirc'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_final_b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.1 )
4 | gpu_id=0
5 | batch_size=4
6 | num_epochs=10
7 | dataset='multirc'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_fever_a.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.02 0.03 0.05 )
4 | gpu_id=0
5 | batch_size=16
6 | num_epochs=10
7 | dataset='fever'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 | train_on_portion='0.1'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_sweep_b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.1 0.2 0.5 1. 2. )
4 | gpu_id=1
5 | batch_size=16
6 | num_epochs=10
7 | dataset='multirc'
8 | exp_structure='rnr'
9 | benchmark_split='val'
10 | train_on_portion='0.4'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/.idea/interpretation_by_design.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/scripts/run_movies_rnr_sweep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. )
4 | gpu_id=0
5 | batch_size=16
6 | num_epochs=10
7 | dataset='movies'
8 | exp_structure='rnr'
9 | benchmark_split='val'
10 | train_on_portion='0'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_sweep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. )
4 | gpu_id=0
5 | batch_size=16
6 | num_epochs=10
7 | dataset='multirc'
8 | exp_structure='rnr'
9 | benchmark_split='val'
10 | train_on_portion='0.4'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_sweep_a.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 )
4 | #0.1, 0.2, 0.5, 1., 2., 5. )
5 | gpu_id=0
6 | batch_size=16
7 | num_epochs=10
8 | dataset='multirc'
9 | exp_structure='rnr'
10 | benchmark_split='val'
11 | train_on_portion='0.4'
12 |
13 | for par_lambda in ${lambdas[@]}; do
14 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
15 | done
16 |
--------------------------------------------------------------------------------
/scripts/run_probe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 1. )
4 | gpu_id=0
5 | train_first=( exp cls )
6 | batch_size=16
7 | num_epochs=2
8 | dataset='movies'
9 | exp_structure='rnr'
10 | benchmark_split='test'
11 | train_on_portion='0.1'
12 |
13 | for phase in ${train_first[@]}; do
14 | for par_lambda in ${lambdas[@]}; do
15 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --train_${phase}_first;
16 | done
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/run_fever_firsts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 1. )
4 | gpu_id=0
5 | train_first=( exp cls )
6 | batch_size=16
7 | num_epochs=2
8 | dataset='movies'
9 | exp_structure='rnr'
10 | benchmark_split='test'
11 | train_on_portion='0.1'
12 |
13 | for phase in ${train_first[@]}; do
14 | for par_lambda in ${lambdas[@]}; do
15 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --train_${phase}_first;
16 | done
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/run_multirc_firsts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 1. )
4 | gpu_id=0
5 | train_first=( exp cls )
6 | batch_size=16
7 | num_epochs=2
8 | dataset='movies'
9 | exp_structure='rnr'
10 | benchmark_split='test'
11 | train_on_portion='0.1'
12 |
13 | for phase in ${train_first[@]}; do
14 | for par_lambda in ${lambdas[@]}; do
15 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --train_${phase}_first;
16 | done
17 | done
18 |
--------------------------------------------------------------------------------
/scripts/run_gru_firsts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | par_lambda=1.
4 | #gpu_id=0
5 | #train_first=( exp cls )
6 | gpu_id=$1
7 | phase=$2
8 | #batch_size=16
9 | batch_size=$3
10 | num_epochs=10
11 | datasets=( fever multirc )
12 | exp_structure='gru'
13 | benchmark_split='test'
14 | train_on_portion='0'
15 |
16 | for dataset in ${datasets[@]}; do
17 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --load_phase1 --merge_evidences --benchmark_split ${benchmark_split} --do_train --start_from_phase1 --train_on_portion ${train_on_portion} --train_${phase}_first;
18 | done
19 |
--------------------------------------------------------------------------------
/scripts/run_fever_gru_sweep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=$1
4 | l0='0.1,0.2,0.5'
5 | l1='1,2,5'
6 | l2='10,20,50'
7 | l3='100,200,500'
8 | lambdas=( $l0 $l1 $l2 $l3 )
9 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]}
10 | batch_size=5
11 | num_epochs=10
12 | dataset='movies'
13 | exp_structure='gru'
14 | benchmark_split='val'
15 | train_on_portion='0'
16 |
17 | for par_lambda in ${lambdas[@]}; do
18 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
19 | done
20 |
--------------------------------------------------------------------------------
/scripts/run_multirc_gru_sweep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=$1
4 | l0='0.1,0.2,0.5'
5 | l1='1,2,5'
6 | l2='10,20,50'
7 | l3='100,200,500'
8 | lambdas=( $l0 $l1 $l2 $l3 )
9 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]}
10 | batch_size=5
11 | num_epochs=10
12 | dataset='movies'
13 | exp_structure='gru'
14 | benchmark_split='val'
15 | train_on_portion='0'
16 |
17 | for par_lambda in ${lambdas[@]}; do
18 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
19 | done
20 |
--------------------------------------------------------------------------------
/scripts/run_fever_rnr_sweep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=$1
4 | l0='0.001,0.002,0.005'
5 | l1='0.01,0.02,0.05'
6 | l2='0.1,0.2,0.5'
7 | l3='1.,2.,5.'
8 | lambdas=( $l0 $l1 $l2 $l3 )
9 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]}
10 | batch_size=5
11 | num_epochs=10
12 | dataset='fever'
13 | exp_structure='rnr'
14 | benchmark_split='val'
15 | train_on_portion='0.1'
16 |
17 | for par_lambda in ${lambdas[@]}; do
18 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
19 | done
20 |
--------------------------------------------------------------------------------
/scripts/run_rnr_firsts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | par_lambda=1.
4 | gpu_id=1
5 | #train_first=( exp cls )
6 | train_first=( cls )
7 | batch_size=100
8 | num_epochs=10
9 | datasets=( movies fever multirc )
10 | #datasets=( movies )
11 | exp_structure='rnr'
12 | benchmark_split='test'
13 | train_on_portion='0'
14 |
15 | for phase in ${train_first[@]}; do
16 | for dataset in ${datasets[@]}; do
17 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --start_from_phase1 --train_on_portion ${train_on_portion} --train_${phase}_first;
18 | done
19 | done
20 |
--------------------------------------------------------------------------------
/scripts/run_sweep.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_id=$1
4 | dataset=$2
5 | exp_structure=$3
6 | train_on_portion=$4
7 | if [[ $exp_structure == 'gru' ]]; then
8 | l0='0.1,0.2,0.5'
9 | l1='1,2,5'
10 | l2='10,20,50'
11 | l3='100,200,500'
12 | else
13 | l0='0.001,0.002,0.005'
14 | l1='0.01,0.02,0.05'
15 | l2='0.1,0.2,0.5'
16 | l3='1.,2.,5'
17 | fi
18 | lambdas=( $l0 $l1 $l2 $l3 )
19 | IFS=',' read -r -a lambdas<<<${lambdas[$gpu_id]}
20 | batch_size=4
21 | num_epochs=10
22 | benchmark_split='val'
23 |
24 | for par_lambda in ${lambdas[@]}; do
25 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
26 | done
27 |
--------------------------------------------------------------------------------
/params/movies_expred.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "dim_cls_linear": 256,
4 | "dim_exp_gru": 128,
5 | "bert_vocab": "bert-base-uncased",
6 | "bert_dir": "bert-base-uncased",
7 | "exp_structure": "gru",
8 | "rebalance_approach": "resampling",
9 | "merge_evidences": 1,
10 | "classes": [
11 | "NEG",
12 | "POS"
13 | ],
14 | "mtl_token_identifier": {
15 | "par_lambda": 5.0,
16 | "batch_size": 16,
17 | "epochs": 1,
18 | "patience": 3,
19 | "warmup_steps": 50,
20 | "lr": 1e-5,
21 | "use_half_precision": 0,
22 | "sampling_method": "whole_document"
23 | },
24 | "evidence_classifier": {
25 | "batch_size": 16,
26 | "warmup_steps": 50,
27 | "epochs": 1,
28 | "patience": 3,
29 | "lr": 1e-5,
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/params/multirc_expred.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "dim_cls_linear": 256,
4 | "dim_exp_gru": 128,
5 | "bert_vocab": "bert-base-uncased",
6 | "bert_dir": "bert-base-uncased",
7 | "exp_structure": "gru",
8 | "rebalance_approach": "resampling",
9 | "merge_evidences": 1,
10 | "classes": [
11 | "True",
12 | "False"
13 | ],
14 | "mtl_token_identifier": {
15 | "par_lambda": 20.0,
16 | "batch_size": 16,
17 | "epochs": 10,
18 | "patience": 3,
19 | "warmup_steps": 50,
20 | "lr": 1e-5,
21 | "use_half_precision": 0,
22 | "sampling_method": "whole_document"
23 | },
24 | "evidence_classifier": {
25 | "batch_size": 16,
26 | "warmup_steps": 50,
27 | "epochs": 10,
28 | "patience": 3,
29 | "lr": 1e-5,
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/params/fever_expred.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "dim_cls_linear": 256,
4 | "dim_exp_gru": 128,
5 | "bert_vocab": "bert-base-uncased",
6 | "bert_dir": "bert-base-uncased",
7 | "exp_structure": "gru",
8 | "rebalance_approach": "resampling",
9 | "merge_evidences": 1,
10 | "classes": [
11 | "SUPPORTS",
12 | "REFUTES"
13 | ],
14 | "mtl_token_identifier": {
15 | "par_lambda": 2.0,
16 | "batch_size": 16,
17 | "epochs": 10,
18 | "patience": 3,
19 | "warmup_steps": 50,
20 | "lr": 1e-5,
21 | "use_half_precision": 0,
22 | "sampling_method": "whole_document"
23 | },
24 | "evidence_classifier": {
25 | "batch_size": 16,
26 | "warmup_steps": 50,
27 | "epochs": 10,
28 | "patience": 3,
29 | "lr": 1e-5,
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/test/test_params/movies_expred.json:
--------------------------------------------------------------------------------
1 | {
2 | "max_length": 512,
3 | "dim_cls_linear": 256,
4 | "dim_exp_gru": 128,
5 | "bert_vocab": "bert-base-uncased",
6 | "bert_dir": "bert-base-uncased",
7 | "exp_structure": "gru",
8 | "rebalance_approach": "resampling",
9 | "merge_evidences": 1,
10 | "classes": [
11 | "NEG",
12 | "POS"
13 | ],
14 | "mtl_token_identifier": {
15 | "par_lambda": 5.0,
16 | "batch_size": 16,
17 | "epochs": 1,
18 | "patience": 3,
19 | "warmup_steps": 50,
20 | "lr": 1e-5,
21 | "use_half_precision": 0,
22 | "sampling_method": "whole_document"
23 | },
24 | "evidence_classifier": {
25 | "batch_size": 16,
26 | "warmup_steps": 50,
27 | "epochs": 1,
28 | "patience": 3,
29 | "lr": 1e-5,
30 | "use_half_precision": 0
31 | }
32 | }
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ExPred
2 |
3 | This is the implementation of the paper [Explain and Predict, and then Predict Again](https://dl.acm.org/doi/abs/10.1145/3437963.3441758) (accepted in WSDM2021). This code is implemented based on the pipeline model of the [Eraserbenchmark](http://www.eraserbenchmark.com/). All data used by the model can be found from the Eraser Benchmark, too.
4 |
5 | ## Usage:
6 | 1. Install the required packages from the ```requirements.txt``` by ```pip install -r requirements.txt```
7 | 2. The implementation entry is under ```expred/train```. To run the training, simply copy and paste the following commands:
8 | ``` export PYTHONPATH=$PYTHONPATH:./ && python expred/train.py --data_dir /dir/to/your/datasets/{movies,fever,multirc} --output_dir /dir/to/your/trained_data --conf ./params/{movies,fever,multirc}_expred.json```
9 |
10 | Not that depending on your hardware you may have to change the `batch_size` in the config file.
--------------------------------------------------------------------------------
/scripts/run_b.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 10. 20. 50. )
4 | gpu_id=1
5 | batch_size=16
6 | num_epochs=10
7 | dataset='movies'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 |
11 | lambdas=( 2 )
12 | gpu_id=1
13 | dataset='fever'
14 |
15 | for par_lambda in ${lambdas[@]}; do
16 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints;
17 | done
18 |
19 | dataset='multirc'
20 |
21 | for par_lambda in ${lambdas[@]}; do
22 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints;
23 | done
24 |
--------------------------------------------------------------------------------
/expred/eraser_utils.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 |
3 | from expred.models.pipeline.pipeline_utils import SentenceEvidence
4 |
5 | def get_docids(ann):
6 | ret = []
7 | for ev_group in ann.evidences:
8 | for ev in ev_group:
9 | ret.append(ev.docid)
10 | return ret
11 |
12 | def extract_doc_ids_from_annotations(anns):
13 | ret = set()
14 | for ann in anns:
15 | ret |= set(get_docids(ann))
16 | return ret
17 |
18 | def chain_sentence_evidences(sentences):
19 | kls = list(chain.from_iterable(s.kls for s in sentences))
20 | document = list(chain.from_iterable(s.sentence for s in sentences))
21 | assert len(kls) == len(document)
22 | return SentenceEvidence(kls=kls,
23 | ann_id=sentences[0].ann_id,
24 | sentence=document,
25 | docid=sentences[0].docid,
26 | index=sentences[0].index,
27 | query=sentences[0].query,
28 | has_evidence=any(map(lambda s: s.has_evidence, sentences)))
--------------------------------------------------------------------------------
/expred/models/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 |
3 | import torch
4 |
5 |
6 | def resampling_rebalanced_crossentropy(seq_reduction : str ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
7 | """
8 | Returns the loss function with given seq_reduction strategy.
9 |
10 | The individual token loss of token $i$ in sequence $S$ s $\vertS_{t^i}\vert \cdot BCE(p^i, t^i)$ where
11 | * $p^i$ and $t^i$ are the predicted and target labels of token $i$ respectively
12 | * \vertS_{t^i}\vert is the number of tokens that have the same target a label as token $i$
13 | :param seq_reduction: either 'none' or 'mean'.
14 | :return:
15 | """
16 | def loss(y_pred, y_true):
17 | prior_pos = torch.mean(y_true, dim=-1, keepdims=True) # percentage of positive tokens (rational)
18 | prior_neg = torch.mean(1 - y_true, dim=-1, keepdim=True) # vice versa
19 | eps = 1e-10
20 | weight = y_true / (prior_pos + eps) + (1 - y_true) / (prior_neg + eps)
21 | ret = -weight * (y_true * (torch.log(y_pred + eps)) + (1 - y_true) * (torch.log(1 - y_pred + eps)))
22 | if seq_reduction == 'mean':
23 | return torch.mean(ret, dim=-1)
24 | elif seq_reduction == 'none':
25 | return ret
26 |
27 | return loss
28 |
--------------------------------------------------------------------------------
/scripts/run_a.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 10. 20. 50. )
4 | gpu_id=0
5 | batch_size=16
6 | num_epochs=10
7 | dataset='movies'
8 | exp_structure='rnr'
9 | benchmark_split='test'
10 |
11 | for par_lambda in ${lambdas[@]}; do
12 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints;
13 | done
14 |
15 | #lambdas=( 2 )
16 | #gpu_id=1
17 | #dataset='fever'
18 |
19 | #for par_lambda in ${lambdas[@]}; do
20 | # python learn_to_interpret.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split};
21 | #done
22 |
23 | #lambdas=( 2 )
24 | #gpu_id=1
25 | #dataset='multirc'
26 |
27 | #for par_lambda in ${lambdas[@]}; do
28 | # python learn_to_interpret.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split};
29 | #done
30 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_sweep_a_and_freeze.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 )
4 | gpu_id=1
5 | batch_size=16
6 | num_epochs=10
7 | dataset='multirc'
8 | exp_structure='rnr'
9 | benchmark_split='val'
10 | train_on_portion='0.4'
11 |
12 | for par_lambda in ${lambdas[@]}; do
13 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion};
14 | done
15 |
16 | par_lambda=1.
17 | datasets=( movies fever multirc )
18 | exp_structures=( gru rnr )
19 | freezes=( cls exp )
20 | train_on_portion='0'
21 |
22 | for dataset in ${datasets[@]}; do
23 | for exp_structure in ${exp_structures[@]}; do
24 | for freeze in ${freezes[@]}; do
25 | python bert_as_tfkeras_layer.py --par_lambda ${par_lambda} --gpu_id ${gpu_id} --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion} --freeze_${freeze};
26 | done
27 | done
28 | done
29 |
30 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ pytorch ]
9 | pull_request:
10 | branches: [ pytorch ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | env:
20 | GIT_TRACE: 1
21 | GIT_CURL_VERBOSE: 1
22 | - name: Set up Python 3.9
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: 3.9
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install flake8 pytest
30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
31 | - name: Lint with flake8
32 | run: |
33 | # stop the build if there are Python syntax errors or undefined names
34 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
35 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
36 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
37 | - name: Test with pytest
38 | run: |
39 | pytest
40 |
--------------------------------------------------------------------------------
/.idea/runConfigurations/mtl_pipeline_movies_regression_test.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_sweep_c.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. )
4 | batch_size=6
5 | num_epochs=10
6 | dataset='multirc'
7 | exp_structure='rnr'
8 | benchmark_split='test'
9 | train_on_portion='0.4'
10 |
11 | python bert_as_tfkeras_layer.py --par_lambda 0.5 --gpu_id 0 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
12 | python bert_as_tfkeras_layer.py --par_lambda 1. --gpu_id 1 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
13 | python bert_as_tfkeras_layer.py --par_lambda 2. --gpu_id 2 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
14 | python bert_as_tfkeras_layer.py --par_lambda 5. --gpu_id 3 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
15 |
--------------------------------------------------------------------------------
/scripts/run_multirc_rnr_sweep_node05.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | lambdas=( 0.001 0.002 0.005 0.01 0.02 0.05 0.1 0.2 0.5 1. 2. 5. )
4 | batch_size=6
5 | num_epochs=10
6 | dataset='multirc'
7 | exp_structure='rnr'
8 | benchmark_split='test'
9 | train_on_portion='0.4'
10 |
11 | python bert_as_tfkeras_layer.py --par_lambda 0.001 --gpu_id 0 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
12 | python bert_as_tfkeras_layer.py --par_lambda 0.002 --gpu_id 1 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
13 | python bert_as_tfkeras_layer.py --par_lambda 0.005 --gpu_id 2 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
14 | python bert_as_tfkeras_layer.py --par_lambda 0.01 --gpu_id 3 --batch_size ${batch_size} --num_epochs ${num_epochs} --dataset ${dataset} --evaluate --exp_benchmark --exp_structure ${exp_structure} --merge_evidences --benchmark_split ${benchmark_split} --do_train --delete_checkpoints --train_on_portion ${train_on_portion}&
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | model_components
2 | *deprecated*
3 | pipeline_outputs
4 | *.swp
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 | *$py.class
9 |
10 | # C extensions
11 | *.so
12 |
13 | # Distribution / packaging
14 | .Python
15 | build/
16 | develop-eggs/
17 | dist/
18 | downloads/
19 | eggs/
20 | .eggs/
21 | lib/
22 | lib64/
23 | parts/
24 | sdist/
25 | var/
26 | wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # pyenv
80 | .python-version
81 |
82 | # celery beat schedule file
83 | celerybeat-schedule
84 |
85 | # SageMath parsed files
86 | *.sage.py
87 |
88 | # Environments
89 | .env
90 | .venv
91 | env/
92 | venv/
93 | ENV/
94 | env.bak/
95 | venv.bak/
96 |
97 | # Spyder project settings
98 | .spyderproject
99 | .spyproject
100 |
101 | # Rope project settings
102 | .ropeproject
103 |
104 | # mkdocs documentation
105 | /site
106 |
107 | # mypy
108 | .mypy_cache/
109 |
110 | rationale_benchmark/data/esnli_previous
111 | data/esnli_previous
112 | esnli_union/
113 |
114 | misc/*
115 |
--------------------------------------------------------------------------------
/expred/rationale_tokenization.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers.tokenization_bert import WordpieceTokenizer, BertTokenizer, BasicTokenizer, whitespace_tokenize
3 | from transformers.file_utils import http_get
4 |
5 |
6 | def printable_text(text):
7 | if isinstance(text, str):
8 | return text
9 | elif isinstance(text, bytes):
10 | return text.decode("utf-8", "ignore")
11 | else:
12 | raise ValueError("Unsupported string type: %s" % (type(text)))
13 |
14 |
15 | def convert_to_unicode(text):
16 | if isinstance(text, str):
17 | return text
18 | elif isinstance(text, bytes):
19 | return text.decode("utf-8", "ignore")
20 | else:
21 | raise ValueError("Unsupported string type: %s" % (type(text)))
22 |
23 |
24 | class BasicRationalTokenizer(BasicTokenizer):
25 | def __init__(self, do_lower_case=True):
26 | self.do_lower_case = do_lower_case
27 |
28 | def _is_token_rational(self, token_idx, evidences):
29 | if evidences is None:
30 | return 0
31 | for ev in evidences:
32 | if token_idx >= ev.start_token and token_idx < ev.end_token:
33 | return 1
34 | return 0
35 |
36 | def tokenize(self, text, evidences):
37 | text = convert_to_unicode(text)
38 | text = self._clean_text(text)
39 | orig_tokens = whitespace_tokenize(text)
40 | split_tokens = []
41 | split_rations = []
42 | for token_idx, token in enumerate(orig_tokens):
43 | if self.do_lower_case:
44 | token = token.lower()
45 | token = self._run_strip_accents(token)
46 | sub_tokens = self._run_split_on_punc(token)
47 | sub_tokens = ' '.join(sub_tokens).strip().split()
48 | if len(sub_tokens) > 0:
49 | split_tokens.extend(sub_tokens)
50 | ration = self._is_token_rational(token_idx, evidences)
51 | split_rations.extend([ration] * len(sub_tokens))
52 | return zip(split_tokens, split_rations)
53 |
54 |
55 | class FullRationaleTokenizer(BertTokenizer): # Test passed :)
56 | def __init__(self, do_lower_case=True):
57 | if not os.path.isfile('bert-base-uncased-vocab.txt'):
58 | http_get("https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt",
59 | open("bert-base-uncased-vocab.txt", 'wb+'))
60 | super(FullRationaleTokenizer, self).__init__('bert-base-uncased-vocab.txt')
61 | self.basic_rational_tokenizer = BasicRationalTokenizer(do_lower_case=do_lower_case)
62 | self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token="[UNK]")
63 |
64 | def tokenize(self, text, evidences=None):
65 | split_tokens = []
66 | split_rations = []
67 | for token, ration in self.basic_rational_tokenizer.tokenize(text, evidences):
68 | for sub_token in self.wordpiece_tokenizer.tokenize(token):
69 | split_tokens.append(sub_token)
70 | split_rations.append(ration)
71 | return list(zip(split_tokens, split_rations))
72 |
73 |
74 |
--------------------------------------------------------------------------------
/expred/bert_rational_feature.py:
--------------------------------------------------------------------------------
1 | # bert_rational_feature.py
2 | from tqdm import tqdm_notebook
3 | import expred.rationale_tokenization as tokenization
4 | import logging
5 |
6 | IRRATIONAL = 0
7 | RATIONAL = 1
8 |
9 | logger = logging.getLogger('feature converter')
10 | logger.setLevel(logging.INFO)
11 |
12 |
13 | class InputRationalExample(object):
14 | def __init__(self, guid, text_a, text_b=None, label=None, evidences=None):
15 | self.guid = guid
16 | self.text_a = text_a
17 | self.text_b = text_b
18 | self.label = label
19 | self.evidences = evidences
20 |
21 |
22 | class InputRationalFeatures(object):
23 | def __init__(self,
24 | input_ids,
25 | input_mask,
26 | segment_ids,
27 | label_id,
28 | rations=None,
29 | is_real_example=True):
30 | self.rations = rations
31 | self.input_ids = input_ids
32 | self.input_mask = input_mask
33 | self.segment_ids = segment_ids
34 | self.label_id = label_id
35 | self.is_real_example = is_real_example
36 |
37 |
38 | def convert_single_rational_example(ex_index, example, label_list, max_seq_length, tokenizer):
39 |
40 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
41 | while True:
42 | total_length = len(tokens_a) + len(tokens_b)
43 | if total_length <= max_length:
44 | break
45 | if len(tokens_a) > len(tokens_b):
46 | tokens_a.pop()
47 | else:
48 | tokens_b.pop()
49 |
50 | label_map = {}
51 | for (i, label) in enumerate(label_list):
52 | label_map[label] = i
53 |
54 | tokens_a = tokenizer.tokenize(example.text_a)
55 | tokens_b = None # no tokens_b in our tasks
56 | if example.text_b:
57 | tokens_b = tokenizer.tokenize(example.text_b, example.evidences)
58 |
59 | if tokens_b:
60 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
61 | if len(tokens_a) > max_seq_length - 2:
62 | tokens_a = tokens_a[0:(max_seq_length - 2)]
63 |
64 | tokens = []
65 | segment_ids = []
66 | rations = []
67 | tokens.append("[CLS]")
68 | segment_ids.append(0)
69 | rations.append(IRRATIONAL)
70 |
71 | for token, ration in tokens_a:
72 | tokens.append(token)
73 | segment_ids.append(0)
74 | rations.append(ration)
75 | tokens.append("[SEP]")
76 | segment_ids.append(0)
77 | rations.append(IRRATIONAL)
78 |
79 | if tokens_b:
80 | for token, ration in tokens_b:
81 | tokens.append(token)
82 | segment_ids.append(1)
83 | rations.append(ration)
84 | tokens.append("[SEP]")
85 | segment_ids.append(1)
86 | rations.append(IRRATIONAL)
87 |
88 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
89 | input_mask = [1] * len(input_ids)
90 |
91 | while len(input_ids) < max_seq_length:
92 | input_ids.append(0)
93 | input_mask.append(0)
94 | segment_ids.append(0)
95 | rations.append(IRRATIONAL)
96 |
97 | assert len(input_ids) == max_seq_length
98 | assert len(input_mask) == max_seq_length
99 | assert len(segment_ids) == max_seq_length
100 | assert len(rations) == max_seq_length
101 |
102 | label_id = label_map[example.label]
103 |
104 | if ex_index < 5:
105 | logger.info("*** Example ***")
106 | logger.info("guid: %s" % (example.guid))
107 | logger.info("tokens: %s" % " ".join([tokenization.printable_text(x) for x in tokens]))
108 | logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
109 | logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
110 | logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
111 | logger.info("rations: %s" % " ".join([str(x) for x in rations]))
112 | logger.info('')
113 | logger.info("label: %s (id = %d)" % (example.label, label_id))
114 |
115 | feature = InputRationalFeatures(
116 | input_ids=input_ids,
117 | input_mask=input_mask,
118 | segment_ids=segment_ids,
119 | label_id=label_id,
120 | rations=rations,
121 | is_real_example=True)
122 | return feature
123 |
124 |
125 | def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
126 | features = []
127 | for (ex_index, example) in enumerate(tqdm_notebook(examples, desc="Converting examples to features")):
128 | if ex_index % 10000 == 0:
129 | logger.info("Writing example %d of %d" % (ex_index, len(examples)))
130 | feature = convert_single_rational_example(ex_index, example, label_list,
131 | max_seq_length, tokenizer)
132 | features.append(feature)
133 | return features
134 |
--------------------------------------------------------------------------------
/sample_rationales_dataset.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | from functools import reduce
4 |
5 | import argparse
6 | import json
7 | import logging
8 | import random
9 | import shutil
10 | import sys
11 | from math import floor
12 | from pathlib import Path
13 |
14 | logger = logging.getLogger(__name__)
15 | logger.setLevel(logging.INFO)
16 |
17 |
18 | def is_valid_file(parser, arg):
19 | path = Path(arg)
20 | if not path.exists():
21 | parser.error("The file %s does not exist!" % arg)
22 | else:
23 | return path
24 |
25 |
26 | class Range(object):
27 | def __init__(self, start, end):
28 | self.start = start
29 | self.end = end
30 |
31 | def __eq__(self, other):
32 | return self.start <= other <= self.end
33 |
34 | def __contains__(self, item):
35 | return self.__eq__(item)
36 |
37 | def __iter__(self):
38 | yield self
39 |
40 | def __str__(self):
41 | return '[{0},{1}]'.format(self.start, self.end)
42 |
43 |
44 | copy_splits = ['val', 'test']
45 | reduce_splits = ['train']
46 |
47 | def create_output_folder(dataset_dir: Path, keep_rationals_fraction: float, prefix:str):
48 | dataset_name = dataset_dir.stem
49 | datasets_root = dataset_dir.parent
50 | parts = ([prefix] if prefix != '' else []) + [dataset_name, str(keep_rationals_fraction)]
51 | output_path: Path = datasets_root / '_'.join(parts)
52 |
53 | if output_path.exists():
54 | raise IOError(f'Output folder {output_path} already exits!')
55 |
56 | output_path.mkdir(parents=True, exist_ok=False)
57 |
58 | return output_path
59 |
60 | def copy_split(split, input_dir : Path, output_path : Path):
61 | file_name = f'{split}.jsonl'
62 | shutil.copy(input_dir / file_name, output_path / file_name)
63 |
64 | def sample_split(split, input_dir : Path, output_path : Path, keep_rationals_fraction :int):
65 | file_name = f'{split}.jsonl'
66 | with open(input_dir / file_name, 'r') as f:
67 | annotations = [json.loads(line) for line in f]
68 |
69 | n = len(annotations)
70 | n_keep_rationales = floor(n * keep_rationals_fraction)
71 | indices_to_keep = random.sample(range(n), n_keep_rationales)
72 |
73 | def remove_rationales(annotation):
74 | annotation['evidences'] = []
75 | return annotation
76 |
77 | reduces_annotations = [
78 | ann if i in indices_to_keep else remove_rationales(ann)
79 | for i, ann in enumerate(annotations)
80 | ]
81 |
82 | with open(output_path / file_name, 'w') as f:
83 | dumped_annotations = map(json.dumps, reduces_annotations)
84 | f.write('\n'.join(dumped_annotations))
85 |
86 | def sample_dataset(dataset_dir: Path, output_path : Path, keep_rationals_fraction: float,):
87 | logger.info(f'Sampling for keep fraction {keep_rationals_fraction} ')
88 |
89 | logger.info('Copying documents')
90 | # copy documents
91 | if (dataset_dir / 'docs').exists():
92 | shutil.copytree(dataset_dir / 'docs', output_path / 'docs')
93 | elif (dataset_dir / 'docs.jsonl').exists():
94 | shutil.copy(dataset_dir / 'docs.jsonl', output_path / 'docs.jsonl')
95 | else:
96 | raise ValueError(f'No documents found in {dataset_dir}')
97 |
98 | logger.info(f'Copying unchanged split {copy_splits}')
99 |
100 | # copy
101 | for split in copy_splits:
102 | copy_split(split, input_dir=dataset_dir, output_path=output_path)
103 |
104 | logger.info(f'Sampling and writing results for {reduce_splits}')
105 |
106 | # sample and copy
107 | for split in reduce_splits:
108 | sample_split(split, dataset_dir, output_path, keep_rationals_fraction)
109 |
110 | return output_path
111 |
112 | def main(args):
113 | parser = argparse.ArgumentParser(
114 | 'Takes an eraser dataset and samples a given fraction of the annotations to keep the rationales and removes the rest')
115 | parser.add_argument('--dataset_dir', type=lambda x: is_valid_file(parser, x))
116 | parser.add_argument('--prefix', type=str, default='')
117 | parser.add_argument('--keep_rationals_fractions', nargs='+', type=float, choices=Range(0, 1))
118 |
119 | args = parser.parse_args(args)
120 |
121 | logger.info(
122 | f'Running sampling process for fraction {args.keep_rationals_fractions} for dataset {args.dataset_dir}'
123 | )
124 |
125 | keep_rationals_fractions = [1]+sorted(args.keep_rationals_fractions, reverse=True)
126 |
127 | relative_fractions = [keep_rationals_fractions[i]/keep_rationals_fractions[i-1] for i in range(1, len(keep_rationals_fractions))]
128 | input_dir = args.dataset_dir
129 |
130 | for fraction, rel_fraction in zip(keep_rationals_fractions[1:], relative_fractions):
131 | output_path = create_output_folder(args.dataset_dir, fraction, args.prefix)
132 | input_dir = sample_dataset(input_dir, output_path, rel_fraction)
133 |
134 |
135 | if __name__ == '__main__':
136 | main(sys.argv[1:])
137 |
--------------------------------------------------------------------------------
/expred/preprocessing.py:
--------------------------------------------------------------------------------
1 | # preprocessing.py
2 | from itertools import chain
3 |
4 | from copy import deepcopy
5 | from expred.bert_rational_feature import InputRationalExample, convert_examples_to_features
6 | # from config import *
7 | from expred.utils import Evidence
8 | import logging
9 |
10 | IRRATIONAL = 0
11 | RATIONAL = 1
12 |
13 | logger = logging.getLogger('preprocessing.py')
14 | logger.setLevel(logging.INFO)
15 |
16 |
17 | def load_bert_features(data, docs, label_list, max_seq_length, merge_evidences, tokenizer):
18 | input_examples = []
19 | for ann in data:
20 | text_a = ann.query
21 | label = ann.classification
22 | if not merge_evidences:
23 | for ev_group in ann.evidences:
24 | doc_ids = list(set([ev.docid for ev in ev_group]))
25 | sentences = chain.from_iterable(docs[doc_id] for doc_id in doc_ids)
26 | flattened_tokens = chain(*sentences)
27 | text_b = ' '.join(flattened_tokens)
28 | evidences = ev_group
29 | input_examples.append(InputRationalExample(guid=None,
30 | text_a=text_a,
31 | text_b=text_b,
32 | label=label,
33 | evidences=evidences))
34 | if merge_evidences:
35 | docids_to_offsets = dict()
36 | latest_offset = 0
37 | example_evidences = []
38 | text_b_tokens = []
39 | for ev_group in ann.evidences:
40 | for ev in ev_group:
41 | if ev.docid in docids_to_offsets:
42 | offset = docids_to_offsets[ev.docid]
43 | else:
44 | tokens = list(chain.from_iterable(docs[ev.docid]))
45 | docids_to_offsets[ev.docid] = latest_offset
46 | offset = latest_offset
47 | latest_offset += len(tokens)
48 | text_b_tokens += tokens
49 | example_ev = Evidence(text=ev.text,
50 | docid=ev.docid,
51 | start_token=offset + ev.start_token,
52 | end_token=offset + ev.end_token,
53 | start_sentence=ev.start_sentence,
54 | end_sentence=ev.end_sentence)
55 | example_evidences.append(deepcopy(example_ev))
56 | input_examples.append(InputRationalExample(guid=None,
57 | text_b=' '.join(text_b_tokens),
58 | text_a=text_a,
59 | label=label,
60 | evidences=example_evidences))
61 | # print(input_examples[-1].text_b, input_examples[-1].text_a, input_examples[-1].evi)
62 |
63 | features = convert_examples_to_features(input_examples, label_list, max_seq_length, tokenizer)
64 | return features
65 |
66 |
67 | # def convert_bert_features(features, with_label_id, with_rations, exp_output='gru'):
68 | # feature_names = "input_ids input_mask segment_ids".split()
69 | #
70 | # input_ids, input_masks, segment_ids = \
71 | # list(map(lambda x: [getattr(f, x) for f in features], feature_names))
72 | #
73 | # rets = [input_ids, input_masks, segment_ids]
74 | #
75 | # if with_rations:
76 | # feature_names.append('rations')
77 | # rations = [getattr(f, 'rations') for f in features]
78 | # rations = np.array(rations).reshape([-1, MAX_SEQ_LENGTH, 1])
79 | # if exp_output == 'interval':
80 | # rations = np.concatenate([np.zeros((rations.shape[0], 1, 1)),
81 | # rations,
82 | # np.zeros((rations.shape[0], 1, 1))], axis=-2)
83 | # rations = rations[:, 1:, :] - rations[:, :-1, :]
84 | # rations_start = (rations > 0)[:, :-1, :].astype(np.int32)
85 | # rations_end = (rations < 0)[:, 1:, :].astype(np.int32)
86 | # rations = np.concatenate((rations_start, rations_end), axis=-1)
87 | # rets.append(rations)
88 | # else:
89 | # rets.append(None)
90 | #
91 | # if with_label_id:
92 | # feature_names.append('label_id')
93 | # label_id = [getattr(f, 'label_id') for f in features]
94 | # labels = np.array(label_id).reshape(-1, 1)
95 | # rets.append(labels)
96 | # else:
97 | # rets.append(None)
98 | # return rets
99 |
100 |
101 | # def preprocess(data, docs, label_list, dataset_name, max_seq_length, exp_output, merge_evidences, tokenizer):
102 | # features = load_bert_features(data, docs, label_list, max_seq_length, merge_evidences, tokenizer)
103 | #
104 | # with_rations = ('cls' not in dataset_name)
105 | # with_lable_id = ('seq' not in dataset_name)
106 | #
107 | # return convert_bert_features(features, with_lable_id, with_rations, exp_output)
108 |
--------------------------------------------------------------------------------
/expred/tokenizer.py:
--------------------------------------------------------------------------------
1 | from itertools import chain
2 | from typing import List, Dict, Tuple
3 |
4 | import os
5 | import torch
6 | from transformers import BertTokenizer, logger
7 |
8 | from expred.utils import Evidence, Annotation
9 |
10 |
11 | class BertTokenizerWithMapping(BertTokenizer):
12 | def __init__(self, *args, **kwargs):
13 | super(BertTokenizerWithMapping, self).__init__(*args, **kwargs)
14 |
15 | def tokenize_doc(self, doc: List[List[str]],
16 | special_token_map: Dict[str, int]) -> \
17 | Tuple[List[List[str]], List[List[Tuple[int, int]]]]:
18 | """ Tokenizes a document and returns [start, end) spans to map the wordpieces back to their source words"""
19 | sents = []
20 | sent_token_spans = []
21 | for sent in doc:
22 | tokens = []
23 | spans = []
24 | start = 0
25 | for w in sent:
26 | if w in special_token_map:
27 | tokens.append(w)
28 | else:
29 | tokens.extend(super(BertTokenizerWithMapping, self).tokenize(w))
30 | end = len(tokens)
31 | spans.append((start, end))
32 | start = end
33 | sents.append(tokens)
34 | sent_token_spans.append(spans)
35 | return sents, sent_token_spans
36 |
37 | def encode_doc(self,
38 | doc: List[List[str]],
39 | special_token_map) -> List[List[int]]:
40 | # return [list(chain.from_iterable(special_token_map.get(w, tokenizer.encode(w, add_special_tokens=False))
41 | # for w in s)) for s in doc]
42 | return [[special_token_map.get(w, self.convert_tokens_to_ids(w))
43 | for w in s]
44 | for s in doc]
45 |
46 | def _encode_docs_maybe_load_from_cache(self, documents, cache_fname):
47 | if os.path.exists(cache_fname):
48 | logger.info(f'Loading interned documents from {cache_fname}')
49 | (encoded_docs, encoded_doc_token_slides) = torch.load(cache_fname)
50 | else:
51 | tokenizer = self
52 | logger.info(f'Interning documents')
53 | special_token_map = {
54 | 'SEP': tokenizer.sep_token_id,
55 | '[SEP]': tokenizer.sep_token_id,
56 | '[sep]': tokenizer.sep_token_id,
57 | 'UNK': tokenizer.unk_token_id,
58 | '[UNK]': tokenizer.unk_token_id,
59 | '[unk]': tokenizer.unk_token_id,
60 | 'PAD': tokenizer.unk_token_id,
61 | '[PAD]': tokenizer.unk_token_id,
62 | '[pad]': tokenizer.unk_token_id,
63 | }
64 | encoded_docs = {}
65 | encoded_doc_token_slides = {}
66 | for d, doc in documents.items():
67 | tokenized_doc, w_slices = self.tokenize_doc(doc, special_token_map=special_token_map)
68 | encoded_docs[d] = self.encode_doc(tokenized_doc, special_token_map=special_token_map)
69 | encoded_doc_token_slides[d] = w_slices
70 | torch.save((encoded_docs, encoded_doc_token_slides), cache_fname)
71 | return encoded_docs, encoded_doc_token_slides
72 |
73 | def encode_docs(self, documents, cache_dir):
74 | cache_fname = os.path.join(cache_dir, 'preprocessed.pkl')
75 | encoded_docs, encoded_doc_token_slides = self._encode_docs_maybe_load_from_cache(documents, cache_fname)
76 | return encoded_docs, encoded_doc_token_slides
77 |
78 | def encode_annotations(self, annotations):
79 | ret = []
80 | for ann in annotations:
81 | ev_groups = []
82 | for ev_group in ann.evidences:
83 | evs = []
84 | for ev in ev_group:
85 | text = list(chain.from_iterable(self.tokenize(w)
86 | for w in ev.text.split()))
87 | if len(text) == 0:
88 | continue
89 | text = self.encode(text, add_special_tokens=False)
90 | evs.append(Evidence(text=tuple(text),
91 | docid=ev.docid,
92 | start_token=ev.start_token,
93 | end_token=ev.end_token,
94 | start_sentence=ev.start_sentence,
95 | end_sentence=ev.end_sentence))
96 | ev_groups.append(tuple(evs))
97 | query = list(chain.from_iterable(self.tokenize(w)
98 | for w in ann.query.split()))
99 | if len(query) > 0:
100 | query = self.encode(query, add_special_tokens=False)
101 | else:
102 | query = []
103 | ret.append(Annotation(annotation_id=ann.annotation_id,
104 | query=tuple(query),
105 | evidences=frozenset(ev_groups),
106 | classification=ann.classification,
107 | query_type=ann.query_type,
108 | docids=ann.docids))
109 | return ret
110 |
--------------------------------------------------------------------------------
/expred/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Dict, List, Set
3 |
4 | import numpy as np
5 | from gensim.models import KeyedVectors
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn.utils.rnn import pad_sequence, PackedSequence, pack_padded_sequence, pad_packed_sequence
10 |
11 |
12 | @dataclass(eq=True, frozen=True)
13 | class PaddedSequence:
14 | """A utility class for padding variable length sequences mean for RNN input
15 | This class is in the style of PackedSequence from the PyTorch RNN Utils,
16 | but is somewhat more manual in approach. It provides the ability to generate masks
17 | for outputs of the same input dimensions.
18 | The constructor should never be called directly and should only be called via
19 | the autopad classmethod.
20 |
21 | We'd love to delete this, but we pad_sequence, pack_padded_sequence, and
22 | pad_packed_sequence all require shuffling around tuples of information, and some
23 | convenience methods using these are nice to have.
24 | """
25 |
26 | data: torch.Tensor
27 | batch_sizes: torch.Tensor
28 | batch_first: bool = False
29 |
30 | @classmethod
31 | def autopad(cls, data, batch_first: bool = False, padding_value=0, device=None) -> 'PaddedSequence':
32 | # handle tensors of size 0 (single item)
33 | data_ = []
34 | for d in data:
35 | if len(d.size()) == 0:
36 | d = d.unsqueeze(0)
37 | data_.append(d)
38 | padded = pad_sequence(data_, batch_first=batch_first, padding_value=padding_value)
39 | if batch_first:
40 | batch_lengths = torch.LongTensor([len(x) for x in data_])
41 | if any([x == 0 for x in batch_lengths]):
42 | raise ValueError(
43 | "Found a 0 length batch element, this can't possibly be right: {}".format(batch_lengths))
44 | else:
45 | # TODO actually test this codepath
46 | batch_lengths = torch.LongTensor([len(x) for x in data])
47 | return PaddedSequence(padded, batch_lengths, batch_first).to(device=device)
48 |
49 | #@classmethod
50 | #def autopad(cls, data, len_queries, max_length, batch_first, device):
51 |
52 |
53 | def pack_other(self, data: torch.Tensor):
54 | return pack_padded_sequence(data, self.batch_sizes, batch_first=self.batch_first, enforce_sorted=False)
55 |
56 | @classmethod
57 | def from_packed_sequence(cls, ps: PackedSequence, batch_first: bool, padding_value=0) -> 'PaddedSequence':
58 | padded, batch_sizes = pad_packed_sequence(ps, batch_first, padding_value)
59 | return PaddedSequence(padded, batch_sizes, batch_first)
60 |
61 | def cuda(self) -> 'PaddedSequence':
62 | return PaddedSequence(self.data.cuda(), self.batch_sizes.cuda(), batch_first=self.batch_first)
63 |
64 | def to(self, dtype=None, device=None, copy=False, non_blocking=False) -> 'PaddedSequence':
65 | # TODO make to() support all of the torch.Tensor to() variants
66 | return PaddedSequence(
67 | self.data.to(dtype=dtype, device=device, copy=copy, non_blocking=non_blocking),
68 | self.batch_sizes.to(device=device, copy=copy, non_blocking=non_blocking),
69 | batch_first=self.batch_first)
70 |
71 | def mask(self, on=int(0), off=int(0), device='cpu', size=None, dtype=None) -> torch.Tensor:
72 | if size is None:
73 | size = self.data.size()
74 | out_tensor = torch.zeros(*size, dtype=dtype)
75 | # TODO this can be done more efficiently
76 | out_tensor.fill_(off)
77 | # note to self: these are probably less efficient than explicilty populating the off values instead of the on values.
78 | if self.batch_first:
79 | for i, bl in enumerate(self.batch_sizes):
80 | out_tensor[i, :bl] = on
81 | else:
82 | for i, bl in enumerate(self.batch_sizes):
83 | out_tensor[:bl, i] = on
84 | return out_tensor.to(device)
85 |
86 | def unpad(self, other: torch.Tensor) -> List[torch.Tensor]:
87 | out = []
88 | for o, bl in zip(other, self.batch_sizes):
89 | out.append(torch.cat((o[:bl], torch.zeros(max(0, bl-len(o))))))
90 | return out
91 |
92 | def flip(self) -> 'PaddedSequence':
93 | return PaddedSequence(self.data.transpose(0, 1), not self.batch_first, self.padding_value)
94 |
95 |
96 | def extract_embeddings(vocab: Set[str], embedding_file: str, unk_token: str = 'UNK', pad_token: str = 'PAD') -> (
97 | nn.Embedding, Dict[str, int], List[str]):
98 | vocab = vocab | set([unk_token, pad_token])
99 | if embedding_file.endswith('.bin'):
100 | WVs = KeyedVectors.load_word2vec_format(embedding_file, binary=True)
101 |
102 | word_to_vector = dict()
103 | WV_matrix = np.matrix([WVs[v] for v in WVs.vocab.keys()])
104 |
105 | if unk_token not in WVs:
106 | mean_vector = np.mean(WV_matrix, axis=0)
107 | word_to_vector[unk_token] = mean_vector
108 | if pad_token not in WVs:
109 | word_to_vector[pad_token] = np.zeros(WVs.vector_size)
110 |
111 | for v in vocab:
112 | if v in WVs:
113 | word_to_vector[v] = WVs[v]
114 |
115 | interner = dict()
116 | deinterner = list()
117 | vectors = []
118 | count = 0
119 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})):
120 | vector = word_to_vector[word]
121 | vectors.append(np.array(vector))
122 | interner[word] = count
123 | deinterner.append(word)
124 | count += 1
125 | vectors = torch.FloatTensor(np.array(vectors))
126 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token])
127 | embedding.weight.requires_grad = False
128 | return embedding, interner, deinterner
129 | elif embedding_file.endswith('.txt'):
130 | word_to_vector = dict()
131 | vector = []
132 | with open(embedding_file, 'r') as inf:
133 | for line in inf:
134 | contents = line.strip().split()
135 | word = contents[0]
136 | vector = torch.tensor([float(v) for v in contents[1:]]).unsqueeze(0)
137 | word_to_vector[word] = vector
138 | embed_size = vector.size()
139 | if unk_token not in word_to_vector:
140 | mean_vector = torch.cat(list(word_to_vector.values()), dim=0).mean(dim=0)
141 | word_to_vector[unk_token] = mean_vector.unsqueeze(0)
142 | if pad_token not in word_to_vector:
143 | word_to_vector[pad_token] = torch.zeros(embed_size)
144 | interner = dict()
145 | deinterner = list()
146 | vectors = []
147 | count = 0
148 | for word in [pad_token, unk_token] + sorted(list(word_to_vector.keys() - {unk_token, pad_token})):
149 | vector = word_to_vector[word]
150 | vectors.append(vector)
151 | interner[word] = count
152 | deinterner.append(word)
153 | count += 1
154 | vectors = torch.cat(vectors, dim=0)
155 | embedding = nn.Embedding.from_pretrained(vectors, padding_idx=interner[pad_token])
156 | embedding.weight.requires_grad = False
157 | return embedding, interner, deinterner
158 | else:
159 | raise ValueError("Unable to open embeddings file {}".format(embedding_file))
160 |
--------------------------------------------------------------------------------
/expred/eraser_benchmark.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from itertools import chain
4 | from copy import deepcopy
5 |
6 | from expred.bert_rational_feature import InputRationalExample, convert_examples_to_features
7 | from expred.eraser_utils import extract_doc_ids_from_annotations
8 | from expred.utils import Annotation
9 |
10 |
11 | def remove_rations(sentence, annotation):
12 | sentence = sentence.lower().split()
13 | rationales = annotation['rationales'][0]['hard_rationale_predictions']
14 | rationales = [{'end_token': 0, 'start_token': 0}] \
15 | + sorted(rationales, key=lambda x: x['start_token']) \
16 | + [{'start_token': len(sentence), 'end_token': len(sentence)}]
17 | ret = []
18 | for rat_id, rat in enumerate(rationales[:-1]):
19 | ret += ['.'] * (rat['end_token'] - rat['start_token']) \
20 | + sentence[rat['end_token']
21 | : rationales[rat_id + 1]['start_token']]
22 | return ' '.join(ret)
23 |
24 |
25 | def extract_rations(sentence, rationale):
26 | sentence = sentence.lower().split()
27 | rationales = rationale['rationales'][0]['hard_rationale_predictions']
28 | rationales = [{'end_token': 0, 'start_token': 0}] \
29 | + sorted(rationales, key=lambda x: x['start_token']) \
30 | + [{'start_token': len(sentence), 'end_token': len(sentence)}]
31 | ret = []
32 | for rat_id, rat in enumerate(rationales[:-1]):
33 | ret += sentence[rat['start_token']: rat['end_token']] \
34 | + ['.'] * (rationales[rat_id + 1]
35 | ['start_token'] - rat['end_token'])
36 | return ' '.join(ret)
37 |
38 |
39 | def ce_load_bert_features(rationales, docs, label_list, decorate, max_seq_length, gpu_id, tokenizer=None):
40 | input_examples = []
41 | for r_idx, rational in enumerate(rationales):
42 | text_a = rational['query']
43 | docids = rational['docids']
44 | sentences = chain.from_iterable(docs[docid] for docid in docids)
45 | flattened_tokens = chain(*sentences)
46 | text_b = ' '.join(flattened_tokens)
47 | text_b = decorate(text_b, rational)
48 | label = rational['classification']
49 | evidences = None
50 | input_examples.append(InputRationalExample(guid=None,
51 | text_a=text_a,
52 | text_b=text_b,
53 | label=label,
54 | evidences=evidences))
55 | features = convert_examples_to_features(input_examples, label_list, max_seq_length, tokenizer)
56 | return features
57 |
58 |
59 | # def ce_preprocess(rationales, docs, label_list, dataset_name, decorate, max_seq_length, exp_output, gpu_id, tokenizer):
60 | # features = ce_load_bert_features(rationales, docs, label_list, decorate, max_seq_length, gpu_id, tokenizer)
61 | #
62 | # with_rations = ('cls' not in dataset_name)
63 | # with_lable_id = ('seq' not in dataset_name)
64 | #
65 | # return convert_bert_features(features, with_lable_id, with_rations, exp_output)
66 |
67 |
68 | # def get_cls_score(model, rationales, docs, label_list, dataset, decorate, max_seq_length, exp_output, gpu_id, tokenizer):
69 | # rets = ce_preprocess(rationales, docs, label_list, dataset, decorate, max_seq_length, exp_output, gpu_id, tokenizer)
70 | # _input_ids, _input_masks, _segment_ids, _rations, _labels = rets
71 | #
72 | # _inputs = [_input_ids, _input_masks, _segment_ids]
73 | # _pred = model.predict(_inputs)
74 | # return (np.hstack([1 - _pred[0], _pred[0]]))
75 |
76 |
77 | def add_cls_scores(res, cls, c, s, label_list):
78 | res['classification_scores'] = {label_list[0]: cls[0], label_list[1]: cls[1]}
79 | res['comprehensiveness_classification_scores'] = {label_list[0]: c[0], label_list[1]: c[1]}
80 | res['sufficiency_classification_scores'] = {label_list[0]: s[0], label_list[1]: s[1]}
81 | return res
82 |
83 |
84 | def pred_to_exp_mask(exp_pred, count, threshold):
85 | if count is None:
86 | return (np.array(exp_pred).astype(np.float) >= threshold).astype(np.int32)
87 | temp = [(i, p) for i, p in enumerate(exp_pred)]
88 | temp = sorted(temp, key=lambda x: x[1], reverse=True)
89 | ret = np.zeros_like(exp_pred).astype(np.int32)
90 | for i, _ in temp[:count]:
91 | ret[i] = 1
92 | return ret
93 |
94 |
95 | def rational_bits_to_ev_generator(token_list, raw_input_or_docid, exp_pred, hard_selection_count=None,
96 | hard_selection_threshold=0.5):
97 | in_rationale = False
98 | if not isinstance(raw_input_or_docid, Annotation):
99 | docid = raw_input_or_docid
100 | else:
101 | docid = list(extract_doc_ids_from_annotations([raw_input_or_docid]))[0]
102 | ev = {'docid': docid,
103 | 'start_token': -1, 'end_token': -1, 'text': ''}
104 | exp_masks = pred_to_exp_mask(
105 | exp_pred, hard_selection_count, hard_selection_threshold)
106 | for i, p in enumerate(exp_masks):
107 | if p == 0 and in_rationale: # leave rational zone
108 | in_rationale = False
109 | ev['end_token'] = i
110 | ev['text'] = ' '.join(
111 | token_list[ev['start_token']: ev['end_token']])
112 | yield deepcopy(ev)
113 | elif p == 1 and not in_rationale: # enter rational zone
114 | in_rationale = True
115 | ev['start_token'] = i
116 | if in_rationale: # the final non-padding token is rational
117 | ev['end_token'] = len(exp_pred)
118 | ev['text'] = ' '.join(token_list[ev['start_token']: ev['end_token']])
119 | yield deepcopy(ev)
120 |
121 |
122 | # [SEP] == 102
123 | # [CLS] == 101
124 | # [PAD] == 0
125 | def extract_texts(tokens, exps=None, text_a=True, text_b=False):
126 | if tokens[0] == 101:
127 | endp_text_a = tokens.index(102)
128 | if text_b:
129 | endp_text_b = endp_text_a + 1 + \
130 | tokens[endp_text_a + 1:].index(102)
131 | else:
132 | endp_text_a = tokens.index('[SEP]')
133 | if text_b:
134 | endp_text_b = endp_text_a + 1 + \
135 | tokens[endp_text_a + 1:].index('[SEP]')
136 | ret_token = []
137 | if text_a:
138 | ret_token += tokens[1: endp_text_a]
139 | if text_b:
140 | ret_token += tokens[endp_text_a + 1: endp_text_b]
141 | if exps is None:
142 | return ret_token
143 | else:
144 | ret_exps = []
145 | if text_a:
146 | ret_exps += exps[1: endp_text_a]
147 | if text_b:
148 | ret_exps += exps[endp_text_a + 1: endp_text_b]
149 | return ret_token, ret_exps
150 |
151 |
152 | def rnr_matrix_to_rational_mask(rnr_matrix):
153 | start_logits, end_logits = rnr_matrix[:, :1], rnr_matrix[:, 1:]
154 | starts = np.round(start_logits).reshape((-1, 1))
155 | ends = np.triu(end_logits)
156 | ends = starts * ends
157 | ends_args = np.argmax(ends, axis=1)
158 | ends = np.zeros_like(ends)
159 | for i in range(len(ends_args)):
160 | ends[i, ends_args[i]] = 1
161 | ends = starts * ends
162 | ends = np.sum(ends, axis=0, keepdims=True)
163 | rational_mask = np.cumsum(starts.reshape((1, -1)), axis=1) - np.cumsum(ends, axis=1) + ends
164 | return rational_mask
165 |
166 |
167 | # def pred_to_results(raw_input, input_ids, pred,
168 | # hard_selection_count, hard_selection_threshold,
169 | # vocab, docs, label_list,
170 | # exp_output):
171 | # cls_pred, exp_pred = pred
172 | # if exp_output == 'rnr':
173 | # exp_pred = rnr_matrix_to_rational_mask(exp_pred)
174 | # exp_pred = exp_pred.reshape((-1,)).tolist()
175 | # docid = list(raw_input.evidences)[0][0].docid
176 | # raw_sentence = ' '.join(list(chain.from_iterable(docs[docid])))
177 | # raw_sentence = re.sub('\x12', '', raw_sentence)
178 | # raw_sentence = raw_sentence.lower().split()
179 | # token_ids, exp_pred = extract_texts(input_ids, exp_pred, text_a=False, text_b=True)
180 | # token_list, exp_pred = convert_subtoken_ids_to_tokens(token_ids, vocab, exp_pred, raw_sentence)
181 | # result = {'annotation_id': raw_input.annotation_id, 'query': raw_input.query}
182 | # ev_groups = []
183 | # result['docids'] = [docid]
184 | # result['rationales'] = [{'docid': docid}]
185 | # for ev in rational_bits_to_ev_generator(token_list, raw_input, exp_pred, hard_selection_count,
186 | # hard_selection_threshold):
187 | # ev_groups.append(ev)
188 | # result['rationales'][-1]['hard_rationale_predictions'] = ev_groups
189 | # if exp_output != 'rnr':
190 | # result['rationales'][-1]['soft_rationale_predictions'] = exp_pred + [0] * (len(raw_sentence) - len(token_list))
191 | # result['classification'] = label_list[int(round(cls_pred[0]))]
192 | # return result
193 |
--------------------------------------------------------------------------------
/expred/models/mlp_mtl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import BertModel, BertTokenizer
5 | from expred.params import MTLParams
6 | from typing import Any, List
7 | from expred.models.model_utils import PaddedSequence
8 |
9 |
10 | class BertMTL(nn.Module):
11 | def __init__(self,
12 | bert_dir: str,
13 | tokenizer: BertTokenizer,
14 | mtl_params: MTLParams,
15 | max_length: int=512,
16 | use_half_precision=True):
17 | super(BertMTL, self).__init__()
18 | bare_bert = BertModel.from_pretrained(bert_dir)
19 | if use_half_precision:
20 | import apex
21 | bare_bert = bare_bert.half()
22 | self.bare_bert = bare_bert
23 | self.pad_token_id = tokenizer.pad_token_id
24 | self.cls_token_id = tokenizer.cls_token_id
25 | self.sep_token_id = tokenizer.sep_token_id
26 | self.max_length = max_length
27 |
28 | class ExpHead(nn.Module):
29 | def __init__(self, input_size, hidden_size):
30 | super(ExpHead, self).__init__()
31 | self.exp_gru = nn.GRU(input_size, hidden_size)
32 | self.exp_linear = nn.Linear(hidden_size, 1, bias=True)
33 | self.exp_act = nn.Sigmoid()
34 |
35 | def forward(self, x):
36 | return self.exp_act(self.exp_linear(self.exp_gru(x)[0]))
37 |
38 | #self.exp_head = lambda x: exp_act(exp_linear(exp_gru(x)[0]))
39 | self.exp_head = ExpHead(self.bare_bert.config.hidden_size, mtl_params.dim_exp_gru)
40 | self.cls_head = nn.Sequential(
41 | nn.Dropout(0.1),
42 | nn.Linear(self.bare_bert.config.hidden_size, mtl_params.dim_cls_linear, bias=True),
43 | nn.Tanh(),
44 | nn.Linear(mtl_params.dim_cls_linear, mtl_params.num_labels, bias=True),
45 | nn.Softmax(dim=-1)
46 | )
47 | for layer in self.cls_head:
48 | if type(layer) == nn.Linear:
49 | nn.init.xavier_normal_(layer.weight)
50 | nn.init.xavier_normal_(self.exp_head.exp_linear.weight)
51 |
52 | def forward(self,
53 | query: List[torch.tensor],
54 | docids: List[Any],
55 | document_batch: List[torch.tensor]):
56 | #input_ids, token_type_ids=None, attention_mask=None, labels=None):
57 | assert len(query) == len(document_batch)
58 | #print(next(self.cls_head.parameters()).device)
59 | target_device = next(self.parameters()).device
60 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
61 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
62 | input_tensors = []
63 | for q, d in zip(query, document_batch):
64 | if len(q) + len(d) + 2 > self.max_length:
65 | d = d[:(self.max_length - len(q) - 2)]
66 | input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
67 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id,
68 | device=target_device)
69 | attention_mask = bert_input.mask(on=1., off=0., device=target_device)
70 | exp_output, cls_output = self.bare_bert(bert_input.data, attention_mask=attention_mask)
71 | exp_output = self.exp_head(exp_output).squeeze() * attention_mask
72 | cls_output = self.cls_head(cls_output)
73 | assert torch.all(cls_output == cls_output)
74 | assert torch.all(exp_output == exp_output)
75 | return cls_output, exp_output, attention_mask
76 |
77 |
78 | class BertClassifier(nn.Module):
79 | """Thin wrapper around BertForSequenceClassification"""
80 | def __init__(self,
81 | bert_dir: str,
82 | pad_token_id: int,
83 | cls_token_id: int,
84 | sep_token_id: int,
85 | num_labels: int,
86 | mtl_params: MTLParams,
87 | max_length: int=512,
88 | use_half_precision=True):
89 | super(BertClassifier, self).__init__()
90 | bert = BertModel.from_pretrained(bert_dir, num_labels=num_labels)
91 | if use_half_precision:
92 | import apex
93 | bert = bert.half()
94 | self.bert = bert
95 | self.cls_head = nn.Sequential(
96 | nn.Dropout(0.1),
97 | nn.Linear(bert.config.hidden_size, mtl_params.dim_cls_linear, bias=True),
98 | nn.Tanh(),
99 | nn.Linear(mtl_params.dim_cls_linear, mtl_params.num_labels, bias=True),
100 | nn.Softmax(dim=-1)
101 | )
102 | for layer in self.cls_head:
103 | if type(layer) == nn.Linear:
104 | nn.init.xavier_normal_(layer.weight)
105 | self.pad_token_id = pad_token_id
106 | self.cls_token_id = cls_token_id
107 | self.sep_token_id = sep_token_id
108 | self.max_length = max_length
109 |
110 | def forward(self,
111 | query: List[torch.tensor],
112 | docids: List[Any],
113 | document_batch: List[torch.tensor]):
114 | assert len(query) == len(document_batch)
115 | # note about device management:
116 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
117 | # we want to keep these conf on the input device (assuming CPU) for as long as possible for cheap memory access
118 | target_device = next(self.parameters()).device
119 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
120 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
121 | input_tensors = []
122 | position_ids = []
123 | for q, d in zip(query, document_batch):
124 | if len(q) + len(d) + 2 > self.max_length:
125 | d = d[:(self.max_length - len(q) - 2)]
126 | input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
127 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
128 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device)
129 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
130 | _, classes = self.bert(bert_input.data, attention_mask=bert_input.mask(on=1., off=0., device=target_device), position_ids=positions.data)
131 | classes = self.cls_head(classes)
132 | assert torch.all(classes == classes) # for nans
133 | return classes
134 |
135 |
136 | class BertClassifier2(nn.Module):
137 | """Thin wrapper around BertForSequenceClassification"""
138 | def __init__(self,
139 | bert_dir: str,
140 | pad_token_id: int,
141 | cls_token_id: int,
142 | sep_token_id: int,
143 | num_labels: int,
144 | mtl_params: MTLParams,
145 | max_length: int=512,
146 | use_half_precision=True):
147 | super(BertClassifier2, self).__init__()
148 | bert = BertModel.from_pretrained(bert_dir, num_labels=num_labels)
149 | if use_half_precision:
150 | import apex
151 | bert = bert.half()
152 | self.bert = bert
153 | self.cls_head = nn.Sequential(
154 | #nn.Dropout(0.1),
155 | nn.Linear(bert.config.hidden_size, mtl_params.dim_cls_linear, bias=True),
156 | nn.Tanh(),
157 | nn.Linear(mtl_params.dim_cls_linear, mtl_params.num_labels, bias=True),
158 | nn.Softmax(dim=-1)
159 | )
160 | for layer in self.cls_head:
161 | if type(layer) == nn.Linear:
162 | nn.init.xavier_normal_(layer.weight)
163 | self.pad_token_id = pad_token_id
164 | self.cls_token_id = cls_token_id
165 | self.sep_token_id = sep_token_id
166 | self.max_length = max_length
167 |
168 | def forward(self,
169 | query: List[torch.tensor],
170 | docids: List[Any],
171 | document_batch: List[torch.tensor]):
172 | assert len(query) == len(document_batch)
173 | # note about device management:
174 | # since distributed training is enabled, the inputs to this module can be on *any* device (preferably cpu, since we wrap and unwrap the module)
175 | # we want to keep these conf on the input device (assuming CPU) for as long as possible for cheap memory access
176 | target_device = next(self.parameters()).device
177 | cls_token = torch.tensor([self.cls_token_id]).to(device=document_batch[0].device)
178 | sep_token = torch.tensor([self.sep_token_id]).to(device=document_batch[0].device)
179 | input_tensors = []
180 | position_ids = []
181 | for q, d in zip(query, document_batch):
182 | if len(q) + len(d) + 2 > self.max_length:
183 | d = d[:(self.max_length - len(q) - 2)]
184 | input_tensors.append(torch.cat([cls_token, q, sep_token, d]))
185 | position_ids.append(torch.tensor(list(range(0, len(q) + 1)) + list(range(0, len(d) + 1))))
186 | bert_input = PaddedSequence.autopad(input_tensors, batch_first=True, padding_value=self.pad_token_id, device=target_device)
187 | positions = PaddedSequence.autopad(position_ids, batch_first=True, padding_value=0, device=target_device)
188 | _, classes = self.bert(bert_input.data, attention_mask=bert_input.mask(on=1., off=0., device=target_device), position_ids=positions.data)
189 | classes = self.cls_head(classes)
190 | assert torch.all(classes == classes) # for nans
191 | return classes
--------------------------------------------------------------------------------
/expred/models/pipeline/mtl_evidence_classifier.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import random
4 |
5 | from collections import OrderedDict
6 |
7 | import wandb
8 | from typing import Dict, List, Tuple, Any
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from sklearn.metrics import accuracy_score, classification_report
14 |
15 | # from expred.utils import Annotation
16 |
17 | from expred.models.pipeline.pipeline_utils import SentenceEvidence
18 | from expred.models.pipeline.mtl_pipeline_utils import (
19 | mask_annotations_to_evidence_classification,
20 | make_mtl_classification_preds_epoch
21 | )
22 |
23 |
24 | def train_mtl_evidence_classifier(evidence_classifier: nn.Module,
25 | save_dir: str,
26 | train: Tuple[List[Tuple[str, SentenceEvidence]], Any],
27 | val: Tuple[List[Tuple[str, SentenceEvidence]], Any],
28 | documents: Dict[str, List[List[int]]],
29 | model_pars: dict,
30 | class_interner: Dict[str, int],
31 | optimizer=None,
32 | scheduler=None,
33 | tensorize_model_inputs: bool = True) -> Tuple[nn.Module, dict]:
34 | """
35 |
36 | :param evidence_classifier:
37 | :param save_dir:
38 | :param train:
39 | :param val:
40 | :param documents:
41 | :param model_pars:
42 | :param class_interner:
43 | :param optimizer:
44 | :param scheduler:
45 | :param tensorize_model_inputs:
46 | :return:
47 | """
48 | logging.info(
49 | f'Beginning training evidence classifier with {len(train[0])} annotations, {len(val[0])} for validation')
50 | # set up output directories
51 | evidence_classifier_output_dir = os.path.join(save_dir, 'evidence_classifier')
52 | os.makedirs(save_dir, exist_ok=True)
53 | os.makedirs(evidence_classifier_output_dir, exist_ok=True)
54 | model_save_file = os.path.join(evidence_classifier_output_dir, 'evidence_classifier.pt')
55 | epoch_save_file = os.path.join(evidence_classifier_output_dir, 'evidence_classifier_epoch_data.pt')
56 |
57 | # set up training (optimizer, loss, patience, ...)
58 | device = next(evidence_classifier.parameters()).device
59 | if optimizer is None:
60 | optimizer = torch.optim.Adam(evidence_classifier.parameters(), lr=model_pars['evidence_classifier']['lr'])
61 | criterion = nn.BCELoss(reduction='none')
62 | batch_size = model_pars['evidence_classifier']['batch_size']
63 | epochs = model_pars['evidence_classifier']['epochs']
64 | patience = model_pars['evidence_classifier']['patience']
65 | max_grad_norm = model_pars['evidence_classifier'].get('max_grad_norm', None)
66 |
67 | # mask out the hard prediction (token 0) and convert to [SentenceEvidence...]
68 | evidence_train_data = mask_annotations_to_evidence_classification(train, class_interner)
69 | evidence_val_data = mask_annotations_to_evidence_classification(val, class_interner)
70 |
71 | class_labels = [k for k, v in sorted(class_interner.items())]
72 |
73 | results = {
74 | 'train_loss': [],
75 | 'train_f1': [],
76 | 'train_acc': [],
77 | 'val_loss': [],
78 | 'val_f1': [],
79 | 'val_acc': [],
80 | }
81 | best_epoch = -1
82 | best_val_loss = float('inf')
83 | best_model_state_dict = None
84 | start_epoch = 0
85 | epoch_data = {}
86 | if os.path.exists(epoch_save_file):
87 | logging.info(f'Restoring model from {model_save_file}')
88 | evidence_classifier.load_state_dict(torch.load(model_save_file))
89 | epoch_data = torch.load(epoch_save_file)
90 | start_epoch = epoch_data['epoch'] + 1
91 | # handle finishing because patience was exceeded or we didn't get the best final epoch
92 | if bool(epoch_data.get('done', 0)):
93 | start_epoch = epochs
94 | results = epoch_data['results']
95 | best_epoch = start_epoch
96 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()})
97 | logging.info(f'Restoring training from epoch {start_epoch}')
98 | logging.info(f'Training evidence classifier from epoch {start_epoch} until epoch {epochs}')
99 | optimizer.zero_grad()
100 | for epoch in range(start_epoch, epochs):
101 | epoch_train_data = random.sample(evidence_train_data, k=len(evidence_train_data))
102 | epoch_val_data = random.sample(evidence_val_data, k=len(evidence_val_data))
103 | epoch_train_loss = 0
104 | evidence_classifier.train()
105 | logging.info(
106 | f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples')
107 | for batch_start in range(0, len(epoch_train_data), batch_size):
108 | batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))]
109 | targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements])
110 | ids = [(s.ann_id, s.docid, s.index) for s in batch_elements]
111 | targets = [[i == target for i in range(len(class_interner))] for target in targets]
112 | targets = torch.tensor(targets, dtype=torch.float, device=device)
113 | if tensorize_model_inputs:
114 | queries = [torch.tensor(q, dtype=torch.long) for q in queries]
115 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
116 | preds = evidence_classifier(queries, ids, sentences)
117 | loss = criterion(preds, targets.to(device=preds.device)).sum()
118 | epoch_train_loss += loss.item()
119 | loss = loss / len(preds) # accumulate entire loss above
120 | loss.backward()
121 | assert loss == loss # for nans
122 | if max_grad_norm:
123 | torch.nn.utils.clip_grad_norm_(evidence_classifier.parameters(), max_grad_norm)
124 | optimizer.step()
125 | if scheduler:
126 | scheduler.step()
127 | optimizer.zero_grad()
128 | epoch_train_loss /= len(epoch_train_data)
129 | assert epoch_train_loss == epoch_train_loss # for nans
130 | results['train_loss'].append(epoch_train_loss)
131 | logging.info(f'Epoch {epoch} training loss {epoch_train_loss}')
132 |
133 | with torch.no_grad():
134 | evidence_classifier.eval()
135 | epoch_train_loss, \
136 | epoch_train_soft_pred, \
137 | epoch_train_hard_pred, \
138 | epoch_train_truth = make_mtl_classification_preds_epoch(
139 | classifier=evidence_classifier,
140 | data=epoch_train_data,
141 | class_interner=class_interner,
142 | batch_size=batch_size,
143 | device=device,
144 | criterion=criterion,
145 | tensorize_model_inputs=tensorize_model_inputs)
146 | results['train_f1'].append(
147 | classification_report(epoch_train_truth, epoch_train_hard_pred, target_names=class_labels,
148 | labels=list(range(len(class_labels))), output_dict=True))
149 | results['train_acc'].append(accuracy_score(epoch_train_truth, epoch_train_hard_pred))
150 | epoch_val_loss, \
151 | epoch_val_soft_pred, \
152 | epoch_val_hard_pred, \
153 | epoch_val_truth = make_mtl_classification_preds_epoch(
154 | classifier=evidence_classifier,
155 | data=epoch_val_data,
156 | class_interner=class_interner,
157 | batch_size=batch_size,
158 | device=device,
159 | criterion=criterion,
160 | tensorize_model_inputs=tensorize_model_inputs)
161 | results['val_loss'].append(epoch_val_loss)
162 | results['val_f1'].append(
163 | classification_report(epoch_val_truth, epoch_val_hard_pred, target_names=class_labels,
164 | labels=list(range(len(class_labels))), output_dict=True))
165 | results['val_acc'].append(accuracy_score(epoch_val_truth, epoch_val_hard_pred))
166 | assert epoch_val_loss == epoch_val_loss # for nans
167 | logging.info(f'Epoch {epoch} val loss {epoch_val_loss}')
168 | logging.info(f'Epoch {epoch} val acc {results["val_acc"][-1]}')
169 | logging.info(f'Epoch {epoch} val f1 {results["val_f1"][-1]}')
170 |
171 | epoch_metrics = {metric: values[-1] for metric, values in results.items()}
172 | epoch_metrics['epoch'] = epoch
173 | wandb.log(epoch_metrics)
174 |
175 | if epoch_val_loss < best_val_loss:
176 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in evidence_classifier.state_dict().items()})
177 | best_epoch = epoch
178 | best_val_loss = epoch_val_loss
179 | epoch_data = {
180 | 'epoch': epoch,
181 | 'results': results,
182 | 'best_val_loss': best_val_loss,
183 | 'done': 0,
184 | }
185 | torch.save(evidence_classifier.state_dict(), model_save_file)
186 | torch.save(epoch_data, epoch_save_file)
187 | logging.debug(f'Epoch {epoch} new best model with val loss {epoch_val_loss}')
188 | if epoch - best_epoch > patience:
189 | logging.info(f'Exiting after epoch {epoch} due to no improvement')
190 | epoch_data['done'] = 1
191 | torch.save(epoch_data, epoch_save_file)
192 | break
193 |
194 | epoch_data['done'] = 1
195 | epoch_data['results'] = results
196 | torch.save(epoch_data, epoch_save_file)
197 | evidence_classifier.load_state_dict(best_model_state_dict)
198 | evidence_classifier = evidence_classifier.to(device=device)
199 | evidence_classifier.eval()
200 | return evidence_classifier, results
201 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
13 |
14 |
15 |
174 |
175 |
176 |
--------------------------------------------------------------------------------
/expred/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import pickle
3 | import os
4 | import re
5 |
6 | from dataclasses import dataclass, asdict, is_dataclass
7 | from itertools import chain
8 | from typing import Dict, List, Set, Tuple, Union, FrozenSet
9 |
10 | import transformers
11 |
12 |
13 |
14 | # tensorflow compatibility import, uncomment if needed
15 | # import tensorflow
16 | # if tensorflow.__version__.startswith('2'):
17 | # import tensorflow.compat.v1 as tf
18 | #
19 | # tf.disable_v2_behavior()
20 | # else:
21 | # import tensorflow as tf
22 |
23 |
24 | @dataclass(eq=True, frozen=True)
25 | class Evidence:
26 | """
27 | (docid, start_token, end_token) form the only official Evidence; sentence level annotations are for convenience.
28 | Args:
29 | text: Some representation of the evidence text
30 | docid: Some identifier for the document
31 | start_token: The canonical start token, inclusive
32 | end_token: The canonical end token, exclusive
33 | start_sentence: Best guess start sentence, inclusive
34 | end_sentence: Best guess end sentence, exclusive
35 | """
36 | text: Union[str, Tuple[int], Tuple[str]]
37 | docid: str
38 | start_token: int = -1
39 | end_token: int = -1
40 | start_sentence: int = -1
41 | end_sentence: int = -1
42 |
43 |
44 | @dataclass(eq=True, frozen=True)
45 | class Annotation:
46 | """
47 | Args:
48 | annotation_id: unique ID for this annotation element
49 | query: some representation of a query string
50 | evidences: a set of "evidence groups".
51 | Each evidence group is:
52 | * sufficient to respond to the query (or justify an answer)
53 | * composed of one or more Evidences
54 | * may have multiple documents in it (depending on the dataset)
55 | - e-snli has multiple documents
56 | - other datasets do not
57 | classification: str
58 | query_type: Optional str, additional information about the query
59 | docids: a set of docids in which one may find evidence.
60 | """
61 | annotation_id: str
62 | query: Union[str, Tuple[int]]
63 | evidences: Union[Set[Tuple[Evidence]], FrozenSet[Tuple[Evidence]]]
64 | classification: str
65 | query_type: str = None
66 | docids: Set[str] = None
67 |
68 | def all_evidences(self) -> Tuple[Evidence]:
69 | return tuple(list(chain.from_iterable(self.evidences)))
70 |
71 |
72 | def annotations_to_jsonl(annotations, output_file):
73 | with open(output_file, 'w') as of:
74 | for ann in sorted(annotations, key=lambda x: x.annotation_id):
75 | as_json = _annotation_to_dict(ann)
76 | as_str = json.dumps(as_json, sort_keys=True)
77 | of.write(as_str)
78 | of.write('\n')
79 |
80 |
81 | def _annotation_to_dict(dc):
82 | # convenience method
83 | if is_dataclass(dc):
84 | d = asdict(dc)
85 | ret = dict()
86 | for k, v in d.items():
87 | ret[k] = _annotation_to_dict(v)
88 | return ret
89 | elif isinstance(dc, dict):
90 | ret = dict()
91 | for k, v in dc.items():
92 | k = _annotation_to_dict(k)
93 | v = _annotation_to_dict(v)
94 | ret[k] = v
95 | return ret
96 | elif isinstance(dc, str):
97 | return dc
98 | elif isinstance(dc, (set, frozenset, list, tuple)):
99 | ret = []
100 | for x in dc:
101 | ret.append(_annotation_to_dict(x))
102 | return tuple(ret)
103 | else:
104 | return dc
105 |
106 |
107 | def load_jsonl(fp: str) -> List[dict]:
108 | ret = []
109 | with open(fp, 'r') as inf:
110 | for line in inf:
111 | content = json.loads(line)
112 | ret.append(content)
113 | return ret
114 |
115 |
116 | def write_jsonl(jsonl, output_file):
117 | with open(output_file, 'w') as of:
118 | for js in jsonl:
119 | as_str = json.dumps(js, sort_keys=True)
120 | of.write(as_str)
121 | of.write('\n')
122 |
123 |
124 | def annotations_from_jsonl(fp: str) -> List[Annotation]:
125 | ret = []
126 | with open(fp, 'r') as inf:
127 | for line in inf:
128 | content = json.loads(line)
129 | ev_groups = []
130 | for ev_group in content['evidences']:
131 | ev_group = tuple([Evidence(**ev) for ev in ev_group])
132 | ev_groups.append(ev_group)
133 | content['evidences'] = frozenset(ev_groups)
134 | ret.append(Annotation(**content))
135 | return ret
136 |
137 |
138 | def decorate_with_docs_ids(annotation : Annotation):
139 | """
140 | Extracts the docids if not already set from the annotations_id.
141 | Warning does not work for esnli. For esnli_flat it works!
142 | :param annotation:
143 | :return: annotation
144 | """
145 | if annotation.docids is not None and len(annotation.docids) > 0:
146 | return annotation
147 |
148 | docids = [annotation.annotation_id]
149 | new_annotation = asdict(annotation)
150 | new_annotation['docids'] = docids
151 | assert docids is not None
152 | return Annotation(**new_annotation)
153 |
154 | def load_datasets(data_dir: str) -> Tuple[List[Annotation], List[Annotation], List[Annotation]]:
155 | """Loads a training, validation, and test dataset
156 |
157 | Each dataset is assumed to have been serialized by annotations_to_jsonl,
158 | that is it is a list of json-serialized Annotation instances.
159 | """
160 | train_data = annotations_from_jsonl(os.path.join(data_dir, 'train.jsonl'))
161 | val_data = annotations_from_jsonl(os.path.join(data_dir, 'val.jsonl'))
162 | test_data = annotations_from_jsonl(os.path.join(data_dir, 'test.jsonl'))
163 |
164 | splits = train_data, val_data, test_data
165 | splits = tuple([decorate_with_docs_ids(ann) for ann in split] for split in splits)
166 | return splits
167 |
168 |
169 | def load_documents(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]:
170 | """Loads a subset of available documents from disk.
171 |
172 | Each document is assumed to be serialized as newline ('\n') separated sentences.
173 | Each sentence is assumed to be space (' ') joined tokens.
174 | """
175 | if os.path.exists(os.path.join(data_dir, 'docs.jsonl')):
176 | assert not os.path.exists(os.path.join(data_dir, 'docs'))
177 | return load_documents_from_file(data_dir, docids)
178 |
179 | docs_dir = os.path.join(data_dir, 'docs')
180 | res = dict()
181 | if docids is None:
182 | docids = sorted(os.listdir(docs_dir))
183 | else:
184 | docids = sorted(set(str(d) for d in docids))
185 | for d in docids:
186 | with open(os.path.join(docs_dir, d), 'r') as inf:
187 | lines = [l.strip() for l in inf.readlines()]
188 | lines = list(filter(lambda x: bool(len(x)), lines))
189 | tokenized = [list(filter(lambda x: bool(len(x)), line.strip().split(' '))) for line in lines]
190 | res[d] = tokenized
191 | return res
192 |
193 |
194 | def load_flattened_documents(data_dir: str, docids: Set[str]) -> Dict[str, List[str]]:
195 | """Loads a subset of available documents from disk.
196 |
197 | Returns a tokenized version of the document.
198 | """
199 | unflattened_docs = load_documents(data_dir, docids)
200 | flattened_docs = dict()
201 | for doc, unflattened in unflattened_docs.items():
202 | flattened_docs[doc] = list(chain.from_iterable(unflattened))
203 | return flattened_docs
204 |
205 |
206 | def intern_documents(documents: Dict[str, List[List[str]]], word_interner: Dict[str, int], unk_token: str):
207 | """
208 | Replaces every word with its index in an embeddings file.
209 |
210 | If a word is not found, uses the unk_token instead
211 | """
212 | ret = dict()
213 | unk = word_interner[unk_token]
214 | for docid, sentences in documents.items():
215 | ret[docid] = [[word_interner.get(w, unk) for w in s] for s in sentences]
216 | return ret
217 |
218 |
219 | def intern_annotations(annotations: List[Annotation], word_interner: Dict[str, int], unk_token: str):
220 | ret = []
221 | for ann in annotations:
222 | ev_groups = []
223 | for ev_group in ann.evidences:
224 | evs = []
225 | for ev in ev_group:
226 | evs.append(Evidence(
227 | text=tuple([word_interner.get(t, word_interner[unk_token]) for t in ev.text.split()]),
228 | docid=ev.docid,
229 | start_token=ev.start_token,
230 | end_token=ev.end_token,
231 | start_sentence=ev.start_sentence,
232 | end_sentence=ev.end_sentence))
233 | ev_groups.append(tuple(evs))
234 | ret.append(Annotation(annotation_id=ann.annotation_id,
235 | query=tuple([word_interner.get(t, word_interner[unk_token]) for t in ann.query.split()]),
236 | evidences=frozenset(ev_groups),
237 | classification=ann.classification,
238 | query_type=ann.query_type))
239 | return ret
240 |
241 |
242 | def load_documents_from_file(data_dir: str, docids: Set[str] = None) -> Dict[str, List[List[str]]]:
243 | """Loads a subset of available documents from 'docs.jsonl' file on disk.
244 |
245 | Each document is assumed to be serialized as newline ('\n') separated sentences.
246 | Each sentence is assumed to be space (' ') joined tokens.
247 | """
248 | docs_file = os.path.join(data_dir, 'docs.jsonl')
249 | documents = load_jsonl(docs_file)
250 | documents = {doc['docid']: doc['document'] for doc in documents}
251 | res = dict()
252 | if docids is None:
253 | docids = sorted(list(documents.keys()))
254 | else:
255 | docids = sorted(set(str(d) for d in docids))
256 | for d in docids:
257 | lines = documents[d].split('\n')
258 | tokenized = [line.strip().split(' ') for line in lines]
259 | res[d] = tokenized
260 | return res
261 |
262 |
263 | NEG = 0
264 | POS = 1
265 |
266 | pattern = re.compile('?(POS)?(NEG)?>')
267 |
268 |
269 | def cache_decorator(*dump_fnames):
270 | def excution_decorator(func):
271 | def wrapper(*args, **kwargs):
272 | if len(dump_fnames) == 1:
273 | dump_fname = dump_fnames[0]
274 | if not os.path.isfile(dump_fname):
275 | ret = func(*args, **kwargs)
276 | with open(dump_fname, 'wb') as fdump:
277 | pickle.dump(ret, fdump)
278 | return ret
279 |
280 | with open(dump_fname, 'rb') as fdump:
281 | ret = pickle.load(fdump)
282 | return ret
283 |
284 | rets = None
285 | for fname in dump_fnames:
286 | if not os.path.isfile(fname):
287 | rets = func(*args, **kwargs)
288 | break
289 | if rets is not None:
290 | for r, fname in zip(rets, dump_fnames):
291 | with open(fname, 'wb') as fdump:
292 | pickle.dump(r, fdump)
293 | return rets
294 |
295 | rets = []
296 | for fname in dump_fnames:
297 | with open(fname, 'rb') as fdump:
298 | rets.append(pickle.load(fdump))
299 | return tuple(rets)
300 |
301 | return wrapper
302 |
303 | return excution_decorator
304 |
305 |
306 | def convert_subtoken_ids_to_tokens(ids:List[int],
307 | tokenizer:transformers.BertTokenizer,
308 | token_mapping=None,
309 | exps=None,
310 | raw_sentence=None):
311 | subtokens = tokenizer.convert_ids_to_tokens(ids)
312 | tokens, exps_outputs = [], []
313 | if not isinstance(exps[0], list):
314 | exps = [exps]
315 | exps_inputs = [[0] * len(ids)] if exps is None else exps
316 | raw_sentence = subtokens if raw_sentence is None else raw_sentence
317 | subtokens = list(reversed([t[2:] if t.startswith('##') else t for t in subtokens]))
318 | if token_mapping is None:
319 | exps_inputs = list(zip(*(list(reversed(e)) for e in exps_inputs)))
320 | for ref_token in raw_sentence:
321 | t, es = '', [0] * len(exps_inputs[0])
322 | while t != ref_token and len(subtokens) > 0:
323 | t += subtokens.pop()
324 | es = [max(old, new) for old, new in zip(es, exps_inputs.pop())]
325 | tokens.append(t)
326 | exps_outputs.append(es)
327 | if len(subtokens) == 0:
328 | # the last sub-token is incomplete, ditch it directly
329 | if ref_token != tokens[-1]:
330 | tokens = tokens[:-1]
331 | exps_outputs = exps_outputs[:-1]
332 | break
333 | else:
334 | hard_rats, soft_rats = exps
335 | for ref_token_idx, (token_piece_start, token_piece_end) in enumerate(token_mapping):
336 | if token_piece_start >= len(hard_rats):
337 | break
338 | tokens.append(raw_sentence[ref_token_idx])
339 | max_hard_rat = max(hard_rats[token_piece_start: token_piece_end])
340 | max_soft_rat = max(soft_rats[token_piece_start: token_piece_end])
341 | exps_outputs.append((max_hard_rat, max_soft_rat))
342 | if exps is None:
343 | return tokens
344 | return tokens, exps_outputs
345 |
--------------------------------------------------------------------------------
/expred/train.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict
2 |
3 | import sys
4 |
5 | import argparse
6 | import logging
7 | import wandb as wandb
8 | from typing import List, Dict, Set, Tuple
9 |
10 | import torch
11 | import os
12 | import json
13 | import numpy as np
14 | import random
15 |
16 | from itertools import chain
17 |
18 | from expred import metrics
19 | from expred.params import MTLParams
20 | from expred.models.mlp_mtl import BertMTL, BertClassifier
21 | from expred.tokenizer import BertTokenizerWithMapping
22 | from expred.models.pipeline.mtl_pipeline_utils import decode
23 | from expred.utils import load_datasets, load_documents, write_jsonl, Annotation
24 | from expred.models.pipeline.mtl_token_identifier import train_mtl_token_identifier
25 | from expred.models.pipeline.mtl_evidence_classifier import train_mtl_evidence_classifier
26 | from expred.eraser_utils import get_docids
27 |
28 | BATCH_FIRST = True
29 |
30 |
31 | def initialize_models(conf: dict,
32 | tokenizer: BertTokenizerWithMapping,
33 | batch_first: bool) -> Tuple[BertMTL, BertClassifier, Dict[int, str]]:
34 | """
35 | Does several things:
36 | 1. Create a mapping from label names to ids
37 | 2. Configure and create the multi task learner, the first stage of the model (BertMTL)
38 | 3. Configure and create the evidence classifier, second stage of the model (BertClassifier)
39 | :param conf:
40 | :param tokenizer:
41 | :param batch_first:
42 | :return: BertMTL, BertClassifier, label mapping
43 | """
44 | assert batch_first
45 | max_length = conf['max_length']
46 | # label mapping
47 | labels = dict((y, x) for (x, y) in enumerate(conf['classes']))
48 |
49 | # configure multi task learner
50 | mtl_params = MTLParams
51 | mtl_params.num_labels = len(labels)
52 | mtl_params.dim_exp_gru = conf['dim_exp_gru']
53 | mtl_params.dim_cls_linear = conf['dim_cls_linear']
54 | bert_dir = conf['bert_dir']
55 | use_half_precision = bool(conf['mtl_token_identifier'].get('use_half_precision', 1))
56 | evidence_identifier = BertMTL(bert_dir=bert_dir,
57 | tokenizer=tokenizer,
58 | mtl_params=mtl_params,
59 | max_length=max_length,
60 | use_half_precision=use_half_precision)
61 |
62 | # set up the evidence classifier
63 | use_half_precision = bool(conf['evidence_classifier'].get('use_half_precision', 1))
64 | evidence_classifier = BertClassifier(bert_dir=bert_dir,
65 | pad_token_id=tokenizer.pad_token_id,
66 | cls_token_id=tokenizer.cls_token_id,
67 | sep_token_id=tokenizer.sep_token_id,
68 | num_labels=mtl_params.num_labels,
69 | max_length=max_length,
70 | mtl_params=mtl_params,
71 | use_half_precision=use_half_precision)
72 |
73 | return evidence_identifier, evidence_classifier, labels
74 |
75 |
76 | logging.basicConfig(level=logging.DEBUG, format='%(relativeCreated)6d %(threadName)s %(message)s')
77 | logger = logging.getLogger(__name__)
78 |
79 | # let's make this more or less deterministic (not resistent to restarts)
80 | random.seed(12345)
81 | np.random.seed(67890)
82 | torch.manual_seed(10111213)
83 | torch.backends.cudnn.deterministic = True
84 | torch.backends.cudnn.benchmark = False
85 |
86 |
87 | # or, uncomment the following sentences to make it more than random
88 | # rand_seed_1 = ord(os.urandom(1)) * ord(os.urandom(1))
89 | # rand_seed_2 = ord(os.urandom(1)) * ord(os.urandom(1))
90 | # rand_seed_3 = ord(os.urandom(1)) * ord(os.urandom(1))
91 | # random.seed(rand_seed_1)
92 | # np.random.seed(rand_seed_2)
93 | # torch.manual_seed(rand_seed_3)
94 | # torch.backends.cudnn.deterministic = False
95 | # torch.backends.cudnn.benchmark = True
96 |
97 |
98 | def main(args : List[str]):
99 | # setup the Argument Parser
100 | parser = argparse.ArgumentParser(description=('Trains a pipeline model.\n'
101 | '\n'
102 | 'Step 1 is evidence identification, the MTL happens here. It '
103 | 'predicts the label of the current sentence and tags its\n '
104 | ' sub-tokens in the same time \n'
105 | ' Step 2 is evidence classification, a BERT classifier takes the output of the evidence identifier and predicts its \n'
106 | ' sentiment. Unlike in Deyong et al. this classifier takes in the same length as the identifier\'s input but with \n'
107 | ' irrational sub-tokens masked.\n'
108 | '\n'
109 | ' These models should be separated into two separate steps, but at the moment:\n'
110 | ' * prep data (load, intern documents, load json)\n'
111 | ' * convert data for evidence identification - in the case of training data we take all the positives and sample some negatives\n'
112 | ' * side note: this sampling is *somewhat* configurable and is done on a per-batch/epoch basis in order to gain a broader sampling of negative values.\n'
113 | ' * train evidence identification\n'
114 | ' * convert data for evidence classification - take all rationales + decisions and use this as input\n'
115 | ' * train evidence classification\n'
116 | ' * decode first the evidence, then run classification for each split\n'
117 | '\n'
118 | ' '), formatter_class=argparse.RawTextHelpFormatter)
119 | parser.add_argument('--data_dir', dest='data_dir', required=True,
120 | help='Which directory contains a {train,val,test}.jsonl file?')
121 | parser.add_argument('--output_dir', dest='output_dir', required=True,
122 | help='Where shall we write intermediate models + final data to?')
123 | parser.add_argument('--conf', dest='conf', required=True,
124 | help='JSoN file for loading arbitrary model parameters (e.g. optimizers, pre-saved files, etc.')
125 | parser.add_argument('--batch_size', type=int, required=False, default=None,
126 | help='Overrides the batch_size given in the config file. Helpful for debugging')
127 | args = parser.parse_args(args)
128 |
129 | wandb.init(entity="explainable-nlp", project="expred")
130 | # Configure
131 | os.makedirs(args.output_dir, exist_ok=True)
132 |
133 | # loads the config
134 | with open(args.conf, 'r') as fp:
135 | logger.info(f'Loading configuration from {args.conf}')
136 | conf = json.load(fp)
137 | if args.batch_size is not None:
138 | logger.info(
139 | 'Overwriting batch_sizes'
140 | f'(mtl_token_identifier:{conf["mtl_token_identifier"]["batch_size"]}'
141 | f'evidence_classifier:{conf["evidence_classifier"]["batch_size"]})'
142 | f'provided in config by command line argument({args.batch_size})'
143 | )
144 | conf['mtl_token_identifier']['batch_size'] = args.batch_size
145 | conf['evidence_classifier']['batch_size'] = args.batch_size
146 | logger.info(f'Configuration: {json.dumps(conf, indent=2, sort_keys=True)}')
147 |
148 | # todo add seeds
149 | wandb.config.update(conf)
150 | wandb.config.update(args)
151 |
152 | # load the annotation data
153 | train, val, test = load_datasets(args.data_dir)
154 |
155 | # get's all docids needed that are contained in the loaded splits
156 | docids: Set[str] = set(chain.from_iterable(map(lambda ann: get_docids(ann),
157 | chain(train, val, test))))
158 |
159 | documents: Dict[str, List[List[str]]] = load_documents(args.data_dir, docids)
160 | logger.info(f'Load {len(documents)} documents')
161 | # this ignores the case where annotations don't align perfectly with token boundaries, but this isn't that important
162 |
163 | tokenizer = BertTokenizerWithMapping.from_pretrained(conf['bert_vocab'])
164 | mtl_token_identifier, evidence_classifier, labels_mapping = \
165 | initialize_models(conf, tokenizer, batch_first=BATCH_FIRST)
166 | # logger.info(f'We have {len(word_interner)} wordpieces')
167 |
168 | # tokenizes and caches tokenized_docs, same for annotations
169 | # todo typo here? slides = slices (words?)
170 | tokenized_docs, tokenized_doc_token_slices = tokenizer.encode_docs(documents, args.output_dir)
171 | indexed_train, indexed_val, indexed_test = [tokenizer.encode_annotations(data) for data in [train, val, test]]
172 |
173 | logger.info('Beginning training of the MTL identifier')
174 | mtl_token_identifier = mtl_token_identifier.cuda()
175 | mtl_token_identifier, mtl_token_identifier_results, \
176 | train_machine_annotated, eval_machine_annotated, test_machine_annotated = \
177 | train_mtl_token_identifier(mtl_token_identifier,
178 | args.output_dir,
179 | indexed_train,
180 | indexed_val,
181 | indexed_test,
182 | labels_mapping=labels_mapping,
183 | interned_documents=tokenized_docs,
184 | source_documents=documents,
185 | token_mapping=tokenized_doc_token_slices,
186 | model_pars=conf,
187 | tensorize_model_inputs=True)
188 | mtl_token_identifier = mtl_token_identifier.cpu()
189 | # evidence identifier ends
190 |
191 | logger.info('Beginning training of the evidence classifier')
192 | evidence_classifier = evidence_classifier.cuda()
193 | optimizer = None
194 | scheduler = None
195 |
196 | # trains the classifier on the masked (based on rationales) documents
197 | evidence_classifier, evidence_class_results = train_mtl_evidence_classifier(evidence_classifier,
198 | args.output_dir,
199 | train_machine_annotated,
200 | eval_machine_annotated,
201 | tokenized_docs,
202 | conf,
203 | optimizer=optimizer,
204 | scheduler=scheduler,
205 | class_interner=labels_mapping,
206 | tensorize_model_inputs=True)
207 | # evidence classifier ends
208 |
209 | logger.info('Beginning final decoding')
210 | mtl_token_identifier = mtl_token_identifier.cuda()
211 | pipeline_batch_size = min(
212 | [conf['evidence_classifier']['batch_size'], conf['mtl_token_identifier']['batch_size']])
213 | pipeline_results, train_decoded, val_decoded, test_decoded = decode(evidence_identifier=mtl_token_identifier,
214 | evidence_classifier=evidence_classifier,
215 | train=indexed_train,
216 | mrs_train=train_machine_annotated,
217 | val=indexed_val,
218 | mrs_eval=eval_machine_annotated,
219 | test=indexed_test,
220 | mrs_test=test_machine_annotated,
221 | source_documents=documents,
222 | interned_documents=tokenized_docs,
223 | token_mapping=tokenized_doc_token_slices,
224 | class_interner=labels_mapping,
225 | tensorize_modelinputs=True,
226 | batch_size=pipeline_batch_size,
227 | tokenizer=tokenizer)
228 | write_jsonl(train_decoded, os.path.join(args.output_dir, 'train_decoded.jsonl'))
229 | write_jsonl(val_decoded, os.path.join(args.output_dir, 'val_decoded.jsonl'))
230 | write_jsonl(test_decoded, os.path.join(args.output_dir, 'test_decoded.jsonl'))
231 | with open(os.path.join(args.output_dir, 'identifier_results.json'), 'w') as ident_output, \
232 | open(os.path.join(args.output_dir, 'classifier_results.json'), 'w') as class_output:
233 | ident_output.write(json.dumps(mtl_token_identifier_results))
234 | class_output.write(json.dumps(evidence_class_results))
235 | for k, v in pipeline_results.items():
236 | if type(v) is dict:
237 | for k1, v1 in v.items():
238 | logging.info(f'Pipeline results for {k}, {k1}={v1}')
239 | else:
240 | logging.info(f'Pipeline results {k}\t={v}')
241 | # decode ends
242 |
243 | scores = metrics.main(
244 | [
245 | '--data_dir', args.data_dir,
246 | '--split', 'test',
247 | '--results', os.path.join(args.output_dir, 'test_decoded.jsonl'),
248 | '--score_file', os.path.join(args.output_dir, 'test_scores.jsonl')
249 | ]
250 | )
251 |
252 | wandb.log(scores)
253 |
254 | wandb.save(os.path.join(args.output_dir, '*.jsonl'))
255 |
256 |
257 |
258 | if __name__ == '__main__':
259 | main(sys.argv[1:])
260 |
--------------------------------------------------------------------------------
/expred/models/pipeline/mtl_token_identifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import logging
4 | import numpy as np
5 | import random
6 | import wandb
7 |
8 | from typing import List, Dict, Tuple, Callable, Union, Any
9 | from collections import OrderedDict, namedtuple, defaultdict
10 | from sklearn.metrics import accuracy_score, classification_report
11 | from itertools import chain
12 | from torch import nn
13 | from torch.nn import Module
14 |
15 | from expred.eraser_utils import chain_sentence_evidences
16 | from expred.models.model_utils import PaddedSequence
17 | from expred.utils import Annotation
18 | from expred.models.pipeline.pipeline_utils import (
19 | SentenceEvidence, score_token_rationales)
20 | from expred.models.pipeline.mtl_pipeline_utils import (
21 | annotations_to_mtl_token_identification,
22 | make_mtl_token_preds_epoch
23 | )
24 | from expred.models.losses import resampling_rebalanced_crossentropy
25 |
26 | AnnotatedDocument = namedtuple('AnnotatedDocument', 'kls evd ann_id query docid index sentence')
27 |
28 |
29 | def _get_sampling_method(params):
30 | if params['sampling_method'] == 'whole_document':
31 | def whole_document_sampler(sentences, _):
32 | return chain_sentence_evidences(sentences)
33 |
34 | return whole_document_sampler
35 | else:
36 | raise NotImplementedError
37 |
38 |
39 | def _prep_data_for_epoch(evidence_data: Tuple[str, Dict[str, Dict[str, List[SentenceEvidence]]]],
40 | sampler: Callable[
41 | [List[SentenceEvidence], Dict[str, List[SentenceEvidence]]], List[SentenceEvidence]]
42 | ) -> List[SentenceEvidence]:
43 | """
44 | Shuffle the annotations and sample from documents (can also be the whole document depending on the sampler)
45 | :param evidence_data:
46 | :param sampler:
47 | :return:
48 | """
49 | output_annotations = []
50 | ann_ids = sorted(evidence_data.keys())
51 | # in place shuffle so we get a different per-epoch ordering
52 | random.shuffle(ann_ids)
53 | for ann_id in ann_ids:
54 | for docid, sentences in evidence_data[ann_id][1].items():
55 | data = sampler(sentences, None)
56 | output_annotations.append((evidence_data[ann_id][0], data))
57 | return output_annotations
58 |
59 |
60 | def train_mtl_token_identifier(mtl_token_identifier: nn.Module,
61 | save_dir: str,
62 | train: List[Annotation],
63 | val: List[Annotation],
64 | test: List[Annotation],
65 | interned_documents: Dict[str, List[List[int]]],
66 | source_documents: Dict[str, List[List[str]]],
67 | token_mapping: Dict[str, List[List[Tuple[int, int]]]],
68 | model_pars: dict,
69 | labels_mapping: Dict[str, int],
70 | optimizer=None,
71 | scheduler=None,
72 | tensorize_model_inputs: bool = True) -> Tuple[
73 | Module, Union[Dict[str, list], Any], List[Tuple[Union[SentenceEvidence, List[SentenceEvidence]], Any, Any]], List[
74 | Tuple[Union[SentenceEvidence, List[SentenceEvidence]], Any, Any]], List[
75 | Tuple[Union[SentenceEvidence, List[SentenceEvidence]], Any, Any]]]:
76 | """Trains a module for token-level rationale identification.
77 | This method tracks loss on the entire validation set, saves intermediate
78 | models, and supports restoring from an unfinished state. The best model on
79 | the validation set is maintained, and the model stops training if a patience
80 | (see below) number of epochs with no improvement is exceeded.
81 | As there are likely too many negative examples to reasonably train a
82 | classifier on everything, every epoch we subsample the negatives.
83 |
84 | :param mtl_token_identifier: token-wise evidence identifier using Multi-Task Learning
85 | :param save_dir: a place to save intermediate and final results and models.
86 | :param train: a List of interned Annotation objects.
87 | :param val: a List of interned Annotation objects.
88 | :param test: a List of interned Annotation objects.
89 | :param interned_documents: a Dict of interned sentences
90 | :param source_documents: a Dict of original sentences, used for sub-token alignment
91 | :param token_mapping: a mapping from original token to sub tokens
92 | :param model_pars: model parameters
93 | :param labels_mapping: a mapping from str labels to their int ids
94 | :param optimizer: what pytorch optimizer to use, if none, initialize Adam
95 | :param scheduler: optional, do we want a scheduler involved in learning?
96 | :param tensorize_model_inputs: should we convert our data to tensors before passing it to the model?
97 | Useful if we have a model that performs its own tokenization (e.g. BERT as a Service)
98 | :returns
99 | the trained MTL evidence token identifier
100 | the intermediate results
101 | machine-annotated train/eval/test datasets
102 | """
103 |
104 |
105 |
106 | # set up output folder structure
107 | logging.info(f'Beginning training with {len(train)} annotations, {len(val)} for validation')
108 | evidence_identifier_output_dir = os.path.join(save_dir, 'evidence_token_identifier')
109 | os.makedirs(save_dir, exist_ok=True)
110 | os.makedirs(evidence_identifier_output_dir, exist_ok=True)
111 |
112 | model_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_token_identifier.pt')
113 | epoch_save_file = os.path.join(evidence_identifier_output_dir, 'evidence_token_identifier_epoch_data.pt')
114 |
115 | # set up training (optimizer, loss (both), sampling_method, ...)
116 | if optimizer is None:
117 | optimizer = torch.optim.Adam(mtl_token_identifier.parameters(), lr=model_pars['mtl_token_identifier']['lr'])
118 | cls_criterion = nn.BCELoss(reduction='none')
119 |
120 | exp_criterion = resampling_rebalanced_crossentropy(seq_reduction='none') # nn.CrossEntropyLoss(reduction='none')
121 | sampling_method = _get_sampling_method(model_pars['mtl_token_identifier'])
122 | batch_size = model_pars['mtl_token_identifier']['batch_size']
123 | max_length = model_pars['max_length']
124 | epochs = model_pars['mtl_token_identifier']['epochs']
125 |
126 | patience = model_pars['mtl_token_identifier']['patience']
127 | max_grad_norm = model_pars['mtl_token_identifier'].get('max_grad_norm', None)
128 | par_lambda = model_pars['mtl_token_identifier']['par_lambda']
129 | # annotation id -> docid -> [SentenceEvidence])
130 | # calculates the classification of the sequence tokens and some other stuff
131 | evidence_train_data = annotations_to_mtl_token_identification(train,
132 | source_documents=source_documents,
133 | interned_documents=interned_documents,
134 | token_mapping=token_mapping)
135 | evidence_val_data = annotations_to_mtl_token_identification(val,
136 | source_documents=source_documents,
137 | interned_documents=interned_documents,
138 | token_mapping=token_mapping)
139 |
140 | evidence_test_data = annotations_to_mtl_token_identification(test,
141 | source_documents=source_documents,
142 | interned_documents=interned_documents,
143 | token_mapping=token_mapping)
144 |
145 | device = next(mtl_token_identifier.parameters()).device
146 |
147 | results = {
148 | 'sampled_epoch_train_losses': [],
149 | 'epoch_val_total_losses': [],
150 | 'epoch_val_cls_losses': [],
151 | 'epoch_val_exp_losses': [],
152 | 'epoch_val_exp_acc': [],
153 | 'epoch_val_exp_f': [],
154 | 'epoch_val_cls_acc': [],
155 | 'epoch_val_cls_f': [],
156 | 'full_epoch_val_rationale_scores': []
157 | }
158 |
159 | # allow restoring an existing training run
160 | start_epoch = 0
161 | best_epoch = -1
162 | best_val_total_loss = float('inf')
163 | best_model_state_dict = None
164 | epoch_data = {}
165 | if os.path.exists(epoch_save_file):
166 | mtl_token_identifier.load_state_dict(torch.load(model_save_file))
167 | epoch_data = torch.load(epoch_save_file)
168 | start_epoch = epoch_data['epoch'] + 1
169 | # handle finishing because patience was exceeded or we didn't get the best final epoch
170 | if bool(epoch_data.get('done', 0)):
171 | start_epoch = epochs
172 | results = epoch_data['results']
173 | best_epoch = start_epoch
174 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in mtl_token_identifier.state_dict().items()})
175 | logging.info(f'Training evidence identifier from epoch {start_epoch} until epoch {epochs}')
176 | optimizer.zero_grad()
177 | for epoch in range(start_epoch, epochs):
178 | epoch_train_data = _prep_data_for_epoch(evidence_train_data, sampling_method)
179 | epoch_val_data = _prep_data_for_epoch(evidence_val_data, sampling_method)
180 | sampled_epoch_train_loss = 0
181 | losses = defaultdict(lambda: [])
182 | mtl_token_identifier.train()
183 | logging.info(
184 | f'Training with {len(epoch_train_data) // batch_size} batches with {len(epoch_train_data)} examples')
185 | # one epoch of training
186 | for batch_start in range(0, len(epoch_train_data), batch_size):
187 | batch_elements = epoch_train_data[batch_start:min(batch_start + batch_size, len(epoch_train_data))]
188 | # we sample every time to thereoretically get a better representation of instances over the corpus.
189 | # this might just take more time than doing so in advance.
190 | labels, targets, queries, sentences, has_evidence = zip(
191 | *[(s[0], s[1].kls, s[1].query, s[1].sentence, s[1].has_evidence) for s in batch_elements])
192 |
193 | # one hot encoding for classification
194 | labels = [[i == labels_mapping[label] for i in range(len(labels_mapping))] for label in labels]
195 | labels = torch.tensor(labels, dtype=torch.float, device=device)
196 |
197 | ids = [(s[1].ann_id, s[1].docid, s[1].index) for s in batch_elements]
198 |
199 | # truncation
200 | cropped_targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep]
201 | + list(target[:(max_length - len(query) - 2)]) for query, target in
202 | zip(queries, targets)]
203 | cropped_targets = PaddedSequence.autopad(
204 | [torch.tensor(t, dtype=torch.float, device=device) for t in cropped_targets],
205 | batch_first=True, device=device)
206 | targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep]
207 | + list(target) for query, target in zip(queries, targets)]
208 | targets = PaddedSequence.autopad([torch.tensor(t, dtype=torch.float, device='cpu') for t in targets],
209 | batch_first=True, device='cpu')
210 | if tensorize_model_inputs:
211 | if all(q is None for q in queries):
212 | queries = [torch.tensor([], dtype=torch.long) for _ in queries]
213 | else:
214 | assert all(q is not None for q in queries)
215 | queries = [torch.tensor(q, dtype=torch.long) for q in queries]
216 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
217 |
218 | # prediction
219 | preds = mtl_token_identifier(queries, ids, sentences)
220 | cls_preds, exp_preds, attention_masks = preds
221 | cls_loss = cls_criterion(cls_preds, labels).mean(dim=-1).sum()
222 |
223 | exp_loss_per_instance = exp_criterion(exp_preds, cropped_targets.data.squeeze()).mean(dim=-1)
224 | has_evidence_mask = torch.tensor(has_evidence, dtype=float, device=exp_loss_per_instance.device)
225 |
226 | exp_loss = (exp_loss_per_instance * has_evidence_mask).sum()
227 | loss = cls_loss + par_lambda * exp_loss
228 |
229 | losses['cls_loss'].append(cls_loss.item())
230 | losses['exp_loss'].append(exp_loss.item())
231 | losses['loss'].append(loss.item())
232 |
233 | sampled_epoch_train_loss += loss.item()
234 | loss.backward()
235 | if max_grad_norm:
236 | torch.nn.utils.clip_grad_norm_(mtl_token_identifier.parameters(), max_grad_norm)
237 | optimizer.step()
238 | if scheduler:
239 | scheduler.step()
240 | optimizer.zero_grad()
241 | sampled_epoch_train_loss /= len(epoch_train_data)
242 | results['sampled_epoch_train_losses'].append(sampled_epoch_train_loss)
243 |
244 | mean_losses = {f'train_{key}':np.mean(loss) for key, loss in losses.items()}
245 | mean_losses['epoch'] = epoch
246 | wandb.log(mean_losses)
247 |
248 | logging.info(f'Epoch {epoch} training loss {sampled_epoch_train_loss}')
249 |
250 | # validation
251 | with torch.no_grad():
252 | mtl_token_identifier.eval()
253 | epoch_val_total_loss, epoch_val_cls_loss, epoch_val_exp_loss, \
254 | epoch_val_soft_pred, epoch_val_hard_pred, epoch_val_token_targets, \
255 | epoch_val_pred_labels, epoch_val_labels = \
256 | make_mtl_token_preds_epoch(mtl_token_identifier,
257 | epoch_val_data,
258 | labels_mapping,
259 | token_mapping,
260 | batch_size,
261 | max_length,
262 | par_lambda,
263 | device,
264 | cls_criterion,
265 | exp_criterion,
266 | tensorize_model_inputs)
267 | # epoch_val_soft_pred = list(chain.from_iterable(epoch_val_soft_pred.tolist()))
268 | # epoch_val_hard_pred = list(chain.from_iterable(epoch_val_hard_pred))
269 | # epoch_val_truth = list(chain.from_iterable(epoch_val_truth))
270 | results['epoch_val_total_losses'].append(epoch_val_total_loss)
271 | results['epoch_val_cls_losses'].append(epoch_val_cls_loss)
272 | results['epoch_val_exp_losses'].append(epoch_val_exp_loss)
273 | epoch_val_hard_pred_chained = list(chain.from_iterable(epoch_val_hard_pred))
274 | epoch_val_token_targets_chained = list(chain.from_iterable(epoch_val_token_targets))
275 | results['epoch_val_exp_acc'].append(accuracy_score(epoch_val_token_targets_chained,
276 | epoch_val_hard_pred_chained))
277 | results['epoch_val_exp_f'].append(classification_report(epoch_val_token_targets_chained,
278 | epoch_val_hard_pred_chained,
279 | labels=[0, 1], # of course rational and irrational
280 | output_dict=True))
281 | flattened_epoch_val_pred_labels = [np.argmax(x) for x in epoch_val_pred_labels]
282 | flattened_epoch_val_labels = [np.argmax(x) for x in epoch_val_labels]
283 | results['epoch_val_cls_acc'].append(accuracy_score(flattened_epoch_val_pred_labels,
284 | flattened_epoch_val_labels))
285 | # print(flattened_epoch_val_labels)
286 | # print(flattened_epoch_val_pred_labels)
287 | results['epoch_val_cls_f'].append(classification_report(flattened_epoch_val_labels,
288 | flattened_epoch_val_pred_labels,
289 | labels=[v for _, v in labels_mapping.items()],
290 | output_dict=True))
291 | results['full_epoch_val_rationale_scores'].append(
292 | score_token_rationales(val, source_documents,
293 | epoch_val_data,
294 | token_mapping,
295 | epoch_val_hard_pred,
296 | epoch_val_soft_pred))
297 |
298 | validation_metrics = {metric:values[-1] for metric, values in results.items()}
299 | validation_metrics['epoch'] = epoch
300 | # epoch_val_soft_pred_for_scoring = [[[1 - z, z] for z in y] for y in epoch_val_soft_pred]
301 | # logging.info(
302 | # f'Epoch {epoch} full val loss {epoch_val_total_loss}, accuracy: {results["epoch_val_acc"][-1]}, f: {results["epoch_val_f"][-1]}, rationale scores: look, it\'s already a pain to duplicate this code. What do you want from me.')
303 |
304 | # if epoch_val_loss < best_val_loss:
305 | if epoch_val_total_loss < best_val_total_loss:
306 | logging.debug(f'Epoch {epoch} new best model with val loss {epoch_val_total_loss}')
307 | best_model_state_dict = OrderedDict({k: v.cpu() for k, v in mtl_token_identifier.state_dict().items()})
308 | best_epoch = epoch
309 | best_val_loss = epoch_val_total_loss
310 | torch.save(mtl_token_identifier.state_dict(), model_save_file)
311 | epoch_data = {
312 | 'epoch': epoch,
313 | 'results': results,
314 | 'best_val_loss': best_val_loss,
315 | 'done': 0
316 | }
317 | torch.save(epoch_data, epoch_save_file)
318 | if epoch - best_epoch > patience:
319 | epoch_data['done'] = 1
320 | torch.save(epoch_data, epoch_save_file)
321 | break
322 |
323 | epoch_data['done'] = 1
324 | epoch_data['results'] = results
325 | torch.save(epoch_data, epoch_save_file)
326 | mtl_token_identifier.load_state_dict(best_model_state_dict)
327 | mtl_token_identifier = mtl_token_identifier.to(device=device)
328 | mtl_token_identifier.eval()
329 |
330 | def prepare_for_cl(input_data, keep_corrected_only=False):
331 | """
332 | Extract rationale prediction from document only.
333 | If keep_corrected_only=True keep only the annotation where the classification was correct
334 | :param input_data:
335 | :param keep_corrected_only:
336 | :return:
337 | """
338 | epoch_input_data = _prep_data_for_epoch(input_data, sampling_method)
339 | _, _, _, soft_pred_for_cl, hard_pred_for_cl, _, \
340 | pred_labels_for_cl, labels_for_cl = \
341 | make_mtl_token_preds_epoch(mtl_token_identifier, epoch_input_data, labels_mapping,
342 | token_mapping, batch_size, max_length, par_lambda,
343 | device, cls_criterion, exp_criterion, tensorize_model_inputs)
344 | hard_pred_for_cl = [h.cpu().tolist() for h in hard_pred_for_cl]
345 | hard_pred_for_cl = [h[len(d[1].query) + 2:] for h, d in zip(hard_pred_for_cl, epoch_input_data)]
346 | soft_pred_for_cl = [s[len(d[1].query) + 2:] for s, d in zip(soft_pred_for_cl, epoch_input_data)]
347 | train_ids = list(range(len(labels_for_cl)))
348 | if keep_corrected_only:
349 | labels_for_cl = [np.argmax(x) for x in labels_for_cl]
350 | pred_labels_for_cl = [np.argmax(x) for x in pred_labels_for_cl]
351 | train_ids = list(filter(lambda i: labels_for_cl[i] == pred_labels_for_cl[i],
352 | range(len(labels_for_cl))))
353 | return [(epoch_input_data[i], soft_pred_for_cl[i], hard_pred_for_cl[i]) for i in train_ids]
354 |
355 | train_machine_annotated = prepare_for_cl(evidence_train_data, True)
356 | eval_machine_annotated = prepare_for_cl(evidence_val_data, False)
357 | test_machine_annotated = prepare_for_cl(evidence_test_data, False)
358 | return mtl_token_identifier, results, \
359 | train_machine_annotated, eval_machine_annotated, test_machine_annotated
360 |
--------------------------------------------------------------------------------
/expred/models/pipeline/mtl_pipeline_utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import defaultdict, namedtuple
3 | from itertools import chain
4 | from typing import Any, Dict, List, Tuple
5 |
6 | import torch
7 | import torch.nn as nn
8 | from sklearn.metrics import classification_report, accuracy_score
9 |
10 | from expred.eraser_utils import chain_sentence_evidences, get_docids
11 | from expred.models.model_utils import PaddedSequence
12 | from expred.models.pipeline.pipeline_utils import SentenceEvidence, _grouper, \
13 | score_rationales
14 | from expred.utils import Annotation
15 | from expred.eraser_benchmark import rational_bits_to_ev_generator
16 | from expred.utils import convert_subtoken_ids_to_tokens
17 |
18 |
19 | def mask_annotations_to_evidence_classification(mrs: List[Tuple[Tuple[str, SentenceEvidence], Any]], # mrs for machine rationales
20 | class_interner: dict) -> List[SentenceEvidence]:
21 | """
22 |
23 | :param mrs:
24 | :param class_interner:
25 | :return:
26 | """
27 | ret = []
28 | for mr, _, hard_prediction in mrs:
29 | kls = class_interner[mr[0]]
30 | evidence = mr[1]
31 | sentence = evidence.sentence
32 | query = evidence.query
33 | try:
34 | assert len(hard_prediction) == len(sentence)
35 | except Exception:
36 | print(mr)
37 | print(len(hard_prediction))
38 | print(len(query))
39 | print(len(sentence))
40 | masked_sentence = [p * d for p, d in zip(hard_prediction, sentence)]
41 | ret.append(SentenceEvidence(kls=kls,
42 | query=query,
43 | ann_id=evidence.ann_id,
44 | docid=evidence.docid,
45 | index=-1,
46 | sentence=masked_sentence,
47 | has_evidence=evidence.has_evidence))
48 | return ret
49 |
50 |
51 | def annotations_to_evidence_token_identification(annotations: List[Annotation],
52 | source_documents: Dict[str, List[List[str]]],
53 | interned_documents: Dict[str, List[List[int]]],
54 | token_mapping: Dict[str, List[List[Tuple[int, int]]]]
55 | ) -> Dict[str, Dict[str, List[SentenceEvidence]]]:
56 | """
57 | Calculates the start and end positions for
58 | * tokens
59 | * sentences
60 | * rationales
61 | and create a classification map for the tokens.
62 |
63 |
64 | :param annotations:
65 | :param source_documents:
66 | :param interned_documents:
67 | :param token_mapping:
68 | :return: dict[ann_id]dict[doc_id][SentenceEvidence, ...]
69 | """
70 | # TODO document
71 | # TODO should we simplify to use only source text?
72 | ret = defaultdict(lambda: defaultdict(list)) # annotation id -> docid -> sentences
73 | positive_tokens = 0
74 | negative_tokens = 0
75 | for ann in annotations:
76 | annid = ann.annotation_id
77 | docids = list(set(get_docids(ann)))
78 | sentence_offsets = defaultdict(list) # docid -> [(start, end)]
79 | classes = defaultdict(list) # docid -> [token is yea or nay]
80 | absolute_word_mapping = defaultdict(list) # docid -> [(absolute wordpiece start, absolute wordpiece end)]
81 |
82 | # chain the sentences and store start and of each token same for each sentence
83 | # also prepare classes (list of zeros for each word piece token)
84 | for docid in docids:
85 | start = 0
86 | assert len(source_documents[docid]) == len(interned_documents[docid])
87 | for sentence_id, (whole_token_sent, wordpiece_sent) in enumerate(
88 | zip(source_documents[docid], interned_documents[docid])):
89 | classes[docid].extend([0 for _ in wordpiece_sent])
90 | end = start + len(wordpiece_sent)
91 | sentence_offsets[docid].append((start, end))
92 | absolute_word_mapping[docid].extend([(start + relative_wp_start,
93 | start + relative_wp_end)
94 | for relative_wp_start,
95 | relative_wp_end in token_mapping[docid][sentence_id]])
96 | start = end
97 | # calculate the start and end tokens for the evidence spans and set classes to 1 respectively
98 | for ev in chain.from_iterable(ann.evidences):
99 | if len(ev.text) == 0:
100 | continue
101 | flat_token_map = list(chain.from_iterable(token_mapping[ev.docid]))
102 | if ev.start_token != -1 and ev.start_sentence != -1:
103 | # start, end = token_mapping[ev.docid][ev.start_token][0], token_mapping[ev.docid][ev.end_token][1]
104 | sentence_offset_start = sentence_offsets[ev.docid][ev.start_sentence][0]
105 | sentence_offset_end = sentence_offsets[ev.docid][ev.end_sentence - 1][0]
106 | start = sentence_offset_start + flat_token_map[ev.start_token][0]
107 | end = sentence_offset_end + flat_token_map[ev.end_token - 1][1]
108 | elif ev.start_token == -1 and ev.start_sentence != -1:
109 | start = sentence_offsets[ev.start_sentence][0]
110 | end = sentence_offsets[ev.end_sentence - 1][1]
111 | elif ev.start_token != -1 and ev.start_sentence == -1:
112 | start = absolute_word_mapping[ev.docid][ev.start_token][0]
113 | end = absolute_word_mapping[ev.docid][ev.end_token][1]
114 | else:
115 | continue
116 | for i in range(start, end):
117 | try:
118 | classes[ev.docid][i] = 1
119 | except IndexError:
120 | print(ev)
121 | print(ev.docid)
122 | print(classes)
123 | print(len(classes[ev.docid]))
124 | print(i)
125 | raise IndexError
126 | for docid, offsets in sentence_offsets.items():
127 | token_assignments = classes[docid]
128 | positive_tokens += sum(token_assignments)
129 | negative_tokens += len(token_assignments) - sum(token_assignments)
130 | for s, (start, end) in enumerate(offsets):
131 | sent = interned_documents[docid][s]
132 | ret[annid][docid].append(SentenceEvidence(kls=tuple(token_assignments[start:end]),
133 | query=ann.query,
134 | ann_id=ann.annotation_id,
135 | docid=docid,
136 | index=s,
137 | sentence=sent,
138 | has_evidence=len(ann.evidences) > 0))
139 | logging.info(f"Have {positive_tokens} positive wordpiece tokens, {negative_tokens} negative wordpiece tokens")
140 | return ret
141 |
142 |
143 | def annotations_to_mtl_token_identification(annotations: object,
144 | source_documents: object,
145 | interned_documents: object,
146 | token_mapping: object
147 | ) -> object:
148 | """
149 | See annotations_to_evidence_token_identification for more details
150 | :param annotations:
151 | :param source_documents:
152 | :param interned_documents:
153 | :param token_mapping:
154 | :return: dict[ann_id][classification ,dict[doc_id][SentenceEvidence, ...]]
155 | """
156 | rets = annotations_to_evidence_token_identification(
157 | annotations,
158 | source_documents,
159 | interned_documents,
160 | token_mapping
161 | )
162 | # adds the final sequence classification
163 | for ann in annotations:
164 | ann_id = ann.annotation_id
165 | ann_kls = ann.classification
166 | rets[ann_id] = [ann_kls, rets[ann_id]]
167 | return rets
168 |
169 |
170 | # for mtl pipeline
171 | def make_mtl_token_preds_batch(classifier: nn.Module,
172 | batch_elements: List[SentenceEvidence],
173 | labels_mapping: Dict[str, int],
174 | token_mapping: Dict[str, List[List[Tuple[int, int]]]],
175 | max_length: int,
176 | par_lambda: int,
177 | device=None,
178 | cls_criterion: nn.Module = None,
179 | exp_criterion: nn.Module = None,
180 | tensorize_model_inputs: bool = True) -> Tuple[float, float, float, List[float], List[int], List[int]]:
181 | batch_elements = [s for s in batch_elements if s is not None]
182 | labels, targets, queries, sentences = zip(*[(s[0], s[1].kls, s[1].query, s[1].sentence)
183 | for s in batch_elements])
184 | labels = [[i == labels_mapping[label] for i in range(len(labels_mapping))] for label in labels]
185 | labels = torch.tensor(labels, dtype=torch.float, device=device)
186 | ids = [(s[1].ann_id, s[1].docid, s[1].index) for s in batch_elements]
187 |
188 | cropped_targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep]
189 | + list(target[:(max_length - len(query) - 2)]) for query, target in zip(queries, targets)]
190 | cropped_targets = PaddedSequence.autopad(
191 | [torch.tensor(t, dtype=torch.float, device=device) for t in cropped_targets],
192 | batch_first=True, device=device)
193 |
194 | targets = [[0] * (len(query) + 2) # length of query and overheads such as [cls] and [sep]
195 | + list(target) for query, target in zip(queries, targets)]
196 | targets = PaddedSequence.autopad([torch.tensor(t, dtype=torch.float, device=device) for t in targets],
197 | batch_first=True, device=device)
198 |
199 | if tensorize_model_inputs:
200 | if all(q is None for q in queries):
201 | queries = [torch.tensor([], dtype=torch.long) for _ in queries]
202 | else:
203 | assert all(q is not None for q in queries)
204 | queries = [torch.tensor(q, dtype=torch.long) for q in queries]
205 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
206 | preds = classifier(queries, ids, sentences)
207 | cls_preds, exp_preds, attention_masks = preds
208 | cls_loss = cls_criterion(cls_preds, labels).mean(dim=-1).sum()
209 | cls_preds = [x.cpu().tolist() for x in cls_preds]
210 | labels = [x.cpu().tolist() for x in labels]
211 |
212 | exp_loss = exp_criterion(exp_preds, cropped_targets.data.squeeze()).mean(dim=-1).sum()
213 | #print(exp_loss.shape, cls_loss.shape)
214 | exp_preds = [x.cpu() for x in exp_preds]
215 | hard_preds = [torch.round(x).to(dtype=torch.int).cpu() for x in targets.unpad(exp_preds)]
216 | exp_preds = [x.tolist() for x in targets.unpad(exp_preds)]
217 | token_targets = [[y.item() for y in x] for x in targets.unpad(targets.data.cpu())]
218 | total_loss = cls_loss + par_lambda * exp_loss
219 |
220 | return total_loss, cls_loss, exp_loss, \
221 | exp_preds, hard_preds, token_targets, \
222 | cls_preds, labels
223 |
224 |
225 | # for mtl pipeline
226 | def make_mtl_token_preds_epoch(classifier: nn.Module,
227 | data: List[SentenceEvidence],
228 | labels_mapping: Dict[str, int],
229 | token_mapping: Dict[str, List[List[Tuple[int, int]]]],
230 | batch_size: int,
231 | max_length: int,
232 | par_lambda: int,
233 | device=None,
234 | cls_criterion: nn.Module = None,
235 | exp_criterion: nn.Module = None,
236 | tensorize_model_inputs: bool = True):
237 | epoch_total_loss = 0
238 | epoch_cls_loss = 0
239 | epoch_exp_loss = 0
240 | epoch_soft_pred = []
241 | epoch_hard_pred = []
242 | epoch_token_targets = []
243 | epoch_pred_labels = []
244 | epoch_labels = []
245 | batches = _grouper(data, batch_size)
246 | classifier.eval()
247 | #for p in classifier.parameters():
248 | # print(str(p.device))
249 | # print('cuda:0')
250 | # print(str(p.device) == 'cuda:0')
251 | # if str(p.device) != 'cuda:0':
252 | # print(p)
253 | # assert False
254 | for batch in batches:
255 | total_loss, cls_loss, exp_loss, \
256 | soft_preds, hard_preds, token_targets, \
257 | pred_labels, labels = make_mtl_token_preds_batch(classifier,
258 | batch,
259 | labels_mapping,
260 | token_mapping,
261 | max_length,
262 | par_lambda,
263 | device,
264 | cls_criterion=cls_criterion,
265 | exp_criterion=exp_criterion,
266 | tensorize_model_inputs=tensorize_model_inputs)
267 | if total_loss is not None:
268 | epoch_total_loss += total_loss.sum().item()
269 | if cls_loss is not None:
270 | epoch_cls_loss += cls_loss.sum().item()
271 | if exp_loss is not None:
272 | epoch_exp_loss += exp_loss.sum().item()
273 | epoch_hard_pred.extend(hard_preds)
274 | epoch_soft_pred.extend(soft_preds)
275 | epoch_token_targets.extend(token_targets)
276 | epoch_pred_labels.extend(pred_labels)
277 | epoch_labels.extend(labels)
278 | epoch_total_loss /= len(data)
279 | epoch_cls_loss /= len(data)
280 | epoch_exp_loss /= len(data)
281 | return epoch_total_loss, epoch_cls_loss, epoch_exp_loss, \
282 | epoch_soft_pred, epoch_hard_pred, epoch_token_targets, \
283 | epoch_pred_labels, epoch_labels
284 |
285 |
286 | def make_mtl_classification_preds_batch(classifier: nn.Module,
287 | batch_elements: List[SentenceEvidence],
288 | class_interner: dict,
289 | device=None,
290 | criterion: nn.Module = None,
291 | tensorize_model_inputs: bool = True) -> Tuple[float, List[float], List[int], List[int]]:
292 | batch_elements = filter(lambda x: x is not None, batch_elements)
293 | targets, queries, sentences = zip(*[(s.kls, s.query, s.sentence) for s in batch_elements])
294 | ids = [(s.ann_id, s.docid, s.index) for s in batch_elements]
295 | targets = [[i == target for i in range(len(class_interner))] for target in targets]
296 | targets = torch.tensor(targets, dtype=torch.float, device=device)
297 | if tensorize_model_inputs:
298 | queries = [torch.tensor(q, dtype=torch.long) for q in queries]
299 | sentences = [torch.tensor(s, dtype=torch.long) for s in sentences]
300 | preds = classifier(queries, ids, sentences)
301 | targets = targets.to(device=preds.device)
302 | if criterion:
303 | loss = criterion(preds, targets)
304 | else:
305 | loss = None
306 | # .float() because pytorch 1.3 introduces a bug where argmax is unsupported for float16
307 | hard_preds = torch.argmax(preds.float(), dim=-1)
308 | return loss, preds, hard_preds, targets
309 |
310 |
311 | def make_mtl_classification_preds_epoch(classifier: nn.Module,
312 | data: List[SentenceEvidence],
313 | class_interner: dict,
314 | batch_size: int,
315 | device=None,
316 | criterion: nn.Module = None,
317 | tensorize_model_inputs: bool = True):
318 | epoch_loss = 0
319 | epoch_soft_pred = []
320 | epoch_hard_pred = []
321 | epoch_truth = []
322 | batches = _grouper(data, batch_size)
323 | classifier.eval()
324 | for batch in batches:
325 | loss, soft_preds, hard_preds, targets = make_mtl_classification_preds_batch(classifier=classifier,
326 | batch_elements=batch,
327 | class_interner=class_interner,
328 | device=device,
329 | criterion=criterion,
330 | tensorize_model_inputs=tensorize_model_inputs)
331 | if loss is not None:
332 | epoch_loss += loss.sum().item()
333 | epoch_hard_pred.extend(hard_preds)
334 | epoch_soft_pred.extend(soft_preds.cpu())
335 | epoch_truth.extend(targets)
336 | epoch_loss /= len(data)
337 | epoch_hard_pred = [x.item() for x in epoch_hard_pred]
338 | epoch_truth = [x.argmax().item() for x in epoch_truth]
339 | return epoch_loss, epoch_soft_pred, epoch_hard_pred, epoch_truth
340 |
341 |
342 | def convert_to_global_token_mapping(token_mapping):
343 | ret = []
344 | sent_offset = 0
345 | for sent in token_mapping:
346 | for mapping in sent:
347 | ret.append((mapping[0]+sent_offset, mapping[1]+sent_offset))
348 | sent_offset = ret[-1][-1]
349 | return ret
350 |
351 |
352 | def decode(evidence_identifier: nn.Module,
353 | evidence_classifier: nn.Module,
354 | train: List[Annotation],
355 | val: List[Annotation],
356 | test: List[Annotation],
357 | source_documents,
358 | token_mapping,
359 | mrs_train,
360 | mrs_eval,
361 | mrs_test,
362 | class_interner: Dict[str, int],
363 | batch_size: int,
364 | tensorize_modelinputs: bool,
365 | interned_documents: bool=None,
366 | tokenizer=None) -> dict:
367 | device = None
368 | class_labels = [k for k, v in sorted(class_interner.items(), key=lambda x: x[1])]
369 | if interned_documents is None:
370 | interned_documents = source_documents
371 |
372 | def prep(data: List[Annotation]) -> List[Tuple[SentenceEvidence, SentenceEvidence]]:
373 | identification_data = annotations_to_mtl_token_identification(annotations=data,
374 | source_documents=source_documents,
375 | interned_documents=interned_documents,
376 | token_mapping=token_mapping)
377 |
378 | identification_data = {ann_id: [v[0], {docid: chain_sentence_evidences(sentences)
379 | for docid, sentences in v[1].items()}]
380 | for ann_id, v in identification_data.items()}
381 | classification_data = mask_annotations_to_evidence_classification(mrs=mrs_test,
382 | class_interner=class_interner)
383 | #ann_doc_sents = defaultdict(lambda: defaultdict(dict)) # ann id -> docid -> sent idx -> sent data
384 | ret = []
385 | for sent_ev in classification_data:
386 | id_data = identification_data[sent_ev.ann_id][1][sent_ev.docid]
387 | ret.append((id_data, sent_ev))
388 | assert id_data.ann_id == sent_ev.ann_id
389 | assert id_data.docid == sent_ev.docid
390 | #assert id_data.index == sent_ev.index
391 | assert len(ret) == len(classification_data)
392 | return ret
393 |
394 | def decode_batch(data: List[Tuple[SentenceEvidence, SentenceEvidence]],
395 | mrs,
396 | name: str,
397 | score: bool = False,
398 | annotations: List[Annotation] = None,
399 | tokenizer=None) -> dict:
400 | """Identifies evidence statements and then makes classifications based on it.
401 |
402 | Args:
403 | data: a paired list of SentenceEvidences, differing only in the kls field.
404 | The first corresponds to whether or not something is evidence, and the second corresponds to an evidence class
405 | name: a name for a results dict
406 | """
407 |
408 | num_uniques = len(set((x.ann_id, x.docid) for x, _ in data))
409 | logging.info(f'Decoding dataset {name} with {len(data)} sentences, {num_uniques} annotations')
410 | identifier_data, classifier_data = zip(*data)
411 | results = dict()
412 | IdentificationClassificationResult = namedtuple('IdentificationClassificationResult',
413 | 'identification_data classification_data soft_identification hard_identification soft_classification hard_classification')
414 | with torch.no_grad():
415 | # make predictions for the evidence_identifier
416 | evidence_identifier.eval()
417 | evidence_classifier.eval()
418 | _, soft_identification_preds, hard_identification_preds = zip(*mrs)
419 | assert len(soft_identification_preds) == len(data)
420 | identification_results = defaultdict(list)
421 | for id_data, cls_data, soft_id_pred, hard_id_pred in zip(identifier_data, classifier_data,
422 | soft_identification_preds,
423 | hard_identification_preds):
424 | res = IdentificationClassificationResult(identification_data=id_data,
425 | classification_data=cls_data,
426 | # 1 is p(evidence|sent,query)
427 | soft_identification=soft_id_pred,
428 | hard_identification=hard_id_pred,
429 | soft_classification=None,
430 | hard_classification=False)
431 | identification_results[(id_data.ann_id, id_data.docid)].append(res) # in original eraser, each sentence
432 | # is stored separately, thence for each ann_idxdocid key there is a list of identification results, each
433 | # corresponds to a sentence. While in our approach a document is chained together from the begining and
434 | # rationalities are predicted in token-level granularity
435 |
436 | best_identification_results = {key: max(value, key=lambda x: x.soft_identification) for key, value in
437 | identification_results.items()}
438 | logging.info(
439 | f'Selected the best sentence for {len(identification_results)} examples from a total of {len(soft_identification_preds)} sentences')
440 | ids, classification_data = zip(
441 | *[(k, v.classification_data) for k, v in best_identification_results.items()])
442 | _, soft_classification_preds, hard_classification_preds, classification_truth = \
443 | make_mtl_classification_preds_epoch(classifier=evidence_classifier,
444 | data=classification_data,
445 | class_interner=class_interner,
446 | batch_size=batch_size,
447 | device=device,
448 | tensorize_model_inputs=tensorize_modelinputs)
449 | classification_results = dict()
450 | for eyeD, soft_class, hard_class in zip(ids, soft_classification_preds, hard_classification_preds):
451 | input_id_result = best_identification_results[eyeD]
452 | res = IdentificationClassificationResult(identification_data=input_id_result.identification_data,
453 | classification_data=input_id_result.classification_data,
454 | soft_identification=input_id_result.soft_identification,
455 | hard_identification=input_id_result.hard_identification,
456 | soft_classification=soft_class,
457 | hard_classification=hard_class)
458 | classification_results[eyeD] = res
459 |
460 | if score:
461 | truth = []
462 | pred = []
463 | for res in classification_results.values():
464 | truth.append(res.classification_data.kls)
465 | pred.append(res.hard_classification)
466 | # results[f'{name}_f1'] = classification_report(classification_truth, pred, target_names=class_labels, output_dict=True)
467 | results[f'{name}_f1'] = classification_report(classification_truth, hard_classification_preds,
468 | target_names=class_labels,
469 | labels=list(range(len(class_labels))), output_dict=True)
470 | results[f'{name}_acc'] = accuracy_score(classification_truth, hard_classification_preds)
471 | results[f'{name}_rationale'] = score_rationales(annotations, interned_documents, identifier_data,
472 | soft_identification_preds)
473 |
474 | # turn the above results into a format suitable for scoring via the rationale scorer
475 | # n.b. the sentence-level evidence predictions (hard and soft) are
476 | # broadcast to the token level for scoring. The comprehensiveness class
477 | # score is also a lie since the pipeline model above is faithful by
478 | # design.
479 | decoded = dict()
480 | decoded_scores = defaultdict(list)
481 | for (ann_id, docid), pred in classification_results.items():
482 | #sentence_prediction_scores = [x.soft_identification for x in identification_results[(ann_id, docid)]]
483 | hard_rationale_predictions = list(chain.from_iterable(x.hard_identification for x in identification_results[(ann_id, docid)]))
484 | soft_rationale_predictions = list(chain.from_iterable(x.soft_identification for x in identification_results[(ann_id, docid)]))
485 | subtoken_ids = list(chain.from_iterable(interned_documents[docid]))
486 | raw_document = []
487 | for word in chain.from_iterable(source_documents[docid]):
488 | token_ids_origin = tokenizer.encode(word, add_special_tokens=False)
489 | if token_ids_origin[0] == tokenizer.unk_token_id:
490 | raw_document.append('[UNK]')
491 | else:
492 | tokenized = ''.join(tokenizer.basic_tokenizer.tokenize(word)) # dumm ass ˈlʊdvɪɡ_væn_ˈbeɪˌtoʊvən
493 | raw_document.append(tokenized)
494 | global_token_mapping = convert_to_global_token_mapping(token_mapping[docid])
495 | tokens, exp_outputs = convert_subtoken_ids_to_tokens(subtoken_ids,
496 | tokenizer=tokenizer,
497 | token_mapping=global_token_mapping,
498 | exps=(hard_rationale_predictions,
499 | soft_rationale_predictions),
500 | raw_sentence=raw_document)
501 | hard_rationale_predictions, soft_rationale_predictions = list(zip(*exp_outputs))
502 | # if docid == 'Ludwig_van_Beethoven':
503 | # #print(len(hard_rationale_predictions))
504 | # print(len(soft_rationale_predictions))
505 | ev_generator = rational_bits_to_ev_generator(tokens,
506 | docid,
507 | hard_rationale_predictions)
508 | hard_rationale_predictions = [ev for ev in ev_generator]
509 |
510 | if ann_id not in decoded:
511 | decoded[ann_id] = {
512 | "annotation_id": ann_id,
513 | "rationales": [],
514 | "classification": class_labels[pred.hard_classification],
515 | "classification_scores": {class_labels[i]: s.item() for i, s in
516 | enumerate(pred.soft_classification)},
517 | # TODO this should turn into the data distribution for the predicted class
518 | # "comprehensiveness_classification_scores": 0.0,
519 | "truth": pred.classification_data.kls,
520 | }
521 | decoded[ann_id]['rationales'].append({
522 | "docid": docid,
523 | "hard_rationale_predictions": hard_rationale_predictions,
524 | "soft_rationale_predictions": soft_rationale_predictions,
525 | })
526 | decoded_scores[ann_id].append(pred.soft_classification)
527 |
528 | return results, list(decoded.values())
529 |
530 | test_results, test_decoded = decode_batch(prep(test), mrs_test, 'test', score=False, tokenizer=tokenizer)
531 | val_results, val_decoded = dict(), []
532 | train_results, train_decoded = dict(), []
533 | # val_results, val_decoded = decode_batch(prep(val), 'val', score=True, annotations=val)
534 | # train_results, train_decoded = decode_batch(prep(train), 'train', score=True, annotations=train)
535 | return dict(**train_results, **val_results, **test_results), train_decoded, val_decoded, test_decoded
536 |
--------------------------------------------------------------------------------