├── W3C ├── __init__.py ├── extract_threads.py └── anonymize.py ├── Avocado ├── __init__.py ├── standardize.py ├── extact_threads.py └── anonymize_files │ └── paths.txt ├── .gitignore ├── rouge ├── requirements.txt ├── __init__.py ├── run.sh ├── tokenize_test.py ├── test_util.py ├── oss │ └── oss_release.sh ├── setup.py ├── tokenize.py ├── create_pyrouge_files.py ├── rouge.py ├── README.md ├── io_test.py ├── scoring.py ├── io.py ├── scoring_test.py ├── rouge_scorer.py └── rouge_scorer_test.py ├── bert_score ├── README.md ├── __init__.py ├── rescale_baseline │ └── en │ │ ├── distilroberta-base.tsv │ │ ├── distilbert-base-uncased.tsv │ │ ├── distilbert-base-multilingual-cased.tsv │ │ ├── distilbert-base-uncased-distilled-squad.tsv │ │ ├── roberta-base.tsv │ │ ├── xlm-mlm-en-2048.tsv │ │ ├── xlm-roberta-base.tsv │ │ ├── xlnet-base-cased.tsv │ │ ├── albert-base-v2.tsv │ │ ├── albert-xxlarge-v2.tsv │ │ ├── albert-base-v1.tsv │ │ ├── bert-base-multilingual-cased.tsv │ │ ├── bert-base-uncased.tsv │ │ ├── albert-xxlarge-v1.tsv │ │ ├── bert-base-cased-finetuned-mrpc.tsv │ │ ├── xlm-mlm-100-1280.tsv │ │ ├── roberta-large.tsv │ │ ├── roberta-large-mnli.tsv │ │ ├── xlm-roberta-large.tsv │ │ ├── albert-large-v2.tsv │ │ ├── bert-large-uncased.tsv │ │ ├── xlnet-large-cased.tsv │ │ ├── albert-large-v1.tsv │ │ ├── albert-xlarge-v1.tsv │ │ └── albert-xlarge-v2.tsv ├── score.py └── scorer.py ├── requirements.txt ├── LICENSE ├── metrics.py ├── run.py ├── data.py ├── README.md └── utils.py /W3C/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Avocado/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | Avocado/emails.json 3 | W3C/raw_data 4 | train/ -------------------------------------------------------------------------------- /rouge/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | nltk 3 | numpy 4 | six>=1.14 5 | -------------------------------------------------------------------------------- /bert_score/README.md: -------------------------------------------------------------------------------- 1 | From https://github.com/Tiiiger/bert_score/tree/master/bert_score -------------------------------------------------------------------------------- /bert_score/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.5" 2 | from .score import * 3 | from .scorer import * 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.20.1 2 | tqdm==4.59.0 3 | ujson==4.0.2 4 | nltk==3.5 5 | spacy==3.0.5 6 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/distilroberta-base.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.42608285,0.4272089,0.42462298 3 | 1,0.7367886,0.7370362,0.736573 4 | 2,0.79922664,0.799593,0.7991632 5 | 3,0.8329021,0.8333321,0.83291864 6 | 4,0.8442,0.84462386,0.84425896 7 | 5,0.84732,0.84759504,0.8473319 8 | 6,0.89334005,0.8935088,0.8933471 9 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/distilbert-base-uncased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.2884445,0.2884457,0.28333962 3 | 1,0.39316687,0.3931663,0.39123002 4 | 2,0.42905498,0.4290923,0.42735597 5 | 3,0.5222444,0.52227175,0.52129734 6 | 4,0.6019937,0.6019904,0.6014007 7 | 5,0.6666034,0.66660464,0.66620487 8 | 6,0.51401854,0.51404256,0.5131456 9 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/distilbert-base-multilingual-cased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.27245584,0.27247205,0.26611173 3 | 1,0.45394143,0.453942,0.45178676 4 | 2,0.5374658,0.5374726,0.53619426 5 | 3,0.61241305,0.61244136,0.6116679 6 | 4,0.63282156,0.632836,0.63219804 7 | 5,0.8164157,0.81645757,0.81623197 8 | 6,0.4648941,0.4649093,0.4638737 9 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/distilbert-base-uncased-distilled-squad.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.28725642,0.2872663,0.28207442 3 | 1,0.37234208,0.37233955,0.37046063 4 | 2,0.403689,0.4037149,0.4020736 5 | 3,0.5399291,0.53997463,0.53930676 6 | 4,0.6591859,0.65919137,0.65882134 7 | 5,0.65313077,0.6531313,0.65279835 8 | 6,0.74920315,0.7491901,0.7487158 9 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/roberta-base.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.4043224,0.40432808,0.40218553 3 | 1,0.6423126,0.6422804,0.6414617 4 | 2,0.768273,0.7682535,0.76791227 5 | 3,0.7803166,0.78030443,0.7800415 6 | 4,0.7839782,0.78397924,0.7836174 7 | 5,0.7959116,0.7959033,0.79557085 8 | 6,0.80936664,0.80936354,0.80908644 9 | 7,0.81720984,0.81721514,0.816965 10 | 8,0.80465585,0.80464727,0.8043641 11 | 9,0.7911581,0.79115206,0.7908595 12 | 10,0.8146725,0.8146619,0.814463 13 | 11,0.8243949,0.8244051,0.82420003 14 | 12,0.8557132,0.85571885,0.8555707 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/xlm-mlm-en-2048.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.48034036,0.48027167,0.4755281 3 | 1,0.68549955,0.68547165,0.68418026 4 | 2,0.7502881,0.7502652,0.7497456 5 | 3,0.7662417,0.7662214,0.7659151 6 | 4,0.7910623,0.7910466,0.79085386 7 | 5,0.8090659,0.8090618,0.80895317 8 | 6,0.82148397,0.8214852,0.821408 9 | 7,0.8091143,0.8091184,0.8090199 10 | 8,0.77966934,0.7796406,0.77937865 11 | 9,0.75278246,0.7527972,0.7524639 12 | 10,0.72071564,0.7207407,0.7202978 13 | 11,0.7175687,0.7176211,0.7170889 14 | 12,0.22130837,0.22130068,0.21938775 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/xlm-roberta-base.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.31767526,0.31771243,0.31208947 3 | 1,0.45930108,0.45930612,0.4573549 4 | 2,0.6739723,0.6739605,0.67332643 5 | 3,0.7428563,0.7428622,0.74252146 6 | 4,0.7270618,0.7270706,0.7267292 7 | 5,0.7459538,0.7459533,0.74563044 8 | 6,0.7416182,0.74162334,0.74136156 9 | 7,0.7766629,0.7766664,0.7764565 10 | 8,0.7827196,0.78271383,0.78251594 11 | 9,0.81658614,0.8165717,0.81639785 12 | 10,0.83839214,0.83837646,0.8382293 13 | 11,0.8711623,0.8711581,0.87106025 14 | 12,0.9843661,0.98436636,0.9843645 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/xlnet-base-cased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.29910204,0.29919305,0.29052314 3 | 1,0.29633516,0.29640594,0.2915415 4 | 2,0.28782755,0.28787795,0.28492415 5 | 3,0.29966587,0.2996727,0.29745364 6 | 4,0.32897076,0.32897395,0.3263186 7 | 5,0.34247187,0.3424195,0.34024557 8 | 6,0.61728173,0.61718243,0.6160013 9 | 7,0.6704566,0.6703779,0.66936857 10 | 8,0.8596307,0.8595696,0.859391 11 | 9,0.8611796,0.8611522,0.8610164 12 | 10,0.89382625,0.8938215,0.8937337 13 | 11,0.97762144,0.9776183,0.97761476 14 | 12,0.93146294,0.93134,0.93100053 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-base-v2.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.43284354,0.4329465,0.42670736 3 | 1,0.4085349,0.40857056,0.4041539 4 | 2,0.42302486,0.42304876,0.41986418 5 | 3,0.43835327,0.43837532,0.43578437 6 | 4,0.46398157,0.4640153,0.46179092 7 | 5,0.487097,0.48714137,0.48507443 8 | 6,0.50701046,0.5070602,0.50516284 9 | 7,0.5251579,0.5252073,0.52346826 10 | 8,0.5432063,0.5432638,0.5416856 11 | 9,0.56169736,0.56174135,0.56031275 12 | 10,0.58207834,0.58211654,0.58080167 13 | 11,0.5087994,0.5088567,0.50630754 14 | 12,0.4822224,0.48224902,0.4795803 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-xxlarge-v2.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.4414845,0.4415628,0.4378333 3 | 1,0.26729813,0.26729846,0.26443842 4 | 2,0.25006709,0.25006858,0.2470538 5 | 3,0.22912578,0.22914563,0.22677879 6 | 4,0.23676835,0.23678702,0.23474906 7 | 5,0.23712093,0.23712862,0.23520498 8 | 6,0.2357785,0.23579709,0.2339876 9 | 7,0.2375271,0.2375658,0.2357691 10 | 8,0.23694733,0.2369875,0.23519956 11 | 9,0.24043696,0.24048997,0.23847668 12 | 10,0.25991938,0.25997588,0.257621 13 | 11,0.3076668,0.30775174,0.30460533 14 | 12,0.5213576,0.52133,0.5192018 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-base-v1.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.42279568,0.42285842,0.4198645 3 | 1,0.38239375,0.3824535,0.3795375 4 | 2,0.35127786,0.35131463,0.34854048 5 | 3,0.3402314,0.34027407,0.33761653 6 | 4,0.34001094,0.3400646,0.33745667 7 | 5,0.34310105,0.34314916,0.34054983 8 | 6,0.3478834,0.34792796,0.34530792 9 | 7,0.3523316,0.35237584,0.34973368 10 | 8,0.35546654,0.35550496,0.35283387 11 | 9,0.35682797,0.35686156,0.3541417 12 | 10,0.3572713,0.35730729,0.35451323 13 | 11,0.35916516,0.35920846,0.35632935 14 | 12,0.3620535,0.3621047,0.35911387 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/bert-base-multilingual-cased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.31651747,0.3166142,0.31180394 3 | 1,0.38737702,0.38744056,0.38455048 4 | 2,0.37912813,0.37916443,0.37648088 5 | 3,0.46451283,0.46451145,0.46312103 6 | 4,0.5066057,0.50659287,0.5054953 7 | 5,0.5804824,0.5804496,0.5797646 8 | 6,0.63067275,0.630636,0.63018715 9 | 7,0.54218787,0.5421653,0.5414328 10 | 8,0.5240471,0.5240057,0.5232123 11 | 9,0.6320527,0.6320019,0.63146895 12 | 10,0.69633687,0.6962761,0.6958725 13 | 11,0.7193143,0.7192363,0.7188216 14 | 12,0.3473233,0.34732684,0.34655094 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/bert-base-uncased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.3231512,0.32322776,0.31853873 3 | 1,0.32517454,0.32522815,0.32197207 4 | 2,0.3708038,0.37080705,0.36834884 5 | 3,0.36287847,0.36286885,0.36059204 6 | 4,0.3786389,0.37860426,0.3767926 7 | 5,0.4018232,0.401791,0.40032896 8 | 6,0.38439456,0.38434005,0.38282546 9 | 7,0.37114623,0.3710986,0.36949417 10 | 8,0.37231025,0.37226102,0.37049443 11 | 9,0.35375935,0.3537393,0.35219112 12 | 10,0.38161838,0.3816211,0.37991408 13 | 11,0.4421448,0.4421776,0.44040316 14 | 12,0.40192786,0.40191513,0.40038353 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-xxlarge-v1.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.44518736,0.44525033,0.44190475 3 | 1,0.26892486,0.26893654,0.26619813 4 | 2,0.25225964,0.25227055,0.2495048 5 | 3,0.23626596,0.23626427,0.23414151 6 | 4,0.24108262,0.24108647,0.23914734 7 | 5,0.2402725,0.24029303,0.23852193 8 | 6,0.24204335,0.24206877,0.24038398 9 | 7,0.24432875,0.24436904,0.2427339 10 | 8,0.24470611,0.24472676,0.24312295 11 | 9,0.24761276,0.24763304,0.2458257 12 | 10,0.26654655,0.26657295,0.26450548 13 | 11,0.30993807,0.309992,0.3073111 14 | 12,0.46560258,0.46563277,0.463768 15 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/bert-base-cased-finetuned-mrpc.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.32524315,0.32527947,0.32047534 3 | 1,0.3697738,0.3697855,0.36682808 4 | 2,0.3912412,0.39124438,0.38884974 5 | 3,0.38678017,0.3867508,0.3849363 6 | 4,0.4306143,0.43059555,0.4291982 7 | 5,0.47680253,0.47676748,0.4757307 8 | 6,0.4937383,0.4937078,0.49275663 9 | 7,0.47395828,0.47392154,0.47275484 10 | 8,0.48822877,0.48818707,0.48712534 11 | 9,0.55345184,0.55342007,0.5525519 12 | 10,0.6535154,0.6534775,0.6529064 13 | 11,0.76415604,0.7641147,0.76378924 14 | 12,0.72067815,0.7206308,0.72023565 15 | -------------------------------------------------------------------------------- /rouge/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/xlm-mlm-100-1280.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.2929519,0.29297927,0.28788087 3 | 1,0.32307193,0.32305866,0.31955993 4 | 2,0.33333376,0.33329934,0.3307059 5 | 3,0.34018472,0.34019333,0.3369147 6 | 4,0.35193846,0.35196185,0.34877294 7 | 5,0.41633913,0.41635182,0.41389906 8 | 6,0.52230054,0.5223191,0.5208747 9 | 7,0.57117224,0.5711975,0.57016635 10 | 8,0.55626523,0.55628437,0.55513597 11 | 9,0.5035621,0.5035617,0.5023768 12 | 10,0.43660313,0.4366135,0.43496045 13 | 11,0.37350416,0.37354943,0.3712711 14 | 12,0.3694557,0.36947483,0.36708415 15 | 13,0.38296118,0.38296735,0.38057274 16 | 14,0.3801941,0.38019708,0.37771493 17 | 15,0.39073846,0.39073724,0.38804337 18 | 16,0.27941948,0.2793937,0.27774334 19 | -------------------------------------------------------------------------------- /rouge/run.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | set -e 17 | set -x 18 | 19 | virtualenv -p python3 . 20 | source ./bin/activate 21 | 22 | pip install -r rouge/requirements.txt 23 | python -m rouge.io_test 24 | python -m rouge.rouge_scorer_test 25 | python -m rouge.scoring_test 26 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/roberta-large.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.3712891,0.37132213,0.36826715 3 | 1,0.67176163,0.6717439,0.6703483 4 | 2,0.70031923,0.7003052,0.69969934 5 | 3,0.7080897,0.7081011,0.707698 6 | 4,0.6976306,0.69762677,0.69710517 7 | 5,0.7187199,0.71873325,0.71828526 8 | 6,0.74678195,0.74678224,0.74642223 9 | 7,0.7772428,0.7772184,0.77691925 10 | 8,0.8021733,0.8021747,0.8019093 11 | 9,0.8067641,0.80678225,0.8065291 12 | 10,0.8366976,0.8367098,0.8364913 13 | 11,0.8163513,0.816369,0.8161064 14 | 12,0.8175406,0.8175611,0.81728977 15 | 13,0.82106245,0.8210674,0.82080233 16 | 14,0.81487834,0.8148861,0.8145652 17 | 15,0.8243552,0.8243522,0.8240494 18 | 16,0.8341641,0.8341684,0.833912 19 | 17,0.83150584,0.8314941,0.83122575 20 | 18,0.8314624,0.83146274,0.8311686 21 | 19,0.82761073,0.8276117,0.8273196 22 | 20,0.799873,0.79988,0.79956234 23 | 21,0.8082163,0.80819315,0.8079286 24 | 22,0.83196104,0.83195347,0.83174026 25 | 23,0.8408042,0.8408027,0.8405716 26 | 24,0.96022236,0.96021587,0.960168 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/roberta-large-mnli.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.36816803,0.36820343,0.3650997 3 | 1,0.6424572,0.64243424,0.6408211 4 | 2,0.62199366,0.6219771,0.62105906 5 | 3,0.65479594,0.65479946,0.6542115 6 | 4,0.66220766,0.66219413,0.66147035 7 | 5,0.6841878,0.6841976,0.6835943 8 | 6,0.6993157,0.6993184,0.698729 9 | 7,0.7363659,0.7363538,0.73597246 10 | 8,0.76699406,0.76697797,0.7666572 11 | 9,0.76385623,0.76387703,0.76359564 12 | 10,0.7751121,0.7751162,0.7748585 13 | 11,0.7607176,0.7607192,0.7604293 14 | 12,0.75846714,0.75850517,0.7582122 15 | 13,0.7660639,0.766093,0.7658386 16 | 14,0.76723933,0.7672636,0.76692307 17 | 15,0.76183504,0.7618548,0.7615043 18 | 16,0.77503896,0.7750635,0.77476084 19 | 17,0.7572284,0.75724494,0.7568846 20 | 18,0.72981,0.72983533,0.7294623 21 | 19,0.6901594,0.69018,0.6896288 22 | 20,0.6456024,0.6456534,0.6447707 23 | 21,0.6733705,0.6734108,0.672755 24 | 22,0.7964235,0.79642963,0.7961781 25 | 23,0.83942956,0.839427,0.8393037 26 | 24,0.87867236,0.8787309,0.8781039 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/xlm-roberta-large.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.38918123,0.38920417,0.3852401 3 | 1,0.66835684,0.6683084,0.6677018 4 | 2,0.7323929,0.7323684,0.7321559 5 | 3,0.7391762,0.7391537,0.73889536 6 | 4,0.7922834,0.79227173,0.7921484 7 | 5,0.79589903,0.795871,0.7957138 8 | 6,0.8166894,0.816673,0.8165898 9 | 7,0.8223533,0.8223572,0.82228154 10 | 8,0.834576,0.8345772,0.8344947 11 | 9,0.8377803,0.83777326,0.8376894 12 | 10,0.8380223,0.8380033,0.83791 13 | 11,0.8415803,0.84157884,0.8414282 14 | 12,0.84659237,0.8466055,0.84632146 15 | 13,0.8437288,0.84372836,0.84340864 16 | 14,0.846515,0.84650415,0.8461781 17 | 15,0.8514585,0.8514379,0.85112184 18 | 16,0.84461045,0.8446081,0.8442589 19 | 17,0.85291016,0.8529066,0.8525485 20 | 18,0.8582745,0.8582787,0.85787606 21 | 19,0.85327464,0.8532746,0.85287833 22 | 20,0.86624545,0.86624,0.86592185 23 | 21,0.8854349,0.88543147,0.88515806 24 | 22,0.8891757,0.8891605,0.88892245 25 | 23,0.88805044,0.88803035,0.88777393 26 | 24,0.9840399,0.98404247,0.984038 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-large-v2.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.43137488,0.4314412,0.4271023 3 | 1,0.47189355,0.47192886,0.46977237 4 | 2,0.4965904,0.49659666,0.49521467 5 | 3,0.4952368,0.4952206,0.49390256 6 | 4,0.49991024,0.4998804,0.49865857 7 | 5,0.5061125,0.5060827,0.50490576 8 | 6,0.52520007,0.5251885,0.5241151 9 | 7,0.5463337,0.54633546,0.54536676 10 | 8,0.56268036,0.56267744,0.5618048 11 | 9,0.5788636,0.5788671,0.5780607 12 | 10,0.59798187,0.5979915,0.5972454 13 | 11,0.6093569,0.6093737,0.60867995 14 | 12,0.61832786,0.6183305,0.6176837 15 | 13,0.6298888,0.62988657,0.6292773 16 | 14,0.63760334,0.6376027,0.6370052 17 | 15,0.6402277,0.6402217,0.63963217 18 | 16,0.6457506,0.6457368,0.64517874 19 | 17,0.6488497,0.6488231,0.6482803 20 | 18,0.6473536,0.6473276,0.6467711 21 | 19,0.65181977,0.6517948,0.6512418 22 | 20,0.65941834,0.6593918,0.65884435 23 | 21,0.65883756,0.65882397,0.65822756 24 | 22,0.6599824,0.6599794,0.6593097 25 | 23,0.6140344,0.6140205,0.6131047 26 | 24,0.54314095,0.54311645,0.5419062 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/bert-large-uncased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.33945993,0.33952734,0.3353803 3 | 1,0.46529758,0.46534532,0.4629573 4 | 2,0.5190359,0.51904607,0.5170987 5 | 3,0.55551875,0.5555247,0.5540426 6 | 4,0.47806495,0.4780755,0.47663376 7 | 5,0.39333034,0.3933407,0.391598 8 | 6,0.30678865,0.30683848,0.30446944 9 | 7,0.40164435,0.40167126,0.39997557 10 | 8,0.44429466,0.4443099,0.44277325 11 | 9,0.5114804,0.5114661,0.5102474 12 | 10,0.53322667,0.5332073,0.5323144 13 | 11,0.56793964,0.56791747,0.56725395 14 | 12,0.56360143,0.5635814,0.5629889 15 | 13,0.5358492,0.5358346,0.53522795 16 | 14,0.42079058,0.42078197,0.41975206 17 | 15,0.3509417,0.3509411,0.34957188 18 | 16,0.4534342,0.45341223,0.45231807 19 | 17,0.46370843,0.46370083,0.46265444 20 | 18,0.4278576,0.42786714,0.42646673 21 | 19,0.38974905,0.3897353,0.3877319 22 | 20,0.3966205,0.3966191,0.3942883 23 | 21,0.4981153,0.49813268,0.4955151 24 | 22,0.5868029,0.58685154,0.584482 25 | 23,0.7136535,0.7137033,0.7118858 26 | 24,0.5152624,0.5152391,0.5146088 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/xlnet-large-cased.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.41637358,0.41643414,0.41258112 3 | 1,0.32545134,0.32545993,0.3204785 4 | 2,0.29599807,0.29601985,0.29176536 5 | 3,0.21799843,0.2180424,0.21441601 6 | 4,0.2619272,0.261958,0.25913864 7 | 5,0.30362618,0.30360785,0.30147976 8 | 6,0.31371272,0.3136575,0.31170228 9 | 7,0.3085695,0.30850938,0.30676135 10 | 8,0.3251663,0.32509723,0.32402074 11 | 9,0.34611195,0.34610417,0.3449464 12 | 10,0.33172518,0.3316963,0.32996267 13 | 11,0.32673666,0.32671896,0.3252777 14 | 12,0.3015574,0.30154356,0.29979268 15 | 13,0.33127543,0.33126998,0.33017284 16 | 14,0.33191463,0.33192313,0.3307891 17 | 15,0.3753324,0.3753503,0.374231 18 | 16,0.37750244,0.37751338,0.37648135 19 | 17,0.3678608,0.3678761,0.36674905 20 | 18,0.305072,0.3050984,0.3042137 21 | 19,0.42524177,0.4253285,0.42387673 22 | 20,0.59149736,0.59153783,0.5901478 23 | 21,0.6070587,0.607099,0.6057612 24 | 22,0.80884385,0.80882186,0.8085461 25 | 23,0.9555436,0.9555404,0.95551467 26 | 24,0.96873486,0.9687297,0.9685215 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-large-v1.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.48447838,0.48450485,0.4821886 3 | 1,0.5124409,0.51243365,0.5109167 4 | 2,0.49396634,0.49394318,0.49285302 5 | 3,0.48355308,0.48351732,0.48258644 6 | 4,0.48206407,0.48202685,0.4811013 7 | 5,0.48171225,0.48167655,0.48073986 8 | 6,0.48402956,0.48400134,0.48304388 9 | 7,0.48760605,0.48758495,0.4866279 10 | 8,0.49034056,0.4903293,0.4893756 11 | 9,0.4919946,0.49199188,0.4910255 12 | 10,0.49351045,0.4935107,0.49251547 13 | 11,0.4953505,0.49535286,0.4943231 14 | 12,0.49792922,0.4979353,0.49686712 15 | 13,0.50119936,0.5012099,0.5001017 16 | 14,0.50464475,0.50465906,0.5035164 17 | 15,0.5072171,0.50723296,0.5060587 18 | 16,0.50804037,0.50805837,0.506836 19 | 17,0.50674427,0.5067624,0.5054734 20 | 18,0.5028615,0.5028785,0.50150096 21 | 19,0.4957624,0.49577576,0.49427336 22 | 20,0.48470628,0.48471764,0.48304176 23 | 21,0.46942177,0.4694329,0.46755382 24 | 22,0.45182654,0.45184082,0.44979697 25 | 23,0.4372368,0.43725976,0.43516964 26 | 24,0.43032366,0.4303518,0.42831102 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-xlarge-v1.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.37603918,0.37612942,0.37049496 3 | 1,0.31145602,0.3114958,0.3073803 4 | 2,0.25227228,0.2522994,0.24795091 5 | 3,0.22015819,0.22017719,0.21600199 6 | 4,0.21572605,0.21576598,0.21187688 7 | 5,0.21390381,0.21393314,0.21024637 8 | 6,0.21366087,0.21368802,0.21022928 9 | 7,0.2149553,0.21497151,0.2116843 10 | 8,0.21902423,0.21904334,0.215865 11 | 9,0.22598784,0.22601976,0.22294162 12 | 10,0.23651579,0.23656204,0.2335378 13 | 11,0.2508,0.25083283,0.24782418 14 | 12,0.26735264,0.26740175,0.2642045 15 | 13,0.2851571,0.2852036,0.28140694 16 | 14,0.30159834,0.3016559,0.2969648 17 | 15,0.31582344,0.31589058,0.31032172 18 | 16,0.33028397,0.3303347,0.32389277 19 | 17,0.34479943,0.34483773,0.33757344 20 | 18,0.3576801,0.35770583,0.34980485 21 | 19,0.36997133,0.36996147,0.3615338 22 | 20,0.3813416,0.38132015,0.37257645 23 | 21,0.3904368,0.39041746,0.38146585 24 | 22,0.4026223,0.40261322,0.39356884 25 | 23,0.41755676,0.41755086,0.4090774 26 | 24,0.40913486,0.40914643,0.40243107 27 | -------------------------------------------------------------------------------- /bert_score/rescale_baseline/en/albert-xlarge-v2.tsv: -------------------------------------------------------------------------------- 1 | LAYER,P,R,F 2 | 0,0.379094,0.37919718,0.37330297 3 | 1,0.27352002,0.27357075,0.26852632 4 | 2,0.24191533,0.24194317,0.23669504 5 | 3,0.2238661,0.22388461,0.21928357 6 | 4,0.22812894,0.22815062,0.22410771 7 | 5,0.22398795,0.22402358,0.22023973 8 | 6,0.22606015,0.22609216,0.22241953 9 | 7,0.22955626,0.22957715,0.2261971 10 | 8,0.23346025,0.23349406,0.230283 11 | 9,0.23933677,0.23937275,0.23639005 12 | 10,0.24947925,0.2495169,0.24674372 13 | 11,0.25879192,0.25879982,0.25623834 14 | 12,0.26840612,0.2684224,0.2659429 15 | 13,0.28223696,0.2822432,0.27990422 16 | 14,0.3007411,0.30081397,0.298456 17 | 15,0.32065493,0.32073346,0.31820792 18 | 16,0.3489667,0.34909493,0.34612358 19 | 17,0.37499505,0.37513632,0.37153322 20 | 18,0.39365283,0.3937659,0.3894278 21 | 19,0.3985198,0.39858896,0.39375183 22 | 20,0.40377426,0.4038127,0.3987301 23 | 21,0.4162669,0.41631454,0.41127917 24 | 22,0.4385093,0.43853307,0.43359485 25 | 23,0.50211877,0.5021498,0.49820283 26 | 24,0.6450441,0.6450727,0.64176905 27 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation and UNC NLP. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /rouge/tokenize_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """Tests for tokenize.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | from absl.testing import absltest 24 | from rouge import tokenize 25 | 26 | 27 | class TokenizeTest(absltest.TestCase): 28 | 29 | def test_give_me_a_name(self): 30 | self.assertEqual(['one', 'two', 'three'], 31 | tokenize.tokenize('one Two three', None)) 32 | self.assertEqual(['one', 'two', 'three'], 33 | tokenize.tokenize('one\n Two \nthree', None)) 34 | 35 | 36 | if __name__ == '__main__': 37 | absltest.main() 38 | -------------------------------------------------------------------------------- /rouge/test_util.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Test utils for ROUGE.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | _TESTDATA_PREFIX = os.path.join(os.path.dirname(__file__), "testdata") 25 | 26 | TARGETS_FILE = os.path.join(_TESTDATA_PREFIX, "target.txt") 27 | 28 | PREDICTIONS_FILE = os.path.join(_TESTDATA_PREFIX, "prediction.txt") 29 | 30 | LARGE_TARGETS_FILE = os.path.join(_TESTDATA_PREFIX, "target_large.txt") 31 | 32 | LARGE_PREDICTIONS_FILE = os.path.join(_TESTDATA_PREFIX, "prediction_large.txt") 33 | 34 | DELIMITED_FILE = os.path.join(_TESTDATA_PREFIX, "delimited.txt") 35 | 36 | PYROUGE_DIR = os.path.join(_TESTDATA_PREFIX, "pyrouge_files") 37 | 38 | 39 | def get_text(fname): 40 | with open(fname) as f: 41 | return f.read() 42 | -------------------------------------------------------------------------------- /rouge/oss/oss_release.sh: -------------------------------------------------------------------------------- 1 | # Copyright 2020 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | #!/bin/bash 16 | 17 | set -v # print commands as they're executed 18 | set -e # fail and exit on any command erroring 19 | 20 | GIT_COMMIT_ID=${1:-""} 21 | [[ -z $GIT_COMMIT_ID ]] && echo "Must provide a commit" && exit 1 22 | 23 | TMP_DIR=$(mktemp -d) 24 | pushd $TMP_DIR 25 | 26 | echo "Cloning trax and checking out commit $GIT_COMMIT_ID" 27 | git clone https://github.com/google-research/google-research 28 | cd google-research/rouge 29 | git checkout $GIT_COMMIT_ID 30 | sed -i 's/from rouge/from rouge_score/' *.py 31 | 32 | python -m pip install wheel twine pyopenssl 33 | 34 | # Build the distribution 35 | echo "Building distribution" 36 | python setup.py sdist 37 | python setup.py bdist_wheel --universal 38 | 39 | # Publish to PyPI 40 | echo "Publishing to PyPI" 41 | twine upload dist/* 42 | 43 | # Cleanup 44 | popd 45 | rm -rf $TMP_DIR 46 | -------------------------------------------------------------------------------- /rouge/setup.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import setuptools 17 | 18 | with open("README.md", "r") as fh: 19 | long_description = fh.read() 20 | 21 | setuptools.setup( 22 | name="rouge_score", 23 | version="0.0.4", 24 | author="Google LLC", 25 | author_email="no-reply@google.com", 26 | description="Pure python implementation of ROUGE-1.5.5.", 27 | long_description=long_description, 28 | long_description_content_type="text/markdown", 29 | url="https://github.com/google-research/google-research/tree/master/rouge", 30 | packages=["rouge_score"], 31 | package_dir={"rouge_score": ""}, 32 | classifiers=[ 33 | "Programming Language :: Python :: 3", 34 | "License :: OSI Approved :: Apache Software License", 35 | "Operating System :: OS Independent", 36 | ], 37 | install_requires=[ 38 | "absl-py", 39 | "nltk", 40 | "numpy", 41 | "six>=1.14.0", 42 | ], 43 | python_requires=">=2.7", 44 | ) 45 | -------------------------------------------------------------------------------- /rouge/tokenize.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """A library for tokenizing text.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import re 24 | import six 25 | 26 | 27 | def tokenize(text, stemmer): 28 | """Tokenize input text into a list of tokens. 29 | 30 | This approach aims to replicate the approach taken by Chin-Yew Lin in 31 | the original ROUGE implementation. 32 | 33 | Args: 34 | text: A text blob to tokenize. 35 | stemmer: An optional stemmer. 36 | 37 | Returns: 38 | A list of string tokens extracted from input text. 39 | """ 40 | 41 | # Convert everything to lowercase. 42 | text = text.lower() 43 | # Replace any non-alpha-numeric characters with spaces. 44 | text = re.sub(r"[^a-z0-9]+", " ", six.ensure_str(text)) 45 | 46 | tokens = re.split(r"\s+", text) 47 | if stemmer: 48 | # Only stem words more than 3 characters long. 49 | tokens = [stemmer.stem(x) if len(x) > 3 else x for x in tokens] 50 | 51 | # One final check to drop any empty or invalid tokens. 52 | tokens = [x for x in tokens if re.match(r"^[a-z0-9]+$", six.ensure_str(x))] 53 | 54 | return tokens 55 | 56 | 57 | if __name__ == '__main__': 58 | print(tokenize("I love you. I love", None)) 59 | -------------------------------------------------------------------------------- /rouge/create_pyrouge_files.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python3 17 | """For creating files from {target,prediction}.txt that can be processed 18 | by pyrouge to compare with scores in scoring_test.py. 19 | 20 | create_pyrouge_files -- --testdata_dir=`pwd`/testdata 21 | 22 | # testConfidenceIntervalsAgainstRouge155WithStemming result 23 | pyrouge_evaluate_plain_text_files \ 24 | -s /tmp/lkj -sfp "prediction.(.*).txt" \ 25 | -m /tmp/lkj -mfp target.#ID#.txt 26 | 27 | pyrouge_evaluate_plain_text_files \ 28 | -s /tmp/lkj -sfp "prediction_multi.(.*).txt" \ 29 | -m /tmp/lkj -mfp target_multi.#ID#.txt 30 | """ 31 | 32 | from __future__ import absolute_import 33 | from __future__ import division 34 | 35 | from __future__ import print_function 36 | 37 | import os 38 | 39 | from absl import app 40 | from absl import flags 41 | 42 | FLAGS = flags.FLAGS 43 | 44 | flags.DEFINE_string('testdata_dir', '', 'testdata path') 45 | flags.DEFINE_string('output', '/tmp/lkj', 'testdata path') 46 | 47 | 48 | def main(argv): 49 | if len(argv) > 1: 50 | raise app.UsageError('Too many command-line arguments.') 51 | 52 | # One line per target 53 | with open(os.path.join(FLAGS.testdata_dir, 'target_large.txt')) as f: 54 | targets = f.readlines() 55 | with open(os.path.join(FLAGS.testdata_dir, 'prediction_large.txt')) as f: 56 | predictions = f.readlines() 57 | 58 | def write_files(prefix, items): 59 | for i, t in enumerate(items): 60 | out = '%s.%d.txt' % (prefix, i) 61 | with open(os.path.join(FLAGS.output, out), 'w') as f: 62 | f.write(t) 63 | write_files('target', targets) 64 | write_files('prediction', predictions) 65 | 66 | # Delete this block 67 | def write_files2(prefix, items): 68 | index = 0 69 | f = None 70 | for i, t in enumerate(items): 71 | # Write 4 lines per file 72 | if i % 4 == 0: 73 | if f: 74 | f.close() 75 | f = open( 76 | os.path.join(FLAGS.output, '%s.%d.txt' % (prefix, index)), 77 | 'w') 78 | index += 1 79 | f.write(t) 80 | f.close() 81 | write_files2('target_multi', targets) 82 | write_files2('prediction_multi', predictions) 83 | 84 | 85 | if __name__ == '__main__': 86 | app.run(main) 87 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from rouge import rouge_scorer 4 | from rouge import scoring 5 | from bert_score import score 6 | import transformers 7 | transformers.tokenization_utils.logger.setLevel(logging.ERROR) 8 | transformers.configuration_utils.logger.setLevel(logging.ERROR) 9 | transformers.modeling_utils.logger.setLevel(logging.ERROR) 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | def rouge(targets, predictions, score_keys=None, use_stemmer=True): 15 | """Computes rouge score. 16 | Args: 17 | targets: list of strings 18 | predictions: list of strings 19 | score_keys: list of strings with the keys to compute. 20 | Returns: 21 | dict with score_key: myrouge score across all targets and predictions 22 | """ 23 | 24 | if score_keys is None: 25 | score_keys = ["rouge1", "rouge2", "rougeL", "rougeLsum"] 26 | scorer = rouge_scorer.RougeScorer(score_keys, use_stemmer=use_stemmer) 27 | aggregator = scoring.BootstrapAggregator() 28 | 29 | for prediction, target in zip(predictions, targets): 30 | target = target 31 | prediction = prediction 32 | aggregator.add_scores(scorer.score(target=target, prediction=prediction)) 33 | result = aggregator.aggregate() 34 | res_str = '\n'.join(["%s = %.2f, 95%% confidence [%.2f, %.2f]" % 35 | (key, result[key].mid.fmeasure * 100, result[key].low.fmeasure * 100, 36 | result[key].high.fmeasure * 100,) for key in score_keys]) 37 | 38 | logger.info(res_str) 39 | return {key: result[key].mid.fmeasure * 100 for key in score_keys}, res_str 40 | 41 | 42 | def extScore(sources, predictions, use_stemmer=True): 43 | """Computes rouge score. 44 | Args: 45 | sources: list of strings 46 | predictions: list of strings 47 | Returns: 48 | dict with score_key: rouge score across all targets and predictions 49 | """ 50 | scorer = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=use_stemmer) 51 | precisions = [] 52 | for prediction, source in zip(predictions, sources): 53 | source = source 54 | prediction = prediction 55 | res = scorer.score(target=source, prediction=prediction) 56 | precisions.append(res["rougeLsum"].precision) 57 | precision = 1. - np.mean(precisions) 58 | res = {"ext_rouge2_prec": precision*100} 59 | res_str = '\n'.join(["%s = %.2f" % 60 | (key, res[key]) for key in res]) 61 | 62 | logger.info(res_str) 63 | return res, res_str 64 | 65 | 66 | def bertScore(refs, cands, rescale_with_baseline=True): 67 | P, R, F1 = score(cands, refs, lang="en", rescale_with_baseline=rescale_with_baseline) 68 | res = {"bertScore": F1.mean().item()*100} 69 | res_str = '\n'.join(["%s = %.2f" % 70 | (key, res[key]) for key in res]) 71 | logger.info(res_str) 72 | return res, res_str 73 | -------------------------------------------------------------------------------- /rouge/rouge.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | r"""Main routine to calculate ROUGE scores across text files. 17 | 18 | Designed to replicate scores computed by the ROUGE perl implementation as 19 | closely as possible. 20 | 21 | Output is a text file in CSV format. 22 | 23 | Sample usage: 24 | 25 | myrouge ---rouge_types=rouge1,rouge2,rougeL \ 26 | --target_filepattern=*.targets \ 27 | --prediction_fliepattern=*.decodes \ 28 | --output_filename=scores.csv \ 29 | --use_stemmer 30 | 31 | Which is equivalent to calling the perl ROUGE script as: 32 | 33 | ROUGE-1.5.5.pl -m -e ./data -n 2 -a /tmp/myrouge/settings.xml 34 | 35 | Where settings.xml provides target and decode text. 36 | """ 37 | 38 | from __future__ import absolute_import 39 | from __future__ import division 40 | from __future__ import print_function 41 | 42 | from absl import app 43 | from absl import flags 44 | from rouge import io 45 | from rouge import rouge_scorer 46 | from rouge import scoring 47 | 48 | flags.DEFINE_string("target_filepattern", None, 49 | "Files containing target text.") 50 | flags.DEFINE_string("prediction_filepattern", None, 51 | "Files containing prediction text.") 52 | flags.DEFINE_string("output_filename", None, 53 | "File in which to write calculated ROUGE scores as a CSV.") 54 | flags.DEFINE_string("delimiter", "\n", 55 | "Record delimiter in files.") 56 | flags.DEFINE_list("rouge_types", ["rouge1", "rouge2", "rougeL"], 57 | "List of ROUGE types to calculate.") 58 | flags.DEFINE_boolean("use_stemmer", False, 59 | "Whether to use Porter stemmer to remove common suffixes.") 60 | flags.DEFINE_boolean("aggregate", True, 61 | "Write aggregates if this is set to True") 62 | 63 | FLAGS = flags.FLAGS 64 | 65 | 66 | def main(argv): 67 | if len(argv) > 1: 68 | raise app.UsageError("Too many command-line arguments.") 69 | scorer = rouge_scorer.RougeScorer(FLAGS.rouge_types, FLAGS.use_stemmer) 70 | aggregator = scoring.BootstrapAggregator() if FLAGS.aggregate else None 71 | io.compute_scores_and_write_to_csv( 72 | FLAGS.target_filepattern, 73 | FLAGS.prediction_filepattern, 74 | FLAGS.output_filename, 75 | scorer, 76 | aggregator, 77 | delimiter=FLAGS.delimiter) 78 | 79 | 80 | if __name__ == "__main__": 81 | flags.mark_flag_as_required("target_filepattern") 82 | flags.mark_flag_as_required("prediction_filepattern") 83 | flags.mark_flag_as_required("output_filename") 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /rouge/README.md: -------------------------------------------------------------------------------- 1 | From https://github.com/google-research/google-research/tree/master/rouge 2 | 3 | Original code assumes sentences are separated by newline. We changed it 4 | to split sentences via nltk.sent_tokenize. 5 | 6 | # Python ROUGE Implementation 7 | 8 | ## Overview 9 | 10 | This is a native python implementation of ROUGE, designed to replicate results 11 | from the original perl package. 12 | 13 | ROUGE was originally introduced in the paper: 14 | 15 | Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In 16 | Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004), 17 | Barcelona, Spain, July 25 - 26, 2004. 18 | 19 | ## ROUGE for Python 20 | 21 | There are ROUGE implementations available for Python, however some are not 22 | native python due to their dependency on the perl script, and others provide 23 | differing results when compared with the original implementation. This makes it 24 | difficult to directly compare with known results. 25 | 26 | This package is designed to replicate perl results. It implements: 27 | 28 | * ROUGE-N (N-gram) scoring 29 | * ROUGE-L (Longest Common Subsequence) scoring 30 | * Text normalization 31 | * Bootstrap resampling for confidence interval calculation 32 | * Optional Porter stemming to remove plurals and word suffixes such as (ing, 33 | ion, ment). 34 | 35 | Note that not all options provided by the original perl ROUGE script are 36 | supported, but the subset of options that are implemented should replicate the 37 | original functionality. 38 | 39 | ## Stopword removal 40 | 41 | The original ROUGE perl script implemented optional stopword removal (using the 42 | -s parameter). However, there were ~600 stopwords used by ROUGE, borrowed from 43 | another now defunct package. This word list contained many words that may not be 44 | suited to some tasks, such as day and month names and numbers. It also has no 45 | clear license for redistribution. Since we are unable to replicate this 46 | functionality precisely we do not include stopword removal. 47 | 48 | ## Two flavors of ROUGE-L 49 | In the ROUGE paper, two flavors of ROUGE are described: 50 | 51 | 1. sentence-level: Compute longest common subsequence (LCS) between two pieces of 52 | text. Newlines are ignored. This is called `rougeL` in this package. 53 | 2. summary-level: Newlines in the text are interpreted as sentence boundaries, 54 | and the LCS is computed between each pair of reference and candidate sentences, 55 | and something called union-LCS is computed. This is called `rougeLsum` in this 56 | package. This is the ROUGE-L reported in *[Get To The Point: Summarization with 57 | Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)*, for example. 58 | 59 | ## How to run 60 | 61 | This package compares target files (containing one example per line) with 62 | prediction files in the same format. It can be launched as follows (from 63 | google-research/): 64 | 65 | ```shell 66 | python -m rouge.rouge \ 67 | --target_filepattern=*.targets \ 68 | --prediction_filepattern=*.decodes \ 69 | --output_filename=scores.csv \ 70 | --use_stemmer=true 71 | ``` 72 | 73 | ## Using pip 74 | ``` 75 | pip install rouge/requirements.txt 76 | pip install rouge-score 77 | ``` 78 | 79 | Then in python: 80 | 81 | ```python 82 | from rouge_score import rouge_scorer 83 | 84 | scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) 85 | scores = scorer.score('The quick brown fox jumps over the lazy dog', 86 | 'The quick brown dog jumps on the log.') 87 | ``` 88 | 89 | ## License 90 | 91 | Licensed under the 92 | [Apache 2.0](https://github.com/google-research/google-research/blob/master/LICENSE) 93 | License. 94 | 95 | ## Disclaimer 96 | 97 | This is not an official Google product. 98 | -------------------------------------------------------------------------------- /rouge/io_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tests for myrouge input/output library.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import tempfile 23 | 24 | from absl.testing import absltest 25 | from rouge import io 26 | from rouge import rouge_scorer 27 | from rouge import scoring 28 | from rouge import test_util 29 | 30 | 31 | class IoTest(absltest.TestCase): 32 | 33 | def testProducesValidOutput(self): 34 | with tempfile.NamedTemporaryFile() as output_file: 35 | output_filename = output_file.name 36 | scorer = rouge_scorer.RougeScorer(["rouge1"], False) 37 | io.compute_scores_and_write_to_csv(test_util.TARGETS_FILE, 38 | test_util.PREDICTIONS_FILE, 39 | output_filename, scorer, 40 | scoring.BootstrapAggregator()) 41 | with open(output_filename) as f: 42 | csv_lines = f.readlines() 43 | output_types = tuple((line.split(",")[0] for line in csv_lines)) 44 | self.assertEqual(output_types[0], "score_type") 45 | self.assertSameElements(output_types[1:], 46 | ["rouge1-P", "rouge1-R", "rouge1-F"]) 47 | 48 | def testUnAggregated(self): 49 | with tempfile.NamedTemporaryFile() as output_file: 50 | output_filename = output_file.name 51 | scorer = rouge_scorer.RougeScorer(["rouge1"], False) 52 | io.compute_scores_and_write_to_csv(test_util.TARGETS_FILE, 53 | test_util.PREDICTIONS_FILE, 54 | output_filename, scorer, None) 55 | with open(output_filename) as f: 56 | csv_lines = f.readlines() 57 | ids = tuple((line.split(",")[0] for line in csv_lines)) 58 | self.assertEqual(ids[0], "id") 59 | self.assertLen(csv_lines, 3) 60 | 61 | def testDelimitedFile(self): 62 | with tempfile.NamedTemporaryFile() as output_file: 63 | output_filename = output_file.name 64 | scorer = rouge_scorer.RougeScorer(["rouge1"], False) 65 | io.compute_scores_and_write_to_csv( 66 | test_util.DELIMITED_FILE, 67 | test_util.DELIMITED_FILE, 68 | output_filename, 69 | scorer, 70 | None, 71 | delimiter=":") 72 | with open(output_filename) as f: 73 | csv_lines = f.readlines() 74 | ids = tuple((line.split(",")[0] for line in csv_lines)) 75 | self.assertEqual(ids[0], "id") 76 | self.assertLen(csv_lines, 5) 77 | 78 | def testAssertsOnInvalidInputFiles(self): 79 | scorer = rouge_scorer.RougeScorer(["rouge1"], False) 80 | with self.assertRaises(ValueError): 81 | io.compute_scores_and_write_to_csv("invalid*", "invalid*", "invalid", 82 | scorer, scoring.BootstrapAggregator()) 83 | 84 | def testAssertsOnInvalidRougeTypes(self): 85 | scorer = rouge_scorer.RougeScorer(["rougex"], False) 86 | with self.assertRaises(ValueError): 87 | io.compute_scores_and_write_to_csv(test_util.TARGETS_FILE, 88 | test_util.PREDICTIONS_FILE, "", scorer, 89 | scoring.BootstrapAggregator()) 90 | 91 | 92 | if __name__ == "__main__": 93 | absltest.main() 94 | -------------------------------------------------------------------------------- /Avocado/standardize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import ujson as json 4 | import nltk 5 | nltk.download('punkt') 6 | 7 | 8 | def prepare_one(lines): 9 | thread_id = None 10 | subject = None 11 | numemail = 0 12 | froms = [''] * 10 13 | tos = [''] * 10 14 | emails = [''] * 10 15 | i = 0 16 | while i < len(lines): 17 | line = lines[i].strip() 18 | if line: 19 | if line.startswith("THREAD"): 20 | thread_id = int(line.split(' ')[1]) 21 | elif line.startswith("subject:"): 22 | subject = line.replace("subject:", "").strip() 23 | # if "docomo" in subject: 24 | # print(thread_id) 25 | # exit() 26 | elif line.startswith("# of emails:"): 27 | numemail = int(line.replace("# of emails:", "").strip()) 28 | elif line.startswith("Email"): 29 | index = int(line.split(' ')[1].strip()) 30 | elif line.startswith("From:"): 31 | froms[index] = line.replace('From:', '').strip() 32 | elif line.startswith("To:"): 33 | tos[index] = line.replace('To:', '').strip() 34 | elif line.startswith("Content:"): 35 | i += 1 36 | line = lines[i].strip() 37 | content = [] 38 | while i < len(lines) and (not line.startswith("===")): 39 | if line: 40 | content.append(line) 41 | i += 1 42 | line = lines[i].strip() 43 | emails[index] = ' '.join(content) 44 | i += 1 45 | return thread_id, subject, numemail, froms, tos, emails 46 | 47 | 48 | def standardize(): 49 | with open("summaries/EmailSum_data.json", 'r') as f: 50 | data = json.load(f)["data"] 51 | train = data["train"] 52 | test = data["test"] 53 | with open("summaries/one_more_reference.json", 'r') as f: 54 | test1 = json.load(f)["test"] 55 | np.random.seed(42) 56 | np.random.shuffle(train) 57 | dev = train[:249] 58 | train = train[249:] 59 | 60 | if not os.path.exists("exp_data"): 61 | os.mkdir("exp_data") 62 | if not os.path.exists("exp_data/data_email_short"): 63 | os.mkdir("exp_data/data_email_short") 64 | if not os.path.exists("exp_data/data_email_long"): 65 | os.mkdir("exp_data/data_email_long") 66 | 67 | for dset, sset in [(train, "train"), (dev, "dev"), (test, "test")]: 68 | short_source, short_target, short_id = [], [], [] 69 | long_source, long_target, long_id = [], [], [] 70 | short_target1, long_target1 = [], [] 71 | for thread in dset: 72 | thread_id = thread["thread_id"] 73 | numemail = thread["num_of_emails"] 74 | short_summary = thread["short_summary"]["content"] 75 | long_summary = thread["long_summary"]["content"] 76 | with open(f"Avocado_threads/{numemail}/{thread_id}", 'r') as f: 77 | lines = f.readlines() 78 | _, subject, _, froms, tos, emails = prepare_one(lines) 79 | source = '|||'.join([f"{f}: {e}" for f, e in zip(["Subject"] + froms, [subject] + emails) if f and e]) 80 | short_source.append(source) 81 | short_target.append(short_summary) 82 | short_id.append(thread_id) 83 | long_source.append(source) 84 | long_target.append(long_summary) 85 | long_id.append(thread_id) 86 | if sset == "test": 87 | short_target1.append(test1[thread_id]["short_summary"]["content"]) 88 | long_target1.append(test1[thread_id]["long_summary"]["content"]) 89 | with open(f"exp_data/data_email_short/{sset}.source", 'w') as f: 90 | f.write('\n'.join(short_source)) 91 | with open(f"exp_data/data_email_short/{sset}.target", 'w') as f: 92 | f.write('\n'.join(short_target)) 93 | with open(f"exp_data/data_email_short/{sset}.id", 'w') as f: 94 | f.write('\n'.join(short_id)) 95 | with open(f"exp_data/data_email_long/{sset}.source", 'w') as f: 96 | f.write('\n'.join(long_source)) 97 | with open(f"exp_data/data_email_long/{sset}.target", 'w') as f: 98 | f.write('\n'.join(long_target)) 99 | with open(f"exp_data/data_email_long/{sset}.id", 'w') as f: 100 | f.write('\n'.join(long_id)) 101 | if sset == "test": 102 | with open(f"exp_data/data_email_short/{sset}.target1", 'w') as f: 103 | f.write('\n'.join(short_target1)) 104 | with open(f"exp_data/data_email_long/{sset}.target1", 'w') as f: 105 | f.write('\n'.join(long_target1)) 106 | 107 | 108 | if __name__ == '__main__': 109 | standardize() -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import ujson as json 4 | 5 | parser = argparse.ArgumentParser() 6 | 7 | parser.add_argument( 8 | "--task", 9 | default=None, 10 | type=str, 11 | required=True, 12 | help="Task name", 13 | ) 14 | parser.add_argument( 15 | "--data_dir", 16 | default=None, 17 | type=str, 18 | required=True, 19 | help="data directory", 20 | ) 21 | parser.add_argument( 22 | "--model_type", 23 | default="t5", 24 | type=str, 25 | help="Model type name", 26 | ) 27 | parser.add_argument( 28 | "--model", 29 | default="t5-base", 30 | type=str, 31 | help="Model name", 32 | ) 33 | parser.add_argument( 34 | "--seed", 35 | default=42, 36 | type=int, 37 | help="Seed", 38 | ) 39 | parser.add_argument( 40 | "--min_output_length", 41 | default=5, 42 | type=int, 43 | help="minimum output length during decoding", 44 | ) 45 | parser.add_argument( 46 | "--max_turn_length", 47 | default=10, 48 | type=int, 49 | help="maximum turn length", 50 | ) 51 | parser.add_argument( 52 | "--max_per_turn_length", default=1, type=int, 53 | help="The maximum length of input per turn.", 54 | ) 55 | parser.add_argument( 56 | "--max_input_length", 57 | default=512, 58 | type=int, 59 | help="maximum input length at each step", 60 | ) 61 | parser.add_argument( 62 | "--max_output_length", # this is not the same as the max_output_length in t5.py 63 | default=56, 64 | type=int, 65 | help="maximum output length at each step", 66 | ) 67 | parser.add_argument( 68 | "--dec_max_output_length", # this is the same as the max_output_length in t5.py 69 | default=200, 70 | type=int, 71 | help="maximum output length during decoding", 72 | ) 73 | parser.add_argument( 74 | "--train_batch_size", 75 | default=4, 76 | type=int, 77 | help="Train batch size", 78 | ) 79 | parser.add_argument( 80 | "--eval_batch_size", 81 | default=4, 82 | type=int, 83 | help="Eval batch size", 84 | ) 85 | parser.add_argument( 86 | "--gradient_accumulation_steps", 87 | default=32, 88 | type=int, 89 | help="Gradient accumulation steps", 90 | ) 91 | parser.add_argument( 92 | "--num_train_epochs", 93 | default=70, 94 | type=int, 95 | help="Number of training epochs", 96 | ) 97 | parser.add_argument("--learning_rate", default=5e-4, type=float, help="The initial learning rate for Adam.") 98 | parser.add_argument("--memory_type", default="base", type=str, help="The type of model, base or ht5") 99 | parser.add_argument("--scheduler", default="linear", type=str, help="Scheduler name") 100 | parser.add_argument("--optimizer", default="adam", type=str, help="The optimizer to use") 101 | parser.add_argument("--warmup_steps", default=0., type=float, help="Linear warmup over warmup_steps.") 102 | parser.add_argument("--zero_shot", action="store_true", help="Whether to update parameter.") 103 | parser.add_argument("--from_pretrain", action="store_true", help="Load from pretrained model") 104 | parser.add_argument("--pretrained_checkpoint", default=None, type=str, help="Load from pretrained model") 105 | parser.add_argument("--from_scratch", action="store_true", help="Load from pretrained model") 106 | parser.add_argument("--test_only", action="store_true", help="Whether only test") 107 | parser.add_argument("--two_ref", action="store_true", help="Evaluate with two references") 108 | args = parser.parse_args() 109 | 110 | data_dir = args.data_dir 111 | output_dir = f"train/{args.model_type}_{args.task}_seed{args.seed}_{args.memory_type}_{args.model}" 112 | 113 | os.system(f"python3 t5.py --model_type {args.model_type} --model_name_or_path {args.model} " 114 | f"{f'--from_pretrain --pretrained_checkpoint {args.pretrained_checkpoint}' if args.from_pretrain else ''} " 115 | f"{'--from_scratch' if args.from_scratch else ''} " 116 | f"--task_name {args.task} {'' if args.test_only else '--do_train --do_eval'} --do_test " 117 | f"--data_dir {data_dir} " 118 | f"--data_file {data_dir}/train_{args.max_input_length}_{args.max_output_length}.t5 " 119 | f"--eval_file {data_dir}/train.target --source_file {data_dir}/train.source " 120 | f"--dev_data_file {data_dir}/dev_{args.max_input_length}_{args.max_output_length}.t5 " 121 | f"--dev_eval_file {data_dir}/dev.target --dev_source_file {data_dir}/dev.source " 122 | f"--test_data_file {data_dir}/test_{args.max_input_length}_{args.max_output_length}.t5 " 123 | f"--test_eval_file {data_dir}/test.target --test_source_file {data_dir}/test.source " 124 | f"{f'--test_eval_file1 {data_dir}/test.target1' if args.two_ref else ''} " 125 | f"--cache_dir train/cache --max_input_length {args.max_input_length} --do_lower_case " 126 | f"--min_output_length {args.min_output_length} --max_output_length {args.dec_max_output_length} " 127 | f"--max_turn_length {args.max_turn_length} " 128 | f"--per_gpu_eval_batch_size={args.eval_batch_size} --per_gpu_train_batch_size={args.train_batch_size} " 129 | f"--learning_rate {args.learning_rate} --warmup_steps {args.warmup_steps} " 130 | f"--scheduler {args.scheduler} --optimizer {args.optimizer} " 131 | f"--gradient_accumulation_steps {args.gradient_accumulation_steps} --num_train_epochs {args.num_train_epochs} " 132 | f"--output_dir {output_dir} --memory_type {args.memory_type} " 133 | f"--overwrite_output_dir --evaluate_during_training --seed {args.seed} ") 134 | -------------------------------------------------------------------------------- /rouge/scoring.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Library for scoring and evaluation of text samples. 18 | 19 | Aggregation functions use bootstrap resampling to compute confidence intervals 20 | as per the original ROUGE perl implementation. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import abc 28 | import collections 29 | 30 | import numpy as np 31 | import six 32 | from six.moves import range 33 | 34 | 35 | class Score( 36 | collections.namedtuple("Score", ["precision", "recall", "fmeasure"])): 37 | """Tuple containing precision, recall, and f-measure values.""" 38 | 39 | 40 | class BaseScorer(object): 41 | """Base class for Scorer objects.""" 42 | 43 | @abc.abstractmethod 44 | def score(self, target, prediction): 45 | """Calculates score between the target and prediction. 46 | 47 | Args: 48 | target: Text containing the target (ground truth) text. 49 | prediction: Text containing the predicted text. 50 | 51 | Returns: 52 | A dict mapping each score_type (string) to Score object. 53 | """ 54 | 55 | 56 | class AggregateScore( 57 | collections.namedtuple("AggregateScore", ["low", "mid", "high"])): 58 | """Tuple containing confidence intervals for scores.""" 59 | 60 | 61 | class BootstrapAggregator(object): 62 | """Aggregates scores to provide confidence intervals. 63 | 64 | Sample usage: 65 | scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL']) 66 | aggregator = Aggregator() 67 | aggregator.add_scores(scorer.score("one two three", "one two")) 68 | aggregator.add_scores(scorer.score("one two five six", "seven eight")) 69 | result = aggregator.aggregate() 70 | print result 71 | {'rougeL': AggregateScore( 72 | low=Score(precision=0.0, recall=0.0, fmeasure=0.0), 73 | mid=Score(precision=0.5, recall=0.33, fmeasure=0.40), 74 | high=Score(precision=1.0, recall=0.66, fmeasure=0.80)), 75 | 'rouge1': AggregateScore( 76 | low=Score(precision=0.0, recall=0.0, fmeasure=0.0), 77 | mid=Score(precision=0.5, recall=0.33, fmeasure=0.40), 78 | high=Score(precision=1.0, recall=0.66, fmeasure=0.80))} 79 | """ 80 | 81 | def __init__(self, confidence_interval=0.95, n_samples=1000): 82 | """Initializes a BootstrapAggregator object. 83 | 84 | Args: 85 | confidence_interval: Confidence interval to compute on the mean as a 86 | decimal. 87 | n_samples: Number of samples to use for bootstrap resampling. 88 | 89 | Raises: 90 | ValueError: If invalid argument is given. 91 | """ 92 | 93 | if confidence_interval < 0 or confidence_interval > 1: 94 | raise ValueError("confidence_interval must be in range [0, 1]") 95 | if n_samples <= 0: 96 | raise ValueError("n_samples must be positive") 97 | 98 | self._n_samples = n_samples 99 | self._confidence_interval = confidence_interval 100 | self._scores = collections.defaultdict(list) 101 | 102 | def add_scores(self, scores): 103 | """Adds a sample for future aggregation. 104 | 105 | Args: 106 | scores: Dict mapping score_type strings to a namedtuple object/class 107 | representing a score. 108 | """ 109 | 110 | for score_type, score in six.iteritems(scores): 111 | self._scores[score_type].append(score) 112 | 113 | def aggregate(self): 114 | """Aggregates scores previously added using add_scores. 115 | 116 | Returns: 117 | A dict mapping score_type to AggregateScore objects. 118 | """ 119 | 120 | result = {} 121 | for score_type, scores in six.iteritems(self._scores): 122 | # Stack scores into a 2-d matrix of (sample, measure). 123 | score_matrix = np.vstack(tuple(scores)) 124 | # Percentiles are returned as (interval, measure). 125 | percentiles = self._bootstrap_resample(score_matrix) 126 | # Extract the three intervals (low, mid, high). 127 | intervals = tuple( 128 | (scores[0].__class__(*percentiles[j, :]) for j in range(3))) 129 | result[score_type] = AggregateScore( 130 | low=intervals[0], mid=intervals[1], high=intervals[2]) 131 | return result 132 | 133 | def _bootstrap_resample(self, matrix): 134 | """Performs bootstrap resampling on a matrix of scores. 135 | 136 | Args: 137 | matrix: A 2-d matrix of (sample, measure). 138 | 139 | Returns: 140 | A 2-d matrix of (bounds, measure). There are three bounds: low (row 0), 141 | mid (row 1) and high (row 2). Mid is always the mean, while low and high 142 | bounds are specified by self._confidence_interval (which defaults to 0.95 143 | meaning it will return the 2.5th and 97.5th percentiles for a 95% 144 | confidence interval on the mean). 145 | """ 146 | 147 | # Matrix of (bootstrap sample, measure). 148 | sample_mean = np.zeros((self._n_samples, matrix.shape[1])) 149 | for i in range(self._n_samples): 150 | sample_idx = np.random.choice( 151 | np.arange(matrix.shape[0]), size=matrix.shape[0]) 152 | sample = matrix[sample_idx, :] 153 | sample_mean[i, :] = np.mean(sample, axis=0) 154 | 155 | # Take percentiles on the estimate of the mean using bootstrap samples. 156 | # Final result is a (bounds, measure) matrix. 157 | percentile_delta = (1 - self._confidence_interval) / 2 158 | q = 100 * np.array([percentile_delta, 0.5, 1 - percentile_delta]) 159 | return np.percentile(sample_mean, q, axis=0) 160 | 161 | 162 | def fmeasure(precision, recall): 163 | """Computes f-measure given precision and recall values.""" 164 | 165 | if precision + recall > 0: 166 | return 2 * precision * recall / (precision + recall) 167 | else: 168 | return 0.0 169 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from os import path 3 | import numpy as np 4 | import ujson as json 5 | from tqdm import tqdm 6 | 7 | from torch.utils.data import TensorDataset 8 | from transformers import T5Tokenizer 9 | import nltk 10 | nltk.data.path.append('./nltk_data') 11 | nltk.download('punkt', download_dir='./nltk_data') 12 | 13 | 14 | def process_file_t5(filename, tokenizer, lower=True): 15 | examples, eval = [], {} 16 | total = 0 17 | input_lens = [] 18 | output_lens = [] 19 | with open(f"{filename}.source", 'r') as fa, open(f"{filename}.target", 'r') as fs: 20 | while True: 21 | total += 1 22 | article = fa.readline() 23 | summary = fs.readline() 24 | if not article or not summary: 25 | break 26 | article = article.strip() 27 | summary = summary.strip() 28 | if lower: 29 | article = article.lower() 30 | summary = summary.lower() 31 | if not article or not summary: 32 | article = 'null' 33 | summary = 'null' 34 | inputs = article.split('|||') 35 | entities, input_tokens, speaker_tokens = {}, [], [] 36 | turns = [] 37 | # add prompt 38 | prompt = tokenizer.tokenize("summarize:") 39 | input_tokens.extend(prompt) 40 | turns.extend([0] * len(prompt)) 41 | for i, input in enumerate(inputs): 42 | input = tokenizer.tokenize(input.strip()) 43 | turns.extend([i + 1] * len(input)) 44 | input_tokens.extend(input) 45 | # summary 46 | output_tokens = tokenizer.tokenize(summary) 47 | input_lens.append(len(input_tokens)) 48 | output_lens.append(len(output_tokens)) 49 | example = {"input": input_tokens, "output": output_tokens, "turn": turns, "id": total} 50 | eval[str(total)] = (article, summary) 51 | examples.append(example) 52 | np.random.shuffle(examples) 53 | print("{} examples in total".format(len(examples))) 54 | print("max_input: {}, max_output: {}.".format(max(input_lens), max(output_lens))) 55 | print("avg_input: {}, avg_output: {}.".format(np.mean(input_lens), np.mean(output_lens))) 56 | return examples, eval 57 | 58 | 59 | def build_features_t5(examples, data_type, out_file, tokenizer, max_input_length, max_output_length): 60 | print("Processing {} examples...".format(data_type)) 61 | total = 0 62 | input_inputs, output_inputs, turn_inputs, ids = [], [], [], [] 63 | 64 | for example in tqdm(examples): 65 | total += 1 66 | 67 | # task input 68 | input_input = tokenizer.convert_tokens_to_ids(example["input"])[:max_input_length] 69 | input_inputs.append(input_input + [0] * (max_input_length - len(input_input))) 70 | 71 | # turn input 72 | turn_input = [turn for turn in example["turn"]][:max_input_length] 73 | turn_inputs.append(turn_input + [0] * (max_input_length - len(turn_input))) 74 | 75 | # task output 76 | output = tokenizer.convert_tokens_to_ids(example["output"]) 77 | output_input = [0] + output[:max_output_length - 2] + [tokenizer.eos_token_id] 78 | output_inputs.append(output_input + [0] * (max_output_length - len(output_input))) 79 | 80 | ids.append(int(example["id"])) 81 | 82 | input_inputs = torch.tensor(input_inputs, dtype=torch.long) 83 | turn_inputs = torch.tensor(turn_inputs, dtype=torch.long) 84 | output_inputs = torch.tensor(output_inputs, dtype=torch.long) 85 | ids = torch.tensor(ids, dtype=torch.long) 86 | 87 | dataset = TensorDataset(input_inputs, output_inputs, turn_inputs, ids) 88 | torch.save({"dataset": dataset}, out_file) 89 | print("Built {} instances of features in total".format(total)) 90 | 91 | 92 | def prepare_t5(tokenizer, data_dir, max_input_length, max_output_length, lower=True): 93 | train_file = f"{data_dir}/train" 94 | dev_file = f"{data_dir}/dev" 95 | test_file = f"{data_dir}/test" 96 | train_out = f"{data_dir}/train_{max_input_length}_{max_output_length}.t5" 97 | dev_out = f"{data_dir}/dev_{max_input_length}_{max_output_length}.t5" 98 | test_out = f"{data_dir}/test_{max_input_length}_{max_output_length}.t5" 99 | # process files 100 | if path.exists(train_file + '.source'): 101 | print(f"prepare {train_out}") 102 | train_examples, train_eval = process_file_t5(train_file, tokenizer, lower=lower) 103 | build_features_t5(train_examples, "train", train_out, tokenizer, max_input_length=max_input_length, 104 | max_output_length=max_output_length) 105 | if path.exists(dev_file + '.source'): 106 | print(f"prepare {dev_out}") 107 | dev_examples, dev_eval = process_file_t5(dev_file, tokenizer, lower=lower) 108 | build_features_t5(dev_examples, "dev", dev_out, tokenizer, max_input_length=max_input_length, 109 | max_output_length=max_output_length) 110 | if path.exists(test_file + '.source'): 111 | print(f"prepare {test_out}") 112 | test_examples, test_eval = process_file_t5(test_file, tokenizer, lower=lower) 113 | build_features_t5(test_examples, "test", test_out, tokenizer, max_input_length=max_input_length, 114 | max_output_length=max_output_length) 115 | 116 | 117 | if __name__ == '__main__': 118 | import argparse 119 | 120 | parser = argparse.ArgumentParser() 121 | 122 | parser.add_argument( 123 | "--data_dir", 124 | default=None, 125 | type=str, 126 | required=True, 127 | help="The input data dir. Should contain the *.txt files for the task.", 128 | ) 129 | parser.add_argument( 130 | "--cache_dir", 131 | default="", 132 | type=str, 133 | help="Where do you want to store the pre-trained models downloaded from s3", 134 | ) 135 | parser.add_argument( 136 | "--max_input_length", 137 | default=512, 138 | type=int, 139 | help="The maximum total input sequence length after tokenization. Sequences longer " 140 | "than this will be discarded, sequences shorter will be padded.", 141 | ) 142 | parser.add_argument( 143 | "--max_output_length", 144 | default=56, 145 | type=int, 146 | help="The maximum total input sequence length after tokenization. Sequences longer " 147 | "than this will be discarded, sequences shorter will be padded.", 148 | ) 149 | parser.add_argument("--lower", default=False, type=bool, help="Lower case") 150 | args = parser.parse_args() 151 | 152 | tokenizer = T5Tokenizer.from_pretrained("t5-base", cache_dir=args.cache_dir) 153 | prepare_t5(tokenizer, args.data_dir, args.max_input_length, args.max_output_length, args.lower) 154 | -------------------------------------------------------------------------------- /W3C/extract_threads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | import numpy as np 5 | import datetime 6 | import ujson as json 7 | from tqdm import tqdm 8 | 9 | 10 | def unzip_files(): 11 | os.system("tar -xzvf raw_data/w3c-emails-part1.tar.gz") 12 | os.system("tar -xzvf raw_data/w3c-emails-part2.tar.gz") 13 | os.system("mv lists-* w3c-emails") 14 | os.system("tar -xzvf raw_data/w3c-emails-part3.tar.gz") 15 | os.system("mv lists-* w3c-emails") 16 | 17 | 18 | def group_emails_by_subjects(): 19 | print("====Group emails by their subjects and save to subject.json====") 20 | subjects = {} 21 | files = os.listdir("w3c-emails") 22 | model1 = re.compile(r'[a-zA-Z0-9]') 23 | model2 = re.compile(r'[^\x00-\x7f]') 24 | for file in tqdm(files): 25 | with codecs.open(f"w3c-emails/{file}", 'r', encoding='utf-8', errors='ignore')as f: 26 | while True: 27 | line = f.readline() 28 | if not line: 29 | break 30 | line = line.strip() 31 | if line.startswith("subject"): 32 | subject = line.split('=')[1].replace('"', '').lower() 33 | if '?????' in subject: 34 | break 35 | subject = subject.replace("re:", '').replace("aw:", '').replace("fw:", '').replace( 36 | "fwd:", '').replace('re[2]:', '').replace('re (3):', '').replace('(fwd)', '').strip() 37 | subject = ' '.join([item for item in subject.split(' ') if item]).strip() 38 | if not subject: 39 | break 40 | if not model1.search(subject): 41 | break 42 | if model2.search(subject): 43 | break 44 | if subject not in subjects: 45 | subjects[subject] = [] 46 | subjects[subject].append(file) 47 | break 48 | subjects = {subject: subjects[subject] for subject in subjects if len(subjects[subject]) > 2} 49 | print(f"======{len(subjects)} subjects in total======") 50 | # there should be 15183 subjects 51 | with open("subjects.txt", 'w') as f: 52 | f.write('\n'.join(sorted([subject for subject in subjects]))) 53 | with open("subjects.json", 'w') as f: 54 | json.dump(subjects, f) 55 | 56 | 57 | def extract_threads(): 58 | print("====Extract threads and save to W3C.json====") 59 | with open("subjects.json", 'r') as f: 60 | subjects = json.load(f) 61 | threads = {} 62 | lens = [] 63 | for subject in tqdm(subjects): 64 | thread = [] 65 | for file in subjects[subject]: 66 | with codecs.open(f"w3c-emails/{file}", 'r', encoding='utf-8', errors='ignore')as f: 67 | email = {"to":[], "content": []} 68 | while True: 69 | line = f.readline() 70 | if not line: 71 | break 72 | line = line.strip() 73 | if not line: 74 | continue 75 | if line.startswith("docno="): 76 | email["docno"] = line.split('=')[1].replace('"', '').strip() 77 | elif line.startswith("received="): 78 | email["received"] = line.split('=')[1].replace('"', '').strip() 79 | items = [item for item in email["received"].split(' ') if item] 80 | if items[-1] in ["EST", "EDT"]: 81 | items = items[:-1] 82 | items = items[:3] + [''.join(items[3:-1]), items[-1]] 83 | date = ' '.join(items) 84 | email["received_time"] = str(datetime.datetime.timestamp(datetime.datetime.strptime(date, "%c"))) 85 | elif line.startswith("isoreceived="): 86 | email["isoreceived"] = line.split('=')[1].replace('"', '').strip() 87 | elif line.startswith("name="): 88 | email["name"] = line.split('=')[1].replace('"', '').strip() 89 | elif line.startswith("email="): 90 | email["email"] = line.split('=')[1].replace('"', '').strip() 91 | elif line.startswith("subject="): 92 | email["subject"] = line.split('=')[1].replace('"', '').strip() 93 | elif line.startswith("To:") or line.startswith("to:") or line.startswith("TO:"): 94 | email["to"].extend(line.replace("To:", '').replace("to:", '').replace("TO:", '').replace( 95 | '"', '').replace("'", '').strip().split(',')) 96 | elif line.startswith("CC:") or line.startswith("Cc:") or line.startswith("cc:"): 97 | email["to"].extend(line.replace("CC:", '').replace("Cc:", '').replace("cc:", '').replace( 98 | '"', '').replace("'", '').strip().split(',')) 99 | elif line.startswith("sent=") or line.startswith("isosent=") or line.startswith("id=") \ 100 | or line.startswith("charset=") or line.startswith("expires=") \ 101 | or line.startswith("inreplyto="): 102 | continue 103 | else: 104 | email["content"].append(line) 105 | try: 106 | assert "received_time" in email 107 | assert len(email["to"]) > 0 108 | except: 109 | continue 110 | thread.append(email) 111 | 112 | if len(thread) == 0: 113 | continue 114 | 115 | thread = sorted(thread, key=lambda x: float(x["received_time"])) 116 | 117 | new_thread = [] 118 | emails = {} 119 | persons = set() 120 | for email in thread: 121 | my = email["email"].lower() 122 | tos = [my] 123 | key = ' '.join(tos + [email["received_time"]]) 124 | if key in emails: 125 | continue 126 | for to in email["to"]: 127 | if '<' in to and '>' in to: 128 | to = to.split('<')[1].split('>')[0].lower() 129 | to = re.sub(r'\(\S+\)', '', to) 130 | tos.append(to.lower()) 131 | if persons and len(set(tos) & persons) == 0: 132 | break 133 | new_thread.append(email) 134 | persons.update(tos) 135 | emails[key] = email 136 | 137 | content = set() 138 | for email in new_thread: 139 | content.add(' '.join(email["content"]).lower()) 140 | if len(content) == 1: 141 | continue 142 | 143 | if 2 < len(new_thread) <= 50: 144 | lens.append(len(new_thread)) 145 | threads[subject] = new_thread 146 | 147 | print(f"======{len(lens)} threads in total, the average/max thread lengths are {np.mean(lens)}/{max(lens)}======") 148 | # it should print 13794 threads in total, the average/max thread lengths are 6.445411048281862/50" 149 | with open("W3C.json", 'w') as f: 150 | json.dump(threads, f) 151 | 152 | 153 | if __name__ == '__main__': 154 | unzip_files() 155 | group_emails_by_subjects() 156 | extract_threads() -------------------------------------------------------------------------------- /rouge/io.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Library for reading/writing input and score files.""" 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import glob 24 | 25 | from absl import logging 26 | import six 27 | from six.moves import zip 28 | from six.moves import zip_longest 29 | 30 | 31 | 32 | def compute_scores_and_write_to_csv(target_filepattern, 33 | prediction_filepattern, 34 | output_filename, 35 | scorer, 36 | aggregator, 37 | delimiter="\n"): 38 | """Runs aggregate score calculations and outputs results to a CSV file. 39 | 40 | Args: 41 | target_filepattern: Pattern for files containing target text. 42 | prediction_filepattern: Pattern for files containing prediction text. 43 | output_filename: Name of file to write results to. 44 | scorer: A BaseScorer object to compute scores. 45 | aggregator: An aggregator to aggregate scores. If None, outputs are 46 | per-example scores. 47 | delimiter: Record delimiter. 48 | """ 49 | 50 | target_filenames = _glob(target_filepattern) 51 | prediction_filenames = _glob(prediction_filepattern) 52 | scores = _compute_scores(target_filenames, prediction_filenames, scorer, 53 | delimiter) 54 | if aggregator: 55 | for score in scores: 56 | aggregator.add_scores(score) 57 | _write_aggregates_to_csv(output_filename, aggregator.aggregate()) 58 | else: 59 | _write_scores_to_csv(output_filename, scores) 60 | 61 | 62 | def _glob(filepattern): 63 | return glob.glob(filepattern) # pylint: disable=unreachable 64 | 65 | 66 | def _open(filepattern, mode="r"): 67 | return open(filepattern, mode) # pylint: disable=unreachable 68 | 69 | 70 | def _record_gen(filename, delimiter): 71 | """Opens file and yields records separated by delimiter.""" 72 | with _open(filename) as f: 73 | records = f.read().split(six.ensure_str(delimiter)) 74 | if records[-1]: 75 | # Need a final delimiter at end of file to be able to detect an empty last 76 | # record. 77 | logging.warn("Expected delimiter at end of file") 78 | else: 79 | records = records[:-1] 80 | for record in records: 81 | yield record 82 | 83 | 84 | def _compute_scores(target_filenames, prediction_filenames, scorer, delimiter): 85 | """Computes aggregates scores across the given target and prediction files. 86 | 87 | Args: 88 | target_filenames: List of filenames from which to read target lines. 89 | prediction_filenames: List of filenames from which to read prediction lines. 90 | scorer: A BaseScorer object to compute scores. 91 | delimiter: string delimiter between each record in input files 92 | Returns: 93 | A list of dicts mapping score_type to Score objects. 94 | Raises: 95 | ValueError: If invalid targets or predictions are provided. 96 | """ 97 | 98 | if (len(target_filenames) < 1 or 99 | len(target_filenames) != len(prediction_filenames)): 100 | raise ValueError("Must have equal and positive number of target and " 101 | "prediction files. Found: %d target files, %d prediction " 102 | "files." % (len(target_filenames), 103 | len(prediction_filenames))) 104 | 105 | scores = [] 106 | for target_filename, prediction_filename in zip( 107 | sorted(target_filenames), sorted(prediction_filenames)): 108 | logging.info("Reading targets from %s.", target_filename) 109 | logging.info("Reading predictions from %s.", prediction_filename) 110 | targets = _record_gen(target_filename, delimiter) 111 | preds = _record_gen(prediction_filename, delimiter) 112 | for target_rec, prediction_rec in zip_longest(targets, preds): 113 | if target_rec is None or prediction_rec is None: 114 | raise ValueError("Must have equal number of lines across target and " 115 | "prediction files. Mismatch between files: %s, %s." % 116 | (target_filename, prediction_filename)) 117 | scores.append(scorer.score(target_rec, prediction_rec)) 118 | 119 | return scores 120 | 121 | 122 | def _write_aggregates_to_csv(output_filename, aggregates): 123 | """Writes aggregate scores to an output CSV file. 124 | 125 | Output file is a comma separated where each line has the format: 126 | score_type-(P|R|F),low_ci,mean,high_ci 127 | 128 | P/R/F indicates whether the score is a precision, recall or f-measure. 129 | 130 | Args: 131 | output_filename: Name of file to write results to. 132 | aggregates: A dict mapping each score_type to a AggregateScore object. 133 | """ 134 | 135 | logging.info("Writing results to %s.", output_filename) 136 | with _open(output_filename, "w") as output_file: 137 | output_file.write("score_type,low,mid,high\n") 138 | for score_type, aggregate in sorted(aggregates.items()): 139 | output_file.write("%s-R,%f,%f,%f\n" % 140 | (score_type, aggregate.low.recall, aggregate.mid.recall, 141 | aggregate.high.recall)) 142 | output_file.write("%s-P,%f,%f,%f\n" % 143 | (score_type, aggregate.low.precision, 144 | aggregate.mid.precision, aggregate.high.precision)) 145 | output_file.write("%s-F,%f,%f,%f\n" % 146 | (score_type, aggregate.low.fmeasure, 147 | aggregate.mid.fmeasure, aggregate.high.fmeasure)) 148 | logging.info("Finished writing results.") 149 | 150 | 151 | def _write_scores_to_csv(output_filename, scores): 152 | """Writes scores for each individual example to an output CSV file. 153 | 154 | Output file is a comma separated where each line has the format: 155 | id,score1,score2,score3,... 156 | 157 | The header row indicates the type of each score column. 158 | 159 | Args: 160 | output_filename: Name of file to write results to. 161 | scores: A list of dicts mapping each score_type to a Score object. 162 | """ 163 | 164 | if len(scores) < 1: 165 | logging.warn("No scores to write") 166 | return 167 | rouge_types = sorted(scores[0].keys()) 168 | 169 | logging.info("Writing results to %s.", output_filename) 170 | with _open(output_filename, "w") as out_file: 171 | out_file.write("id") 172 | for rouge_type in rouge_types: 173 | out_file.write(",{t}-P,{t}-R,{t}-F".format(t=rouge_type)) 174 | out_file.write("\n") 175 | for i, result in enumerate(scores): 176 | out_file.write("%d" % i) 177 | for rouge_type in rouge_types: 178 | out_file.write(",%f,%f,%f" % 179 | (result[rouge_type].precision, result[rouge_type].recall, 180 | result[rouge_type].fmeasure)) 181 | out_file.write("\n") 182 | logging.info("Finished writing results.") 183 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EmailSum (ACL 2021) 2 | 3 | This repository contains the data and code for the following paper: 4 | 5 | [EmailSum: Abstractive Email Thread Summarization]() 6 | 7 | ``` 8 | @inproceedings{zhang2021emailsum, 9 | title={EmailSum: Abstractive Email Thread Summarization}, 10 | author={Zhang, Shiyue and Celikyilmaz, Asli and Gao, Jianfeng and Bansal, Mohit}, 11 | booktitle={Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics}, 12 | year={2021} 13 | } 14 | ``` 15 | 16 | ## Data 17 | 18 | We only release the summaries we collected and provide scripts to extract email threads 19 | from flat email corpus (Avocado or W3C), because Avocado's copyright is protected by Linguistic Data Consortium. 20 | 21 | ### Requirements 22 | 23 | * Python 3 24 | * requirements.txt 25 | * Download Avocado Research Email Collection from [LDC](https://catalog.ldc.upenn.edu/LDC2015T03) 26 | 27 | ### Avocado 28 | We collected the summaries of 2,549 Avocado email threads (see Avocado/summaries/EmailSum_data.json). 29 | We collected one more reference for each of the 500 email threads in the testing set after submission 30 | (see Avocado/summaries/one_more_reference.json). 31 | 32 | * First, cd Avocado/ 33 | 34 | * Download "emails.json" from [here](https://drive.google.com/file/d/1OK1fjBn269N3Cx8QUAFga1fWn9nNE_N8/view?usp=sharing) 35 | and put it under Avocado/ 36 | 37 | * Extract threads, assuming $ROOT_DIR contains the LDC2015T03 (i.e., $ROOT_DIR/LDC2015T03/Data/avocado-1.0.2) 38 | ``` 39 | python extract_threads.py --root_dir $ROOT_DIR 40 | ``` 41 | You will get "Avocado.json" which contains all extracted threads. 42 | 43 | 44 | * Anonymize & Filter 45 | ``` 46 | python anonymize.py 47 | ``` 48 | After this step, you can see cleaned threads under "Avocado_threads/". 49 | 50 | 51 | * Prepare Train/Dev/Test files 52 | ``` 53 | python standardize.py 54 | ``` 55 | After this step, you can see experimental files under "exp_data/". 56 | There are two sub-directories: "data_email_short" and "data_email_long" for short 57 | and long summary, respectively. 58 | Each line of the *.source file is one email thread, in which 59 | emails are separated by "|||". 60 | 61 | 62 | ### W3C 63 | We provide the code for extracting threads from W3C email corpus for semi-supervised learning. 64 | 65 | * First, cd "W3C/" 66 | 67 | * Download raw data files from [here](https://drive.google.com/drive/folders/1ZPGdzvauoEN4qqsZ2ZxD4EkWobCZPZhZ?usp=sharing) 68 | and put them under "W3C/raw_data/" 69 | 70 | * Extract threads 71 | ``` 72 | python extract_threads.py 73 | ``` 74 | You will get "W3C.json" which contains all extracted threads. 75 | 76 | * Anonymize & Filter 77 | ``` 78 | python anonymize.py 79 | ``` 80 | After this step, you can see all cleaned thread under "W3C_threads/". 81 | 82 | 83 | ## Model 84 | 85 | ### Requirements 86 | 87 | * Python 3 88 | * PyTorch 1.7, transformers==2.11.0 89 | 90 | ### Test pre-trained models 91 | 92 | * Download pre-trained models from [here](https://drive.google.com/drive/folders/1KXnoGpyzcwESfLN9JmzM0gDz_75WVkaj?usp=sharing), decompress, and put them under "train/". 93 | 94 | Note that we conduct model selection for each metric, so there are multiple 95 | best checkpoints, e.g., "checkpoint-rouge1" is the best ROUGE1 checkpoint selected by ROUGE1 on 96 | development set. "best_ckpt.json" contains the best scores on development set. 97 | 98 | * Prepare data 99 | 100 | After you get "Avocado/exp_data/data_email_short" and "Avocado/exp_data/data_email_long", run 101 | ``` 102 | python3 data.py --data_dir Avocado/exp_data/data_email_long --cache_dir train/cache --max_output_length 128 103 | python3 data.py --data_dir Avocado/exp_data/data_email_short --cache_dir train/cache --max_output_length 56 104 | 105 | ``` 106 | 107 | * Test 108 | 109 | T5 baselines 110 | ``` 111 | python3 run.py --task email_long --data_dir Avocado/exp_data/data_email_long/ --test_only --max_output_length 128 112 | python3 run.py --task email_short --data_dir Avocado/exp_data/data_email_short/ --test_only --max_output_length 56 113 | ``` 114 | 115 | Hierarchical T5 116 | ``` 117 | python3 run.py --task email_long --memory_type ht5 --data_dir Avocado/exp_data/data_email_long/ --test_only --max_output_length 128 118 | python3 run.py --task email_short --memory_type ht5 --data_dir Avocado/exp_data/data_email_short/ --test_only --max_output_length 56 119 | ``` 120 | 121 | Semi-supervised models 122 | ``` 123 | python3 run.py --task email_long_w3c --data_dir Avocado/exp_data/data_email_long/ --test_only --max_output_length 128 124 | python3 run.py --task email_short_together --data_dir Avocado/exp_data/data_email_short/ --test_only --max_output_length 56 125 | ``` 126 | 127 | The testing scores will be saved in "best_ckpt_test.json". 128 | We provide "best_ckpt_test_verification.json" for verification of results, almost the same numbers should be obtained. 129 | 130 | We also provide "best_ckpt_test_old.json" that contains our previously tested scores (reported in the paper). 131 | You are likely to get slightly different numbers from "best_ckpt_test_old.json" because we added a few more data 132 | clean and anonymization rules. The pre-processed *.source files will be 133 | slightly different from the ones we used before. 134 | 135 | * Test with two references 136 | 137 | Just add "--two_ref", e.g., 138 | ``` 139 | python3 run.py --task email_long --data_dir Avocado/exp_data/data_email_long/ --test_only --two_ref --max_output_length 128 140 | ``` 141 | 142 | The testing scores will be saved in "best_ckpt_test_2ref.json". 143 | We provide "best_ckpt_test_2ref_verification.json" for verification of results, almost the same numbers should be obtained. 144 | 145 | 146 | ### Benchmark Results 147 | **One-reference** results: 148 | 149 | | EmailSum **Short** | rouge1 | rouge2 | rougeL | rougeLsum | BERTScore | 150 | | :------------- | :----------: | -----------: | -----------: | -----------: | -----------: | 151 | | T5 base | 36.61 | 10.58 | 28.29 | 32.77 | 33.92 | 152 | | HT5 | 36.30 | 10.74 | 28.52 | 33.33 | 33.49 | 153 | | Semi-sup. (together)| 36.99 | 11.22 | 28.71 | 33.70 | 33.91 | 154 | 155 | | EmailSum **Long** | rouge1 | rouge2 | rougeL | rougeLsum | BERTScore | 156 | | :------------- | :----------: | -----------: | -----------: | -----------: | -----------: | 157 | | T5 base | 43.87 | 14.10 | 30.50 | 39.91 | 32.07 | 158 | | HT5 | 44.44 | 14.51 | 30.86 | 40.24 | 32.31 | 159 | | Semi-sup. (w3c)| 44.58 | 14.64 | 31.40 | 40.73 | 32.80 | 160 | 161 | 162 | **Two-reference** results (average the results of two references): 163 | 164 | | EmailSum **Short** | rouge1 | rouge2 | rougeL | rougeLsum | BERTScore | 165 | | :------------- | :----------: | -----------: | -----------: | -----------: | -----------: | 166 | | T5 base | 35.22 | 9.60 | 27.08 | 31.22 | 32.45 | 167 | | HT5 | 34.81 | 9.82 | 27.28 | 31.74 | 32.42 | 168 | | Semi-sup. (together)| 35.52 | 10.35 | 27.29 | 33.11 | 32.24 | 169 | 170 | | EmailSum **Long** | rouge1 | rouge2 | rougeL | rougeLsum | BERTScore | 171 | | :------------- | :----------: | -----------: | -----------: | -----------: | -----------: | 172 | | T5 base | 43.41 | 13.81 | 29.97 | 39.32 | 31.58 | 173 | | HT5 | 43.86 | 14.06 | 30.17 | 39.64 | 31.84 | 174 | | Semi-sup. (w3c)| 43.99 | 14.18 | 30.56 | 40.12 | 32.04 | 175 | 176 | Interestingly, we always get lower scores when comparing to the 2nd reference we collected after 177 | paper submission. That's why two-reference results are always worse than one-reference ones. 178 | It may be caused by the different set of turkers involved in summary annotation that 179 | brings domain shift. 180 | 181 | ### Train 182 | 183 | Just drop "--test_only", e.g., 184 | ``` 185 | python3 run.py --task email_long --data_dir Avocado/exp_data/data_email_long/ --max_output_length 128 186 | ``` 187 | 188 | 189 | 190 | 191 | -------------------------------------------------------------------------------- /rouge/scoring_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Tests for myrouge scoring and aggregation. 18 | 19 | Checks for both correctness, and for consistency with values from the perl ROUGE 20 | implementation which this package replicates. 21 | """ 22 | 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import os 28 | 29 | from absl.testing import absltest 30 | import numpy as np 31 | from six.moves import range 32 | from six.moves import zip 33 | from rouge import rouge_scorer 34 | from rouge import scoring 35 | from rouge import test_util 36 | 37 | # Delta for matching against ground truth myrouge values. Must be relatively 38 | # high compared to the individual myrouge tests since bootstrap sampling 39 | # introduces randomness. 40 | _DELTA = 0.002 41 | 42 | # Use a fixed random seed, or tests may fail with nonzero probability. 43 | _RANDOM_SEED = 123 44 | 45 | 46 | class BootstrapAggregatorTest(absltest.TestCase): 47 | 48 | def setUp(self): 49 | super(BootstrapAggregatorTest, self).setUp() 50 | np.random.seed(_RANDOM_SEED) 51 | with open(test_util.LARGE_TARGETS_FILE) as f: 52 | self.targets = f.readlines() 53 | with open(test_util.LARGE_PREDICTIONS_FILE) as f: 54 | self.predictions = f.readlines() 55 | 56 | def assertSimilarAggregates(self, precision, recall, fmeasure, aggregate, 57 | delta=_DELTA): 58 | """Helper method for asserting matching aggregate scores. 59 | 60 | Args: 61 | precision: Tuple of (low, mid, high) precision scores. 62 | recall: Tuple of (low, mid, high) recall scores. 63 | fmeasure: Tuple of (low, mid, high) fmeasure scores. 64 | aggregate: An AggregateScore object. 65 | delta: Tolerance delta for matching values. 66 | """ 67 | 68 | self.assertAlmostEqual(precision[0], aggregate.low.precision, delta=delta) 69 | self.assertAlmostEqual(precision[1], aggregate.mid.precision, delta=delta) 70 | self.assertAlmostEqual(precision[2], aggregate.high.precision, delta=delta) 71 | self.assertAlmostEqual(recall[0], aggregate.low.recall, delta=delta) 72 | self.assertAlmostEqual(recall[1], aggregate.mid.recall, delta=delta) 73 | self.assertAlmostEqual(recall[2], aggregate.high.recall, delta=delta) 74 | self.assertAlmostEqual(fmeasure[0], aggregate.low.fmeasure, delta=delta) 75 | self.assertAlmostEqual(fmeasure[1], aggregate.mid.fmeasure, delta=delta) 76 | self.assertAlmostEqual(fmeasure[2], aggregate.high.fmeasure, delta=delta) 77 | 78 | def testConsistentPercentiles(self): 79 | aggregator = scoring.BootstrapAggregator(confidence_interval=0.9) 80 | aggregator.add_scores({ 81 | "rouge1": scoring.Score(precision=1, recall=1 / 3, fmeasure=1 / 2) 82 | }) 83 | aggregator.add_scores({ 84 | "rouge1": scoring.Score(precision=0, recall=0, fmeasure=0) 85 | }) 86 | aggregator.add_scores({ 87 | "rouge1": scoring.Score(precision=1, recall=1, fmeasure=1) 88 | }) 89 | result = aggregator.aggregate() 90 | 91 | self.assertSimilarAggregates((1 / 3, 2 / 3, 3 / 3), 92 | (1 / 9, 4 / 9, 7 / 9), 93 | (1 / 6, 3 / 6, 5 / 6), 94 | result["rouge1"], delta=1e-8) 95 | 96 | def testLargeConfidence(self): 97 | aggregator = scoring.BootstrapAggregator(confidence_interval=0.0) 98 | aggregator.add_scores({ 99 | "rouge1": scoring.Score(precision=1, recall=1 / 3, fmeasure=1 / 2) 100 | }) 101 | aggregator.add_scores({ 102 | "rouge1": scoring.Score(precision=0, recall=0, fmeasure=0) 103 | }) 104 | aggregator.add_scores({ 105 | "rouge1": scoring.Score(precision=1, recall=1, fmeasure=1) 106 | }) 107 | result = aggregator.aggregate() 108 | 109 | self.assertSimilarAggregates((2 / 3, 2 / 3, 2 / 3), 110 | (4 / 9, 4 / 9, 4 / 9), 111 | (3 / 6, 3 / 6, 3 / 6), 112 | result["rouge1"], delta=1e-8) 113 | 114 | def testMultipleRougeTypes(self): 115 | scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=False) 116 | aggregator = scoring.BootstrapAggregator() 117 | for target, prediction in zip(self.targets[:5], self.predictions[:5]): 118 | aggregator.add_scores(scorer.score(target, prediction)) 119 | result = aggregator.aggregate() 120 | 121 | self.assertSameElements(list(result.keys()), ["rouge1", "rougeL"]) 122 | 123 | def testConfidenceIntervalsAgainstRouge155(self): 124 | scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=False) 125 | aggregator = scoring.BootstrapAggregator() 126 | for target, prediction in zip(self.targets, self.predictions): 127 | aggregator.add_scores(scorer.score(target, prediction)) 128 | result = aggregator.aggregate() 129 | 130 | self.assertSimilarAggregates((0.48695, 0.49879, 0.51131), 131 | (0.31106, 0.31950, 0.32849), 132 | (0.37614, 0.38554, 0.39581), 133 | result["rouge1"]) 134 | 135 | def testConfidenceIntervalsAgainstRouge155WithStemming(self): 136 | scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"], use_stemmer=True) 137 | aggregator = scoring.BootstrapAggregator() 138 | for target, prediction in zip(self.targets, self.predictions): 139 | aggregator.add_scores(scorer.score(target, prediction)) 140 | result = aggregator.aggregate() 141 | 142 | self.assertSimilarAggregates((0.51027, 0.52434, 0.53788), 143 | (0.32563, 0.33580, 0.34548), 144 | (0.39380, 0.40524, 0.41661), 145 | result["rouge1"]) 146 | self.assertSimilarAggregates((0.50759, 0.52104, 0.53382), # P 147 | (0.32418, 0.33377, 0.34362), # R 148 | (0.39157, 0.40275, 0.41383), # F 149 | result["rougeL"]) 150 | 151 | def testConfidenceIntervalsAgainstRouge155WithStemmingMultiLine(self): 152 | scorer = rouge_scorer.RougeScorer( 153 | ["rouge1", "rouge2", "rougeLsum"], use_stemmer=True) 154 | aggregator = scoring.BootstrapAggregator() 155 | t_files = [os.path.join(test_util.PYROUGE_DIR, 'target_multi.%d.txt' % i) for i in range(0, 250)] 156 | p_files = [os.path.join(test_util.PYROUGE_DIR, 'prediction_multi.%d.txt' % i) for i in range(0, 250)] 157 | 158 | targets = [test_util.get_text(x) for x in t_files] 159 | predictions = [test_util.get_text(x) for x in p_files] 160 | assert len(targets) == len(predictions) 161 | assert len(targets) == 250 162 | for target, prediction in zip(targets, predictions): 163 | aggregator.add_scores(scorer.score(target, prediction)) 164 | result = aggregator.aggregate() 165 | 166 | # DIR = testdata/pyrouge_evaluate_plain_text_files 167 | # pyrouge_evaluate_plain_text_files -s $DIR -sfp "prediction_multi.(.*).txt" 168 | # -m $DIR -mfp target_multi.#ID#.txt 169 | self.assertSimilarAggregates((0.58963, 0.59877, 0.60822), # P 170 | (0.37327, 0.38091, 0.38914), # R 171 | (0.45607, 0.46411, 0.47244), # F 172 | result["rouge1"]) 173 | self.assertSimilarAggregates((0.35429, 0.36516, 0.37665), # P 174 | (0.22341, 0.23109, 0.23916), # R 175 | (0.27312, 0.28209, 0.29133), # F 176 | result["rouge2"]) 177 | self.assertSimilarAggregates((0.58604, 0.59491, 0.60444), # P 178 | (0.37084, 0.37846, 0.38671), # R 179 | (0.45305, 0.46113, 0.46946), # F 180 | result["rougeLsum"]) 181 | 182 | 183 | if __name__ == "__main__": 184 | absltest.main() 185 | -------------------------------------------------------------------------------- /rouge/rouge_scorer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Computes myrouge scores between two text blobs. 18 | 19 | Implementation replicates the functionality in the original ROUGE package. See: 20 | 21 | Lin, Chin-Yew. ROUGE: a Package for Automatic Evaluation of Summaries. In 22 | Proceedings of the Workshop on Text Summarization Branches Out (WAS 2004), 23 | Barcelona, Spain, July 25 - 26, 2004. 24 | 25 | Default options are equivalent to running: 26 | ROUGE-1.5.5.pl -e data -n 2 -a settings.xml 27 | 28 | Or with use_stemmer=True: 29 | ROUGE-1.5.5.pl -m -e data -n 2 -a settings.xml 30 | 31 | In these examples settings.xml lists input files and formats. 32 | """ 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import collections 39 | import re 40 | 41 | import nltk 42 | nltk.download('punkt', download_dir='./nltk_data') 43 | nltk.data.path.append('./nltk_data') 44 | from nltk import sent_tokenize 45 | from nltk.stem import porter 46 | import six 47 | from six.moves import map 48 | from six.moves import range 49 | from rouge import scoring 50 | from rouge import tokenize 51 | 52 | 53 | class RougeScorer(scoring.BaseScorer): 54 | """Calculate rouges scores between two blobs of text. 55 | 56 | Sample usage: 57 | scorer = RougeScorer(['rouge1', 'rougeL'], use_stemmer=True) 58 | scores = scorer.score('The quick brown fox jumps over the lazy dog', 59 | 'The quick brown dog jumps on the log.') 60 | """ 61 | 62 | def __init__(self, rouge_types, use_stemmer=False): 63 | """Initializes a new RougeScorer. 64 | 65 | Valid rouge types that can be computed are: 66 | rougen (e.g. rouge1, rouge2): n-gram based scoring. 67 | rougeL: Longest common subsequence based scoring. 68 | 69 | Args: 70 | rouge_types: A list of rouge types to calculate. 71 | use_stemmer: Bool indicating whether Porter stemmer should be used to 72 | strip word suffixes to improve matching. 73 | Returns: 74 | A dict mapping rouge types to Score tuples. 75 | """ 76 | 77 | self.rouge_types = rouge_types 78 | self._stemmer = porter.PorterStemmer() if use_stemmer else None 79 | 80 | def score(self, target, prediction): 81 | """Calculates rouge scores between the target and prediction. 82 | 83 | Args: 84 | target: Text containing the target (ground truth) text. 85 | prediction: Text containing the predicted text. 86 | Returns: 87 | A dict mapping each rouge type to a Score object. 88 | Raises: 89 | ValueError: If an invalid rouge type is encountered. 90 | """ 91 | 92 | target_tokens = tokenize.tokenize(target, self._stemmer) 93 | prediction_tokens = tokenize.tokenize(prediction, self._stemmer) 94 | result = {} 95 | 96 | for rouge_type in self.rouge_types: 97 | if rouge_type == "rougeL": 98 | # Rouge from longest common subsequences. 99 | scores = _score_lcs(target_tokens, prediction_tokens) 100 | elif rouge_type == "rougeLsum": 101 | # Note: Does not support multi-line text. 102 | def get_sents(text): 103 | # Assume sentences are separated by newline. 104 | sents = sent_tokenize(text) 105 | # sents = six.ensure_str(text).split("\n") 106 | sents = [x for x in sents if len(x)] 107 | return sents 108 | 109 | target_tokens_list = [ 110 | tokenize.tokenize(s, self._stemmer) for s in get_sents(target)] 111 | prediction_tokens_list = [ 112 | tokenize.tokenize(s, self._stemmer) for s in get_sents(prediction)] 113 | scores = _summary_level_lcs(target_tokens_list, 114 | prediction_tokens_list) 115 | elif re.match(r"rouge[0-9]$", six.ensure_str(rouge_type)): 116 | # Rouge from n-grams. 117 | n = int(rouge_type[5:]) 118 | if n <= 0: 119 | raise ValueError("rougen requires positive n: %s" % rouge_type) 120 | target_ngrams = _create_ngrams(target_tokens, n) 121 | prediction_ngrams = _create_ngrams(prediction_tokens, n) 122 | scores = _score_ngrams(target_ngrams, prediction_ngrams) 123 | else: 124 | raise ValueError("Invalid rouge type: %s" % rouge_type) 125 | result[rouge_type] = scores 126 | 127 | return result 128 | 129 | 130 | def _create_ngrams(tokens, n): 131 | """Creates ngrams from the given list of tokens. 132 | 133 | Args: 134 | tokens: A list of tokens from which ngrams are created. 135 | n: Number of tokens to use, e.g. 2 for bigrams. 136 | Returns: 137 | A dictionary mapping each bigram to the number of occurrences. 138 | """ 139 | 140 | ngrams = collections.Counter() 141 | for ngram in (tuple(tokens[i:i + n]) for i in range(len(tokens) - n + 1)): 142 | ngrams[ngram] += 1 143 | return ngrams 144 | 145 | 146 | def _score_lcs(target_tokens, prediction_tokens): 147 | """Computes LCS (Longest Common Subsequence) rouge scores. 148 | 149 | Args: 150 | target_tokens: Tokens from the target text. 151 | prediction_tokens: Tokens from the predicted text. 152 | Returns: 153 | A Score object containing computed scores. 154 | """ 155 | 156 | if not target_tokens or not prediction_tokens: 157 | return scoring.Score(precision=0, recall=0, fmeasure=0) 158 | 159 | # Compute length of LCS from the bottom up in a table (DP appproach). 160 | lcs_table = _lcs_table(target_tokens, prediction_tokens) 161 | lcs_length = lcs_table[-1][-1] 162 | 163 | precision = lcs_length / len(prediction_tokens) 164 | recall = lcs_length / len(target_tokens) 165 | fmeasure = scoring.fmeasure(precision, recall) 166 | 167 | return scoring.Score(precision=precision, recall=recall, fmeasure=fmeasure) 168 | 169 | 170 | def _lcs_table(ref, can): 171 | """Create 2-d LCS score table.""" 172 | rows = len(ref) 173 | cols = len(can) 174 | lcs_table = [[0] * (cols + 1) for _ in range(rows + 1)] 175 | for i in range(1, rows + 1): 176 | for j in range(1, cols + 1): 177 | if ref[i - 1] == can[j - 1]: 178 | lcs_table[i][j] = lcs_table[i - 1][j - 1] + 1 179 | else: 180 | lcs_table[i][j] = max(lcs_table[i - 1][j], lcs_table[i][j - 1]) 181 | return lcs_table 182 | 183 | 184 | def _backtrack_norec(t, ref, can): 185 | """Read out LCS.""" 186 | i = len(ref) 187 | j = len(can) 188 | lcs = [] 189 | while i > 0 and j > 0: 190 | if ref[i - 1] == can[j - 1]: 191 | lcs.insert(0, i-1) 192 | i -= 1 193 | j -= 1 194 | elif t[i][j - 1] > t[i - 1][j]: 195 | j -= 1 196 | else: 197 | i -= 1 198 | return lcs 199 | 200 | 201 | def _summary_level_lcs(ref_sent, can_sent): 202 | """ROUGE: Summary-level LCS, section 3.2 in ROUGE paper. 203 | 204 | Args: 205 | ref_sent: list of tokenized reference sentences 206 | can_sent: list of tokenized candidate sentences 207 | 208 | Returns: 209 | summary level ROUGE score 210 | """ 211 | if not ref_sent or not can_sent: 212 | return scoring.Score(precision=0, recall=0, fmeasure=0) 213 | 214 | m = sum(map(len, ref_sent)) 215 | n = sum(map(len, can_sent)) 216 | if not n or not m: 217 | return scoring.Score(precision=0, recall=0, fmeasure=0) 218 | 219 | # get token counts to prevent double counting 220 | token_cnts_r = collections.Counter() 221 | token_cnts_c = collections.Counter() 222 | for s in ref_sent: 223 | # s is a list of tokens 224 | token_cnts_r.update(s) 225 | for s in can_sent: 226 | token_cnts_c.update(s) 227 | 228 | hits = 0 229 | for r in ref_sent: 230 | lcs = _union_lcs(r, can_sent) 231 | # Prevent double-counting: 232 | # The paper describes just computing hits += len(_union_lcs()), 233 | # but the implementation prevents double counting. We also 234 | # implement this as in version 1.5.5. 235 | for t in lcs: 236 | if token_cnts_c[t] > 0 and token_cnts_r[t] > 0: 237 | hits += 1 238 | token_cnts_c[t] -= 1 239 | token_cnts_r[t] -= 1 240 | 241 | recall = hits / m 242 | precision = hits / n 243 | fmeasure = scoring.fmeasure(precision, recall) 244 | return scoring.Score(precision=precision, recall=recall, fmeasure=fmeasure) 245 | 246 | 247 | def _union_lcs(ref, c_list): 248 | """Find union LCS between a ref sentence and list of candidate sentences. 249 | 250 | Args: 251 | ref: list of tokens 252 | c_list: list of list of indices for LCS into reference summary 253 | 254 | Returns: 255 | List of tokens in ref representing union LCS. 256 | """ 257 | lcs_list = [lcs_ind(ref, c) for c in c_list] 258 | return [ref[i] for i in _find_union(lcs_list)] 259 | 260 | 261 | def _find_union(lcs_list): 262 | """Finds union LCS given a list of LCS.""" 263 | return sorted(list(set().union(*lcs_list))) 264 | 265 | 266 | def lcs_ind(ref, can): 267 | """Returns one of the longest lcs.""" 268 | t = _lcs_table(ref, can) 269 | return _backtrack_norec(t, ref, can) 270 | 271 | 272 | def _score_ngrams(target_ngrams, prediction_ngrams): 273 | """Compute n-gram based rouge scores. 274 | 275 | Args: 276 | target_ngrams: A Counter object mapping each ngram to number of 277 | occurrences for the target text. 278 | prediction_ngrams: A Counter object mapping each ngram to number of 279 | occurrences for the prediction text. 280 | Returns: 281 | A Score object containing computed scores. 282 | """ 283 | 284 | intersection_ngrams_count = 0 285 | for ngram in six.iterkeys(target_ngrams): 286 | intersection_ngrams_count += min(target_ngrams[ngram], 287 | prediction_ngrams[ngram]) 288 | target_ngrams_count = sum(target_ngrams.values()) 289 | prediction_ngrams_count = sum(prediction_ngrams.values()) 290 | 291 | precision = intersection_ngrams_count / max(prediction_ngrams_count, 1) 292 | recall = intersection_ngrams_count / max(target_ngrams_count, 1) 293 | fmeasure = scoring.fmeasure(precision, recall) 294 | 295 | return scoring.Score(precision=precision, recall=recall, fmeasure=fmeasure) 296 | -------------------------------------------------------------------------------- /Avocado/extact_threads.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import codecs 4 | import datetime 5 | import numpy as np 6 | import ujson as json 7 | import xml.etree.ElementTree as ET 8 | from tqdm import tqdm 9 | 10 | 11 | def unzip_files(): 12 | files = os.listdir(f"{ROOT_DIR}/LDC2015T03/Data/avocado-1.0.2/data/text/") 13 | for file in files: 14 | os.system(f"unzip {ROOT_DIR}/LDC2015T03/Data/avocado-1.0.2/data/text/{file} " 15 | f"-d avocado_text") 16 | 17 | 18 | def get_emails(): 19 | print("====Get all emails and save to emails.json====") 20 | files = os.listdir(f"{ROOT_DIR}/LDC2015T03/Data/avocado-1.0.2/data/custodians") 21 | emails = {} 22 | dedup_emails = set() 23 | for file in tqdm(files): 24 | parsedXML = ET.parse(f"{ROOT_DIR}/LDC2015T03/Data/avocado-1.0.2/data/custodians/{file}") 25 | custodian = parsedXML.getroot() 26 | custodian = {item.tag: item for item in custodian} 27 | items = custodian["items"] 28 | for item in items: 29 | subitems = {subitem.tag: subitem for subitem in item} 30 | if "files" not in subitems or "metadata" not in subitems: 31 | continue 32 | for file in subitems["files"]: 33 | path = file.attrib['path'] 34 | items = path.split('/') 35 | if items[0] == "text" and 'EM' in items[2]: 36 | filename = items[2] 37 | break 38 | emails[filename] = {} 39 | dedup_emails.add(filename) 40 | if "relationships" in subitems: 41 | for relation in subitems["relationships"]: 42 | if relation.tag == "duplicate_of": 43 | dedup_emails.add(f'{relation.attrib["id"]}.txt') 44 | elif relation.tag == "reply_to": 45 | emails[filename]["reply_to"] = f'{relation.attrib["id"]}.txt' 46 | for field in subitems["metadata"]: 47 | if field.attrib["name"] == "arrival_date": 48 | emails[filename]["date"] = field.text 49 | elif field.attrib["name"] == "outlook_sender_name": 50 | emails[filename]["sender"] = field.text 51 | elif field.attrib["name"] == "processed_subject": 52 | emails[filename]["processed_subject"] = field.text 53 | elif field.attrib["name"] == "subject": 54 | emails[filename]["subject"] = field.text 55 | print(f"======{len(emails)} emails in total======") 56 | # there should be 937958 emails 57 | # !!! Please use our provided emails.json to get the exact threads we used because the order of emails matters. 58 | with open("emails_new.json", 'w') as f: 59 | json.dump(emails, f) 60 | 61 | 62 | def group_emails_by_subjects(): 63 | print("====Group emails by their subjects and save to subject.json====") 64 | subjects = {} 65 | # !!! Please use our provided emails.json to get the exact threads we used because the order of emails matters. 66 | with open("emails.json", 'r') as f: 67 | emails = json.load(f) 68 | model1 = re.compile(r'[a-zA-Z0-9]') 69 | model2 = re.compile(r'[^\x00-\x7f]') 70 | for email in tqdm(emails): 71 | directory = email.split('-')[0] 72 | with codecs.open(f"avocado_text/{directory}/{email}", 'r', encoding='utf-8', errors='ignore')as f: 73 | while True: 74 | line = f.readline() 75 | if not line: 76 | break 77 | line = line.strip() 78 | if line.startswith("Received:"): 79 | break 80 | if line.startswith("Subject:"): 81 | subject = line.replace("Subject:", '').lower() 82 | if '?????' in subject: 83 | break 84 | subject = subject.replace("re:", '').replace("aw:", '').replace("fw:", '').replace( 85 | "fwd:", '').replace('re[2]:', '').replace('re (3):', '').replace('(fwd)', '').strip() 86 | subject = ' '.join([item for item in subject.split(' ') if item]).strip() 87 | if not subject: 88 | break 89 | if not model1.search(subject): 90 | break 91 | if model2.search(subject): 92 | break 93 | if subject not in subjects: 94 | subjects[subject] = [] 95 | subjects[subject].append(email) 96 | break 97 | subjects = {subject: subjects[subject] for subject in subjects if len(subjects[subject]) > 2} 98 | print(f"======{len(subjects)} subjects in total======") 99 | # there should be 51367 subjects 100 | with open("subjects.txt", 'w') as f: 101 | f.write('\n'.join(sorted([subject for subject in subjects]))) 102 | with open("subjects.json", 'w') as f: 103 | json.dump(subjects, f) 104 | 105 | 106 | def extract_threads(): 107 | print("====Extract threads and save to Avocado.json====") 108 | threads = {} 109 | lens = [] 110 | with open("subjects.json", 'r') as f: 111 | subjects = json.load(f) 112 | datestr = "%d %b %Y %H:%M:%S UTC" 113 | for subject in tqdm(subjects): 114 | thread = [] 115 | for file in subjects[subject]: 116 | name = None 117 | email = None 118 | sub = None 119 | date = None 120 | received_time = None 121 | tos = [] 122 | content = [] 123 | directory = file.split('-')[0] 124 | with open(f"avocado_text/{directory}/{file}", 'r') as f: 125 | while True: 126 | line = f.readline() 127 | if not line: 128 | break 129 | line = line.strip() 130 | if not line: 131 | continue 132 | elif "-----Original Message-----" in line or '-----Ursprüngliche Nachricht-----' in line: 133 | break 134 | elif line.startswith("From:"): 135 | items = line.replace("From:", '').split('<') 136 | if len(items) < 2: 137 | break 138 | email = items[1].replace('>', '').strip() 139 | name = items[0].replace('"', '').strip() 140 | elif (line.startswith("To:") or line.startswith("to:") or line.startswith("TO:")): 141 | tos.extend(line.replace("To:", '').replace("to:", '').replace("TO:", '').replace( 142 | '"', '').replace("'", '').strip().split(',')) 143 | elif (line.startswith("CC:") or line.startswith("Cc:") or line.startswith("cc:")): 144 | tos.extend(line.replace("CC:", '').replace("Cc:", '').replace("cc:", '').replace( 145 | '"', '').replace("'", '').strip().split(',')) 146 | elif (line.startswith("Bcc:") or line.startswith("bcc:") or line.startswith("BCC:")): 147 | tos.extend(line.replace("Bcc:", '').replace("bcc:", '').replace("BCC:", '').replace( 148 | '"', '').replace("'", '').strip().split(',')) 149 | elif line.startswith("Subject:"): 150 | sub = line.replace("Subject:", '').strip() 151 | elif line.startswith("Date:"): 152 | try: 153 | date = line.replace("Date:", '').strip() 154 | received_time = str(datetime.datetime.timestamp(datetime.datetime.strptime(date, datestr))) 155 | except: 156 | break 157 | elif (line.startswith("Message-ID:") or 158 | line.startswith("In-Reply-To:") or line.startswith("MIME-Version:") or 159 | line.startswith("Reply-To:") or line.startswith("Content-Type:") or 160 | line.startswith("Sender:")): 161 | continue 162 | elif "===============================" in line: 163 | continue 164 | else: 165 | content.append(line) 166 | if not email or not tos or not content or not date or not received_time: 167 | continue 168 | else: 169 | # received_time = str(time.mktime(datetime.datetime.strptime(date, datestr).timetuple())) 170 | thread.append({"file": file, "name": name, "email": email, "subject": sub, "date": date, 171 | "received_time": received_time, "to": tos, "content": content}) 172 | if len(thread) == 0: 173 | continue 174 | 175 | thread = sorted(thread, key=lambda x: float(x["received_time"])) 176 | 177 | new_thread = [] 178 | emails = {} 179 | persons = set() 180 | for email in thread: 181 | my = email["email"].lower() 182 | tos = [my] 183 | key = ' '.join(tos + [email["received_time"]]) 184 | if key in emails: 185 | continue 186 | for to in email["to"]: 187 | if '<' in to and '>' in to: 188 | to = to.split('<')[1].split('>')[0].lower() 189 | to = re.sub(r'\(\S+\)', '', to) 190 | tos.append(to.lower()) 191 | if persons and len(set(tos) & persons) == 0: 192 | break 193 | new_thread.append(email) 194 | persons.update(tos) 195 | emails[key] = email 196 | 197 | # remove repeat content email threads 198 | content = set() 199 | for email in new_thread: 200 | content.add(' '.join(email["content"]).lower()) 201 | if len(content) == 1: 202 | continue 203 | 204 | if 2 < len(new_thread) <= 50: # 2 < thread length <= 50 205 | lens.append(len(new_thread)) 206 | threads[subject] = new_thread 207 | 208 | print(f"======{len(lens)} threads in total, the average/max thread lengths are {np.mean(lens)}/{max(lens)}======") 209 | # it should print "28416 threads in total, the average/max thread lengths are 5.34667088963964/50" 210 | with open("Avocado.json", 'w') as f: 211 | json.dump(threads, f) 212 | 213 | 214 | if __name__ == '__main__': 215 | import argparse 216 | 217 | parser = argparse.ArgumentParser() 218 | parser.add_argument("--root_dir", default="./", type=str, help="the directory contains LDC2015T03") 219 | args = parser.parse_args() 220 | ROOT_DIR = args.root_dir 221 | 222 | if not os.path.exists("avocado_text"): 223 | os.mkdir("avocado_text") 224 | unzip_files() 225 | group_emails_by_subjects() 226 | extract_threads() 227 | -------------------------------------------------------------------------------- /bert_score/score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pathlib 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1 import make_axes_locatable 8 | import numpy as np 9 | import pandas as pd 10 | 11 | from collections import defaultdict 12 | from transformers import AutoTokenizer 13 | 14 | from .utils import ( 15 | get_model, 16 | get_idf_dict, 17 | bert_cos_score_idf, 18 | get_bert_embedding, 19 | lang2model, 20 | model2layers, 21 | get_hash, 22 | cache_scibert, 23 | sent_encode, 24 | ) 25 | 26 | CACHE_DIR = "train/cache" 27 | __all__ = ["score", "plot_example"] 28 | 29 | 30 | def score( 31 | cands, 32 | refs, 33 | model_type=None, 34 | num_layers=None, 35 | verbose=False, 36 | idf=False, 37 | device=None, 38 | batch_size=64, 39 | nthreads=4, 40 | all_layers=False, 41 | lang=None, 42 | return_hash=False, 43 | rescale_with_baseline=False, 44 | ): 45 | """ 46 | BERTScore metric. 47 | 48 | Args: 49 | - :param: `cands` (list of str): candidate sentences 50 | - :param: `refs` (list of str or list of list of str): reference sentences 51 | - :param: `model_type` (str): bert specification, default using the suggested 52 | model for the target langauge; has to specify at least one of 53 | `model_type` or `lang` 54 | - :param: `num_layers` (int): the layer of representation to use. 55 | default using the number of layer tuned on WMT16 correlation data 56 | - :param: `verbose` (bool): turn on intermediate status update 57 | - :param: `idf` (bool or dict): use idf weighting, can also be a precomputed idf_dict 58 | - :param: `device` (str): on which the contextual embedding model will be allocated on. 59 | If this argument is None, the model lives on cuda:0 if cuda is available. 60 | - :param: `nthreads` (int): number of threads 61 | - :param: `batch_size` (int): bert score processing batch size 62 | - :param: `lang` (str): language of the sentences; has to specify 63 | at least one of `model_type` or `lang`. `lang` needs to be 64 | specified when `rescale_with_baseline` is True. 65 | - :param: `return_hash` (bool): return hash code of the setting 66 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 67 | 68 | Return: 69 | - :param: `(P, R, F)`: each is of shape (N); N = number of input 70 | candidate reference pairs. if returning hashcode, the 71 | output will be ((P, R, F), hashcode). If a candidate have 72 | multiple references, the returned score of this candidate is 73 | the *best* score among all references. 74 | """ 75 | assert len(cands) == len(refs), "Different number of candidates and references" 76 | 77 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 78 | 79 | ref_group_boundaries = None 80 | if not isinstance(refs[0], str): 81 | ref_group_boundaries = [] 82 | ori_cands, ori_refs = cands, refs 83 | cands, refs = [], [] 84 | count = 0 85 | for cand, ref_group in zip(ori_cands, ori_refs): 86 | cands += [cand] * len(ref_group) 87 | refs += ref_group 88 | ref_group_boundaries.append((count, count + len(ref_group))) 89 | count += len(ref_group) 90 | 91 | if rescale_with_baseline: 92 | assert lang is not None, "Need to specify Language when rescaling with baseline" 93 | 94 | if model_type is None: 95 | lang = lang.lower() 96 | model_type = lang2model[lang] 97 | if num_layers is None: 98 | num_layers = model2layers[model_type] 99 | 100 | if model_type.startswith("scibert"): 101 | tokenizer = AutoTokenizer.from_pretrained(cache_scibert(model_type), cache_dir=CACHE_DIR) 102 | else: 103 | tokenizer = AutoTokenizer.from_pretrained(model_type, cache_dir=CACHE_DIR) 104 | 105 | model = get_model(model_type, num_layers, all_layers) 106 | if device is None: 107 | device = "cuda" if torch.cuda.is_available() else "cpu" 108 | model.to(device) 109 | 110 | if not idf: 111 | idf_dict = defaultdict(lambda: 1.0) 112 | # set idf for [SEP] and [CLS] to 0 113 | idf_dict[tokenizer.sep_token_id] = 0 114 | idf_dict[tokenizer.cls_token_id] = 0 115 | elif isinstance(idf, dict): 116 | if verbose: 117 | print("using predefined IDF dict...") 118 | idf_dict = idf 119 | else: 120 | if verbose: 121 | print("preparing IDF dict...") 122 | start = time.perf_counter() 123 | idf_dict = get_idf_dict(refs, tokenizer, nthreads=nthreads) 124 | if verbose: 125 | print("done in {:.2f} seconds".format(time.perf_counter() - start)) 126 | 127 | if verbose: 128 | print("calculating scores...") 129 | start = time.perf_counter() 130 | all_preds = bert_cos_score_idf( 131 | model, 132 | refs, 133 | cands, 134 | tokenizer, 135 | idf_dict, 136 | verbose=verbose, 137 | device=device, 138 | batch_size=batch_size, 139 | all_layers=all_layers, 140 | ).cpu() 141 | 142 | if ref_group_boundaries is not None: 143 | max_preds = [] 144 | for beg, end in ref_group_boundaries: 145 | max_preds.append(all_preds[beg:end].max(dim=0)[0]) 146 | all_preds = torch.stack(max_preds, dim=0) 147 | 148 | if rescale_with_baseline: 149 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") 150 | if os.path.isfile(baseline_path): 151 | if not all_layers: 152 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() 153 | else: 154 | baselines = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() 155 | 156 | all_preds = (all_preds - baselines) / (1 - baselines) 157 | else: 158 | print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) 159 | 160 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F 161 | 162 | if verbose: 163 | time_diff = time.perf_counter() - start 164 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec") 165 | 166 | if return_hash: 167 | return tuple([out, get_hash(model_type, num_layers, idf, rescale_with_baseline)]) 168 | 169 | return out 170 | 171 | 172 | def plot_example( 173 | candidate, reference, model_type=None, num_layers=None, lang=None, rescale_with_baseline=False, fname="" 174 | ): 175 | """ 176 | BERTScore metric. 177 | 178 | Args: 179 | - :param: `candidate` (str): a candidate sentence 180 | - :param: `reference` (str): a reference sentence 181 | - :param: `verbose` (bool): turn on intermediate status update 182 | - :param: `model_type` (str): bert specification, default using the suggested 183 | model for the target langauge; has to specify at least one of 184 | `model_type` or `lang` 185 | - :param: `num_layers` (int): the layer of representation to use 186 | - :param: `lang` (str): language of the sentences; has to specify 187 | at least one of `model_type` or `lang`. `lang` needs to be 188 | specified when `rescale_with_baseline` is True. 189 | - :param: `return_hash` (bool): return hash code of the setting 190 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 191 | - :param: `fname` (str): path to save the output plot 192 | """ 193 | assert isinstance(candidate, str) 194 | assert isinstance(reference, str) 195 | 196 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 197 | 198 | if rescale_with_baseline: 199 | assert lang is not None, "Need to specify Language when rescaling with baseline" 200 | 201 | if model_type is None: 202 | lang = lang.lower() 203 | model_type = lang2model[lang] 204 | if num_layers is None: 205 | num_layers = model2layers[model_type] 206 | 207 | if model_type.startswith("scibert"): 208 | tokenizer = AutoTokenizer.from_pretrained(cache_scibert(model_type), cache_dir=CACHE_DIR) 209 | else: 210 | tokenizer = AutoTokenizer.from_pretrained(model_type, cache_dir=CACHE_DIR) 211 | model = get_model(model_type, num_layers) 212 | device = "cuda" if torch.cuda.is_available() else "cpu" 213 | model.to(device) 214 | 215 | idf_dict = defaultdict(lambda: 1.0) 216 | # set idf for [SEP] and [CLS] to 0 217 | idf_dict[tokenizer.sep_token_id] = 0 218 | idf_dict[tokenizer.cls_token_id] = 0 219 | 220 | hyp_embedding, masks, padded_idf = get_bert_embedding( 221 | [candidate], model, tokenizer, idf_dict, device=device, all_layers=False 222 | ) 223 | ref_embedding, masks, padded_idf = get_bert_embedding( 224 | [reference], model, tokenizer, idf_dict, device=device, all_layers=False 225 | ) 226 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 227 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 228 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 229 | sim = sim.squeeze(0).cpu() 230 | 231 | # remove [CLS] and [SEP] tokens 232 | r_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, reference)][1:-1] 233 | h_tokens = [tokenizer.decode([i]) for i in sent_encode(tokenizer, candidate)][1:-1] 234 | sim = sim[1:-1, 1:-1] 235 | 236 | if rescale_with_baseline: 237 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{lang}/{model_type}.tsv") 238 | if os.path.isfile(baseline_path): 239 | baselines = torch.from_numpy(pd.read_csv(baseline_path).iloc[num_layers].to_numpy())[1:].float() 240 | sim = (sim - baselines[2].item()) / (1 - baselines[2].item()) 241 | else: 242 | print(f"Warning: Baseline not Found for {model_type} on {lang} at {baseline_path}", file=sys.stderr) 243 | 244 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens))) 245 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1) 246 | 247 | # We want to show all ticks... 248 | ax.set_xticks(np.arange(len(r_tokens))) 249 | ax.set_yticks(np.arange(len(h_tokens))) 250 | # ... and label them with the respective list entries 251 | ax.set_xticklabels(r_tokens, fontsize=10) 252 | ax.set_yticklabels(h_tokens, fontsize=10) 253 | ax.grid(False) 254 | plt.xlabel("Reference (tokenized)", fontsize=14) 255 | plt.ylabel("Candidate (tokenized)", fontsize=14) 256 | title = "Similarity Matrix" 257 | if rescale_with_baseline: 258 | title += " (after Rescaling)" 259 | plt.title(title, fontsize=14) 260 | 261 | divider = make_axes_locatable(ax) 262 | cax = divider.append_axes("right", size="2%", pad=0.2) 263 | fig.colorbar(im, cax=cax) 264 | 265 | # Rotate the tick labels and set their alignment. 266 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 267 | 268 | # Loop over data dimensions and create text annotations. 269 | for i in range(len(h_tokens)): 270 | for j in range(len(r_tokens)): 271 | text = ax.text( 272 | j, 273 | i, 274 | "{:.3f}".format(sim[i, j].item()), 275 | ha="center", 276 | va="center", 277 | color="k" if sim[i, j].item() < 0.5 else "w", 278 | ) 279 | 280 | fig.tight_layout() 281 | if fname != "": 282 | plt.savefig(fname, dpi=100) 283 | print("Saved figure to file: ", fname) 284 | plt.show() 285 | -------------------------------------------------------------------------------- /bert_score/scorer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import pathlib 5 | import torch 6 | import matplotlib.pyplot as plt 7 | from mpl_toolkits.axes_grid1 import make_axes_locatable 8 | import numpy as np 9 | import pandas as pd 10 | import warnings 11 | 12 | from collections import defaultdict 13 | from transformers import AutoTokenizer 14 | 15 | from .utils import ( 16 | get_model, 17 | get_idf_dict, 18 | bert_cos_score_idf, 19 | get_bert_embedding, 20 | lang2model, 21 | model2layers, 22 | get_hash, 23 | cache_scibert, 24 | sent_encode, 25 | ) 26 | CACHE_DIR = "train/cache" 27 | 28 | 29 | class BERTScorer: 30 | """ 31 | BERTScore Scorer Object. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | model_type=None, 37 | num_layers=None, 38 | batch_size=64, 39 | nthreads=4, 40 | all_layers=False, 41 | idf=False, 42 | idf_sents=None, 43 | device=None, 44 | lang=None, 45 | rescale_with_baseline=False, 46 | ): 47 | """ 48 | Args: 49 | - :param: `model_type` (str): contexual embedding model specification, default using the suggested 50 | model for the target langauge; has to specify at least one of 51 | `model_type` or `lang` 52 | - :param: `num_layers` (int): the layer of representation to use. 53 | default using the number of layer tuned on WMT16 correlation data 54 | - :param: `verbose` (bool): turn on intermediate status update 55 | - :param: `idf` (dict): use idf weighting, can also be a precomputed idf_dict 56 | - :param: `idf_sents` (List of str): use idf weighting, can also be a precomputed idf_dict 57 | - :param: `device` (str): on which the contextual embedding model will be allocated on. 58 | If this argument is None, the model lives on cuda:0 if cuda is available. 59 | - :param: `batch_size` (int): bert score processing batch size 60 | - :param: `nthreads` (int): number of threads 61 | - :param: `lang` (str): language of the sentences; has to specify 62 | at least one of `model_type` or `lang`. `lang` needs to be 63 | specified when `rescale_with_baseline` is True. 64 | - :param: `return_hash` (bool): return hash code of the setting 65 | - :param: `rescale_with_baseline` (bool): rescale bertscore with pre-computed baseline 66 | """ 67 | 68 | assert lang is not None or model_type is not None, "Either lang or model_type should be specified" 69 | 70 | if rescale_with_baseline: 71 | assert lang is not None, "Need to specify Language when rescaling with baseline" 72 | 73 | if device is None: 74 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 75 | else: 76 | self.device = device 77 | 78 | self._lang = lang 79 | self._rescale_with_baseline = rescale_with_baseline 80 | self._idf = idf 81 | self.batch_size = batch_size 82 | self.nthreads = nthreads 83 | self.all_layers = all_layers 84 | 85 | if model_type is None: 86 | lang = lang.lower() 87 | self._model_type = lang2model[lang] 88 | else: 89 | self._model_type = model_type 90 | 91 | if num_layers is None: 92 | self._num_layers = model2layers[self.model_type] 93 | else: 94 | self._num_layers = num_layers 95 | 96 | # Building model and tokenizer 97 | 98 | if self.model_type.startswith("scibert"): 99 | self._tokenizer = AutoTokenizer.from_pretrained(cache_scibert(self.model_type), cache_dir=CACHE_DIR) 100 | else: 101 | self._tokenizer = AutoTokenizer.from_pretrained(self.model_type, cache_dir=CACHE_DIR) 102 | 103 | self._model = get_model(self.model_type, self.num_layers, self.all_layers) 104 | self._model.to(self.device) 105 | 106 | self._idf_dict = None 107 | if idf_sents is not None: 108 | self.compute_idf(idf_sents) 109 | 110 | @property 111 | def lang(self): 112 | return self._lang 113 | 114 | @property 115 | def idf(self): 116 | return self._idf 117 | 118 | @property 119 | def model_type(self): 120 | return self._model_type 121 | 122 | @property 123 | def num_layers(self): 124 | return self._num_layers 125 | 126 | @property 127 | def rescale_with_baseline(self): 128 | return self._rescale_with_baseline 129 | 130 | @property 131 | def baseline_vals(self): 132 | baseline_path = os.path.join(os.path.dirname(__file__), f"rescale_baseline/{self.lang}/{self.model_type}.tsv") 133 | if os.path.isfile(baseline_path): 134 | if not self.all_layers: 135 | baseline_vals = torch.from_numpy(pd.read_csv(baseline_path).iloc[self.num_layers].to_numpy())[ 136 | 1: 137 | ].float() 138 | else: 139 | baseline_vals = torch.from_numpy(pd.read_csv(baseline_path).to_numpy())[:, 1:].unsqueeze(1).float() 140 | else: 141 | raise ValueError(f"Baseline not Found for {self.model_type} on {self.lang} at {baseline_path}") 142 | 143 | return baseline_vals 144 | 145 | @property 146 | def hash(self): 147 | return get_hash(self.model_type, self.num_layers, self.idf, self.rescale_with_baseline) 148 | 149 | def compute_idf(self, sents): 150 | """ 151 | Args: 152 | 153 | """ 154 | if self._idf_dict is not None: 155 | warnings.warn("Overwriting the previous importance weights.") 156 | 157 | self._idf_dict = get_idf_dict(sents, self._tokenizer, nthreads=self.nthreads) 158 | 159 | def score(self, cands, refs, verbose=False, batch_size=64, return_hash=False): 160 | """ 161 | Args: 162 | - :param: `cands` (list of str): candidate sentences 163 | - :param: `refs` (list of str or list of list of str): reference sentences 164 | 165 | Return: 166 | - :param: `(P, R, F)`: each is of shape (N); N = number of input 167 | candidate reference pairs. if returning hashcode, the 168 | output will be ((P, R, F), hashcode). If a candidate have 169 | multiple references, the returned score of this candidate is 170 | the *best* score among all references. 171 | """ 172 | 173 | ref_group_boundaries = None 174 | if not isinstance(refs[0], str): 175 | ref_group_boundaries = [] 176 | ori_cands, ori_refs = cands, refs 177 | cands, refs = [], [] 178 | count = 0 179 | for cand, ref_group in zip(ori_cands, ori_refs): 180 | cands += [cand] * len(ref_group) 181 | refs += ref_group 182 | ref_group_boundaries.append((count, count + len(ref_group))) 183 | count += len(ref_group) 184 | 185 | if verbose: 186 | print("calculating scores...") 187 | start = time.perf_counter() 188 | 189 | if self.idf: 190 | assert self._idf_dict, "IDF weights are not computed" 191 | idf_dict = self._idf_dict 192 | else: 193 | idf_dict = defaultdict(lambda: 1.0) 194 | idf_dict[self._tokenizer.sep_token_id] = 0 195 | idf_dict[self._tokenizer.cls_token_id] = 0 196 | 197 | all_preds = bert_cos_score_idf( 198 | self._model, 199 | refs, 200 | cands, 201 | self._tokenizer, 202 | idf_dict, 203 | verbose=verbose, 204 | device=self.device, 205 | batch_size=batch_size, 206 | all_layers=self.all_layers, 207 | ).cpu() 208 | 209 | if ref_group_boundaries is not None: 210 | max_preds = [] 211 | for start, end in ref_group_boundaries: 212 | max_preds.append(all_preds[start:end].max(dim=0)[0]) 213 | all_preds = torch.stack(max_preds, dim=0) 214 | 215 | if self.rescale_with_baseline: 216 | all_preds = (all_preds - self.baseline_vals) / (1 - self.baseline_vals) 217 | 218 | out = all_preds[..., 0], all_preds[..., 1], all_preds[..., 2] # P, R, F 219 | 220 | if verbose: 221 | time_diff = time.perf_counter() - start 222 | print(f"done in {time_diff:.2f} seconds, {len(refs) / time_diff:.2f} sentences/sec") 223 | 224 | if return_hash: 225 | out = tuple([out, self.hash]) 226 | 227 | return out 228 | 229 | def plot_example(self, candidate, reference, fname=""): 230 | """ 231 | Args: 232 | - :param: `candidate` (str): a candidate sentence 233 | - :param: `reference` (str): a reference sentence 234 | - :param: `fname` (str): path to save the output plot 235 | """ 236 | 237 | assert isinstance(candidate, str) 238 | assert isinstance(reference, str) 239 | 240 | idf_dict = defaultdict(lambda: 1.0) 241 | idf_dict[self._tokenizer.sep_token_id] = 0 242 | idf_dict[self._tokenizer.cls_token_id] = 0 243 | 244 | hyp_embedding, masks, padded_idf = get_bert_embedding( 245 | [candidate], self._model, self._tokenizer, idf_dict, device=self.device, all_layers=False 246 | ) 247 | ref_embedding, masks, padded_idf = get_bert_embedding( 248 | [reference], self._model, self._tokenizer, idf_dict, device=self.device, all_layers=False 249 | ) 250 | ref_embedding.div_(torch.norm(ref_embedding, dim=-1).unsqueeze(-1)) 251 | hyp_embedding.div_(torch.norm(hyp_embedding, dim=-1).unsqueeze(-1)) 252 | sim = torch.bmm(hyp_embedding, ref_embedding.transpose(1, 2)) 253 | sim = sim.squeeze(0).cpu() 254 | 255 | r_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, reference)][1:-1] 256 | h_tokens = [self._tokenizer.decode([i]) for i in sent_encode(self._tokenizer, candidate)][1:-1] 257 | sim = sim[1:-1, 1:-1] 258 | 259 | if self.rescale_with_baseline: 260 | sim = (sim - self.baseline_vals[2].item()) / (1 - self.baseline_vals[2].item()) 261 | 262 | fig, ax = plt.subplots(figsize=(len(r_tokens), len(h_tokens))) 263 | im = ax.imshow(sim, cmap="Blues", vmin=0, vmax=1) 264 | 265 | # We want to show all ticks... 266 | ax.set_xticks(np.arange(len(r_tokens))) 267 | ax.set_yticks(np.arange(len(h_tokens))) 268 | # ... and label them with the respective list entries 269 | ax.set_xticklabels(r_tokens, fontsize=10) 270 | ax.set_yticklabels(h_tokens, fontsize=10) 271 | ax.grid(False) 272 | plt.xlabel("Reference (tokenized)", fontsize=14) 273 | plt.ylabel("Candidate (tokenized)", fontsize=14) 274 | title = "Similarity Matrix" 275 | if self.rescale_with_baseline: 276 | title += " (after Rescaling)" 277 | plt.title(title, fontsize=14) 278 | 279 | divider = make_axes_locatable(ax) 280 | cax = divider.append_axes("right", size="2%", pad=0.2) 281 | fig.colorbar(im, cax=cax) 282 | 283 | # Rotate the tick labels and set their alignment. 284 | plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor") 285 | 286 | # Loop over data dimensions and create text annotations. 287 | for i in range(len(h_tokens)): 288 | for j in range(len(r_tokens)): 289 | text = ax.text( 290 | j, 291 | i, 292 | "{:.3f}".format(sim[i, j].item()), 293 | ha="center", 294 | va="center", 295 | color="k" if sim[i, j].item() < 0.5 else "w", 296 | ) 297 | 298 | fig.tight_layout() 299 | if fname != "": 300 | plt.savefig(fname, dpi=100) 301 | print("Saved figure to file: ", fname) 302 | plt.show() 303 | 304 | def __repr__(self): 305 | return f"{self.__class__.__name__}(hash={self.hash}, batch_size={self.batch_size}, nthreads={self.nthreads})" 306 | 307 | def __str__(self): 308 | return self.__repr__() 309 | -------------------------------------------------------------------------------- /rouge/rouge_scorer_test.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Tests for myrouge scorer. 18 | 19 | Tests for both correctness and for consistency with the official ROUGE-1.5.5 20 | implementation. 21 | 22 | "Ground truth" scores are taken from manual runs of ROUGE-1.5.5. 23 | """ 24 | 25 | from __future__ import absolute_import 26 | from __future__ import division 27 | from __future__ import print_function 28 | 29 | import os 30 | 31 | from absl.testing import absltest 32 | from absl.testing import parameterized 33 | from rouge import rouge_scorer 34 | from rouge import test_util 35 | 36 | 37 | class RougeScorerTest(parameterized.TestCase): 38 | 39 | def setUp(self): 40 | super(RougeScorerTest, self).setUp() 41 | with open(test_util.TARGETS_FILE) as f: 42 | self.targets = f.readlines() 43 | with open(test_util.PREDICTIONS_FILE) as f: 44 | self.predictions = f.readlines() 45 | 46 | @parameterized.parameters(["rougen", "rouge0", "rouge10"]) 47 | def testInvalidRougeTypes(self, rouge_type): 48 | with self.assertRaises(ValueError): 49 | scorer = rouge_scorer.RougeScorer([rouge_type]) 50 | scorer.score("testing one two", "testing") 51 | 52 | @parameterized.parameters(["rouge1", "rouge9", "rougeL", "rougeLsum"]) 53 | def testValidRougeTypes(self, rouge_type): 54 | scorer = rouge_scorer.RougeScorer([rouge_type]) 55 | result = scorer.score("testing one two", "testing") 56 | self.assertSameElements(list(result.keys()), [rouge_type]) 57 | 58 | def testRouge1(self): 59 | scorer = rouge_scorer.RougeScorer(["rouge1"]) 60 | result = scorer.score("testing one two", "testing") 61 | self.assertAlmostEqual(1, result["rouge1"].precision) 62 | self.assertAlmostEqual(1 / 3, result["rouge1"].recall) 63 | self.assertAlmostEqual(1 / 2, result["rouge1"].fmeasure) 64 | 65 | @parameterized.parameters(["rouge1", "rouge2", "rougeL", "rougeLsum"]) 66 | def testRougeEmpty(self, rouge_type): 67 | scorer = rouge_scorer.RougeScorer([rouge_type]) 68 | result = scorer.score("testing one two", "") 69 | self.assertAlmostEqual(0, result[rouge_type].precision) 70 | self.assertAlmostEqual(0, result[rouge_type].recall) 71 | self.assertAlmostEqual(0, result[rouge_type].fmeasure) 72 | 73 | def testRouge2(self): 74 | scorer = rouge_scorer.RougeScorer(["rouge2"]) 75 | result = scorer.score("testing one two", "testing one") 76 | self.assertAlmostEqual(1, result["rouge2"].precision) 77 | self.assertAlmostEqual(1 / 2, result["rouge2"].recall) 78 | self.assertAlmostEqual(2 / 3, result["rouge2"].fmeasure) 79 | 80 | def testRougeLConsecutive(self): 81 | scorer = rouge_scorer.RougeScorer(["rougeL"]) 82 | result = scorer.score("testing one two", "testing one") 83 | self.assertAlmostEqual(1, result["rougeL"].precision) 84 | self.assertAlmostEqual(2 / 3, result["rougeL"].recall) 85 | self.assertAlmostEqual(4 / 5, result["rougeL"].fmeasure) 86 | 87 | def testRougeLNonConsecutive(self): 88 | scorer = rouge_scorer.RougeScorer(["rougeL"]) 89 | result = scorer.score("testing one two", "testing two") 90 | self.assertAlmostEqual(1, result["rougeL"].precision) 91 | self.assertAlmostEqual(2 / 3, result["rougeL"].recall) 92 | self.assertAlmostEqual(4 / 5, result["rougeL"].fmeasure) 93 | 94 | def testMultipleRougeTypes(self): 95 | scorer = rouge_scorer.RougeScorer(["rouge1", "rougeL"]) 96 | result = scorer.score("testing one two", "testing one") 97 | self.assertSameElements(list(result.keys()), ["rouge1", "rougeL"]) 98 | self.assertAlmostEqual(1, result["rouge1"].precision) 99 | self.assertAlmostEqual(2 / 3, result["rouge1"].recall) 100 | self.assertAlmostEqual(4 / 5, result["rouge1"].fmeasure) 101 | self.assertAlmostEqual(1, result["rougeL"].precision) 102 | self.assertAlmostEqual(2 / 3, result["rougeL"].recall) 103 | self.assertAlmostEqual(4 / 5, result["rougeL"].fmeasure) 104 | 105 | def testRouge1AgainstRouge155(self): 106 | scorer = rouge_scorer.RougeScorer(["rouge1"]) 107 | result = scorer.score(self.targets[0], self.predictions[0]) 108 | self.assertAlmostEqual(0.40741, result["rouge1"].recall, 5) 109 | self.assertAlmostEqual(0.68750, result["rouge1"].precision, 5) 110 | self.assertAlmostEqual(0.51163, result["rouge1"].fmeasure, 5) 111 | result = scorer.score(self.targets[1], self.predictions[1]) 112 | self.assertAlmostEqual(0.40476, result["rouge1"].recall, 5) 113 | self.assertAlmostEqual(0.65385, result["rouge1"].precision, 5) 114 | self.assertAlmostEqual(0.50000, result["rouge1"].fmeasure, 5) 115 | 116 | def testRouge1AgainstRouge155WithStemming(self): 117 | scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True) 118 | result = scorer.score(self.targets[0], self.predictions[0]) 119 | self.assertAlmostEqual(0.40741, result["rouge1"].recall, 5) 120 | self.assertAlmostEqual(0.68750, result["rouge1"].precision, 5) 121 | self.assertAlmostEqual(0.51163, result["rouge1"].fmeasure, 5) 122 | result = scorer.score(self.targets[1], self.predictions[1]) 123 | self.assertAlmostEqual(0.42857, result["rouge1"].recall, 5) 124 | self.assertAlmostEqual(0.69231, result["rouge1"].precision, 5) 125 | self.assertAlmostEqual(0.52941, result["rouge1"].fmeasure, 5) 126 | 127 | def testRouge2AgainstRouge155(self): 128 | scorer = rouge_scorer.RougeScorer(["rouge2"]) 129 | result = scorer.score(self.targets[0], self.predictions[0]) 130 | self.assertAlmostEqual(0.30769, result["rouge2"].recall, 5) 131 | self.assertAlmostEqual(0.53333, result["rouge2"].precision, 5) 132 | self.assertAlmostEqual(0.39024, result["rouge2"].fmeasure, 5) 133 | result = scorer.score(self.targets[1], self.predictions[1]) 134 | self.assertAlmostEqual(0.29268, result["rouge2"].recall, 5) 135 | self.assertAlmostEqual(0.48000, result["rouge2"].precision, 5) 136 | self.assertAlmostEqual(0.36364, result["rouge2"].fmeasure, 5) 137 | 138 | def testRouge2AgainstRouge155WithStemming(self): 139 | scorer = rouge_scorer.RougeScorer(["rouge2"], use_stemmer=True) 140 | result = scorer.score(self.targets[0], self.predictions[0]) 141 | self.assertAlmostEqual(0.30769, result["rouge2"].recall, 5) 142 | self.assertAlmostEqual(0.53333, result["rouge2"].precision, 5) 143 | self.assertAlmostEqual(0.39024, result["rouge2"].fmeasure, 5) 144 | result = scorer.score(self.targets[1], self.predictions[1]) 145 | self.assertAlmostEqual(0.29268, result["rouge2"].recall, 5) 146 | self.assertAlmostEqual(0.48000, result["rouge2"].precision, 5) 147 | self.assertAlmostEqual(0.36364, result["rouge2"].fmeasure, 5) 148 | 149 | def testRougeLAgainstRouge155(self): 150 | scorer = rouge_scorer.RougeScorer(["rougeL"]) 151 | result = scorer.score(self.targets[0], self.predictions[0]) 152 | self.assertAlmostEqual(0.40741, result["rougeL"].recall, 5) 153 | self.assertAlmostEqual(0.68750, result["rougeL"].precision, 5) 154 | self.assertAlmostEqual(0.51163, result["rougeL"].fmeasure, 5) 155 | result = scorer.score(self.targets[1], self.predictions[1]) 156 | self.assertAlmostEqual(0.40476, result["rougeL"].recall, 5) 157 | self.assertAlmostEqual(0.65385, result["rougeL"].precision, 5) 158 | self.assertAlmostEqual(0.50000, result["rougeL"].fmeasure, 5) 159 | 160 | def testRougeLSumAgainstRouge155WithStemming(self): 161 | scorer = rouge_scorer.RougeScorer(["rougeLsum"], use_stemmer=True) 162 | 163 | target = test_util.get_text( 164 | os.path.join(test_util.PYROUGE_DIR, "target_multi.0.txt")) 165 | prediction = test_util.get_text( 166 | os.path.join(test_util.PYROUGE_DIR, "prediction_multi.0.txt")) 167 | result = scorer.score(target, prediction) 168 | 169 | self.assertAlmostEqual(0.36538, result["rougeLsum"].recall, places=5) 170 | self.assertAlmostEqual(0.66667, result["rougeLsum"].precision, places=5) 171 | self.assertAlmostEqual(0.47205, result["rougeLsum"].fmeasure, places=5) 172 | 173 | def testLcsTable(self): 174 | ref = [1, 2, 3, 4, 5] 175 | c1 = [2, 5, 3, 4] 176 | t = rouge_scorer._lcs_table(ref, c1) 177 | self.assertEqual(3, t[len(ref)][len(c1)]) 178 | def _read_lcs(t, ref, can): 179 | return rouge_scorer._backtrack_norec(t, ref, can) 180 | # Indices 181 | self.assertEqual([1, 2, 3], 182 | _read_lcs(t, ref, c1)) 183 | # Values 184 | self.assertEqual([2, 3, 4], 185 | [ref[i] for i in _read_lcs(t, ref, c1)]) 186 | 187 | # No common subsequence. 188 | c2 = [8, 9] 189 | t = rouge_scorer._lcs_table(ref, c2) 190 | self.assertEqual(0, t[len(ref)][len(c2)]) 191 | self.assertEqual([], 192 | _read_lcs(t, ref, c2)) 193 | 194 | def testUnionLcs(self): 195 | # Example in Section 3.2 of https://www.aclweb.org/anthology/W04-1013, 196 | # except using indices into ref. 197 | 198 | # First test helper. 199 | lcs1 = [0, 1] # lcs [1, 2] 200 | lcs2 = [0, 2, 4] 201 | self.assertEqual([0, 1, 2, 4], rouge_scorer._find_union([lcs1, lcs2])) 202 | self.assertEqual([0, 1, 2, 4], rouge_scorer._find_union([lcs2, lcs1])) 203 | 204 | ref = [1, 2, 3, 4, 5] 205 | c1 = [1, 2, 6, 7, 8] # lcs = [1, 2] 206 | c2 = [1, 3, 8, 9, 5] # lcs = [1, 3, 5] 207 | self.assertEqual([1, 2, 3, 5], 208 | rouge_scorer._union_lcs(ref, [c1, c2])) 209 | self.assertEqual([1, 2, 3, 5], 210 | rouge_scorer._union_lcs(ref, [c1, c2])) 211 | 212 | def testSummaryLevelLcs(self): 213 | refs = [ 214 | [1, 2, 3, 4, 5] 215 | ] 216 | cans = [ 217 | [1, 2, 6, 7, 8], # lcs = [1, 2] 218 | [1, 3, 8, 9, 5] # lcs = [1, 3, 5] 219 | ] 220 | score = rouge_scorer._summary_level_lcs(refs, cans) 221 | self.assertEqual(0.8, score.recall) # 4 / 5 222 | self.assertEqual(0.4, score.precision) # 4 / 10 223 | # 0.4*0.8 / (0.4 + 0.8) 224 | self.assertAlmostEqual(0.5333, score.fmeasure, places=3) 225 | 226 | # Tokenizer may drop all tokens, resulting in empty candidate list. 227 | score = rouge_scorer._summary_level_lcs([["reference"]], [[]]) 228 | self.assertEqual(0.0, score.recall) 229 | 230 | def testRougeLsum(self): 231 | scorer = rouge_scorer.RougeScorer(["rougeLsum"]) 232 | result = scorer.score("w1 w2 w3 w4 w5", "w1 w2 w6 w7 w8\nw1 w3 w8 w9 w5") 233 | self.assertAlmostEqual(0.8, result["rougeLsum"].recall) 234 | self.assertAlmostEqual(0.4, result["rougeLsum"].precision) 235 | self.assertAlmostEqual(0.5333, result["rougeLsum"].fmeasure, places=3) 236 | 237 | # Empty case 238 | result = scorer.score("w1 w2 w3 w4 w5", "") 239 | self.assertAlmostEqual(0.0, result["rougeLsum"].fmeasure, places=3) 240 | self.assertAlmostEqual(0.0, result["rougeLsum"].recall, places=3) 241 | self.assertAlmostEqual(0.0, result["rougeLsum"].precision, places=3) 242 | 243 | result = scorer.score("", "w1") 244 | self.assertAlmostEqual(0.0, result["rougeLsum"].fmeasure, places=3) 245 | self.assertAlmostEqual(0.0, result["rougeLsum"].recall, places=3) 246 | self.assertAlmostEqual(0.0, result["rougeLsum"].precision, places=3) 247 | 248 | # Case in which summary is all non-word characters. 249 | result = scorer.score("w1 w2 w3 w4 w5", "/") 250 | self.assertAlmostEqual(0.0, result["rougeLsum"].fmeasure, places=3) 251 | self.assertAlmostEqual(0.0, result["rougeLsum"].recall, places=3) 252 | self.assertAlmostEqual(0.0, result["rougeLsum"].precision, places=3) 253 | 254 | def testRougeLsumLarge(self): 255 | with open(test_util.LARGE_PREDICTIONS_FILE) as f: 256 | prediction = f.read() 257 | with open(test_util.LARGE_TARGETS_FILE) as f: 258 | target = f.read() 259 | scorer = rouge_scorer.RougeScorer(["rougeLsum"]) 260 | result = scorer.score(target, prediction) 261 | self.assertAlmostEqual(0.533, result["rougeLsum"].fmeasure, places=3) 262 | 263 | 264 | if __name__ == "__main__": 265 | absltest.main() 266 | -------------------------------------------------------------------------------- /W3C/anonymize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from tqdm import tqdm 4 | import ujson as json 5 | from nltk import word_tokenize 6 | import nltk 7 | from nltk.corpus import wordnet 8 | import spacy 9 | nlp = spacy.load('en_core_web_sm') 10 | 11 | 12 | def remove_char(text): 13 | removelist = ' ' 14 | text = re.sub(r'[^\w' + removelist + ']', '', text.lower()) 15 | return text 16 | 17 | 18 | def remove_email(text): 19 | subtext = text.split(' ') 20 | sts = [] 21 | for i, st in enumerate(subtext): 22 | st = st.strip() 23 | st = re.sub(r"([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)", f'USERNAME@DOMAIN.COM', st, flags=re.MULTILINE) 24 | sts.append(st) 25 | return ' '.join(sts) 26 | 27 | 28 | def remove_url(text): 29 | subtext = text.split(' ') 30 | sts = [] 31 | for st in subtext: 32 | st = re.sub(r'https?:\/\/.*[\r\n]*', 'HTTP://LINK', st, flags=re.MULTILINE) 33 | st = re.sub(r'www.*[\r\n]*', 'HTTP://LINK', st, flags=re.MULTILINE) 34 | sts.append(st) 35 | return ' '.join(sts) 36 | 37 | 38 | def remove_ip_address(text): 39 | subtext = text.split(' ') 40 | sts = [] 41 | for st in subtext: 42 | st = re.sub(r'(?:[0-9]{1,3}\.){3}[0-9]{1,3}', 'IPADDRESS', st, flags=re.MULTILINE) 43 | st = re.sub(r'(?:[0-9]{1,3}\.){3}X', 'IPADDRESS', st, flags=re.MULTILINE) 44 | st = re.sub(r'([0-9A-Fa-f]{2}[:-]){5}([0-9A-Fa-f]{2})|([0-9a-fA-F]{4}\\.[0-9a-fA-F]{4}\\.[0-9a-fA-F]{4})', 'MACADDRESS', st, flags=re.MULTILINE) 45 | sts.append(st) 46 | return ' '.join(sts) 47 | 48 | 49 | def remove_phone(text): 50 | st = re.sub(r"([2-9]\d{2}-\d{3}-\d{4})", 'PHONENUMBER', text, flags=re.MULTILINE) 51 | st = re.sub(r"([2-9]\d{2}-\d{3}-\d{3})", 'PHONENUMBER', st, flags=re.MULTILINE) 52 | st = re.sub(r"([2-9]\d{2}- \d{3}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 53 | st = re.sub(r"([2-9]\d{2} \d{3} \d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 54 | st = re.sub(r"([2-9]\d{2} \d{3}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 55 | st = re.sub(r"(\([2-9]\d{2}-\d{3}-\d{4}\))", 'PHONENUMBER', st, flags=re.MULTILINE) 56 | st = re.sub(r"([2-9]\d{2}-\d{3}-\d{4},)", 'PHONENUMBER', st, flags=re.MULTILINE) 57 | st = re.sub(r"([2-9]\d{2}-\d{3}-\d{4}.)", 'PHONENUMBER', st, flags=re.MULTILINE) 58 | st = re.sub(r"([2-9]\d{2}/\d{3}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 59 | st = re.sub(r"([2-9]\d{2}\.\d{3}\.\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 60 | st = re.sub(r"(\([2-9]\d{2}\) \d{3}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 61 | st = re.sub(r"(\([2-9]\d{2}\)\d{3}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 62 | st = re.sub(r"(\([2-9]\d{2}\) \d{3} \d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 63 | st = re.sub(r"(\([2-9]\d{2}\) \d{3} \d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 64 | st = re.sub(r"(\d{2} \d{5} \d{5})", 'PHONENUMBER', st, flags=re.MULTILINE) 65 | st = re.sub(r"(\d{2}-\d{3}-\d{7})", 'PHONENUMBER', st, flags=re.MULTILINE) 66 | st = re.sub(r"(\d{2}\.\d{3}\.\d{7})", 'PHONENUMBER', st, flags=re.MULTILINE) 67 | st = re.sub(r"(\d{2}-\d{3}-\d{6})", 'PHONENUMBER', st, flags=re.MULTILINE) 68 | st = re.sub(r"(\d{2}-\d{1}-\d{4}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 69 | st = re.sub(r"(\d{2}-\d{1}-\d{8})", 'PHONENUMBER', st, flags=re.MULTILINE) 70 | st = re.sub(r"(\d{3}-\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 71 | st = re.sub(r"(\d{5} \d{5})", 'PHONENUMBER', st, flags=re.MULTILINE) 72 | st = re.sub(r"(\d{5} \d{6})", 'PHONENUMBER', st, flags=re.MULTILINE) 73 | st = re.sub(r"(\d{3}\.\d{4})", 'PHONENUMBER', st, flags=re.MULTILINE) 74 | for number in ["4085628026", "4084648998", "90-5334253539", "4088294796"]: 75 | if number in st: 76 | st = st.replace(number, 'PHONENUMBER') 77 | return st 78 | 79 | 80 | def remove_local_path(text, path_dict): 81 | subtext = text.split(' ') 82 | sts = [] 83 | for st in subtext: 84 | st = st.strip() 85 | if st.startswith("CONNECTION="): 86 | st = 'CONNECTION="DBPATH"' 87 | elif st.startswith("DB_URL="): 88 | st = 'DB_URL="DBPATH"' 89 | elif st in path_dict: 90 | st = path_dict[st] 91 | sts.append(st) 92 | return ' '.join(sts) 93 | 94 | 95 | def remove_last_names(text): 96 | tokens = nltk.tokenize.word_tokenize(text) 97 | pos = nltk.pos_tag(tokens) 98 | sentt = nltk.ne_chunk(pos, binary=False) 99 | person_list = [] 100 | person = [] 101 | name = "" 102 | for subtree in sentt.subtrees(filter=lambda t: t.label() == 'PERSON'): 103 | for leaf in subtree.leaves(): 104 | person.append(leaf[0]) 105 | if len(person) > 1: # avoid grabbing lone surnames 106 | for part in person: 107 | name += part + ' ' 108 | if name[:-1] not in person_list: 109 | person_list.append(name[:-1]) 110 | name = '' 111 | person = [] 112 | names = person_list.copy() 113 | for person in person_list: 114 | person_split = person.split(" ") 115 | if len(person_split) == 2: 116 | if "Nunn" in person_split or "Chan" in person_split or "Richard" in person_split: 117 | continue 118 | for name in person_split: 119 | if wordnet.synsets(name): 120 | if (name in person): 121 | names.remove(person) 122 | break 123 | for name in names: 124 | words = name.split(' ') 125 | text = text.replace(name, words[0].strip() + "'s" if "'s" in words[-1] else words[0].strip()) 126 | return text 127 | 128 | 129 | def remove_last_name_in_context(text, name_dict): 130 | error_words = ["To", "Update", "Hello", "Hi", "User", "ASAP", "Enterprise", "Screen", "DB", "Chart", 131 | "hai", "Hey", "Box", "Bug", "Time", "Greetings", "Bus", "R&D", "HP", "Doc", 132 | "Hai", "Apps", "Software", "and", "7th", "IE", "Schedule", "delete", "Dev", 133 | "Worksheet", "Name", "Girl", "Appliance", "BEFORE", "Encyclopedic", "Topics", "Configure", 134 | "Release", "office", "Alerts", "Installer", 135 | "Admin", "BBQ", "Mtg", "TBD", "Owner", "data", "app", "Interface", "Auto", 136 | "Adpater", "Assign", "Boxer", "Ask", "Objectives", "Discuss", "logins", "Code", "Alternative", 137 | "xml", "Comments"] 138 | doc = nlp(text) 139 | ents = doc.ents 140 | for ent in ents: 141 | if ent.label_ != 'PERSON': 142 | continue 143 | ent_text = ent.text.strip() 144 | ent_words = [word.strip() for word in ent_text.split(' ') if word.strip()] 145 | if len(ent_words) == 1: 146 | continue 147 | if set(ent_words) & set(error_words): 148 | continue 149 | if ent_text in name_dict: 150 | text = text.replace(ent_text, name_dict[ent_text]) 151 | return text 152 | 153 | 154 | def filter(): 155 | def _name_dict(): 156 | name_dict = {} 157 | with open(f"../Avocado/anonymize_files/person_names.txt", 'r') as f: 158 | lines = f.readlines() 159 | for line in lines: 160 | items = line.split(',') 161 | if len(items) > 1: 162 | if items[0].strip() != items[1].strip(): 163 | name_dict[items[0].strip()] = items[1].strip() 164 | return name_dict 165 | 166 | def _subject_filter(subject): 167 | model = re.compile(r'[a-z]') 168 | if subject[0] == "#" or "no subject" in subject or subject[0] == "$" or "***********" in subject \ 169 | or "re:" in subject or "autoreply" in subject or "confidential" in subject: 170 | return True 171 | if not model.search(subject): 172 | return True 173 | return False 174 | 175 | def _thread_filter(thread): 176 | end_words = ['cheers', 'regards', 'thanks', 'best', 'kind regards', 'best regards', 177 | "thanks and looking forward", "super thanks"] 178 | contents = [] 179 | total_words = [] 180 | froms = [] 181 | if not thread[0]['subject'] or "re:" in thread[0]['subject'].lower() or \ 182 | "fw:" in thread[0]['subject'].lower() or "fwd:" in thread[0]['subject'].lower() or \ 183 | "fw:" in thread[1]['subject'].lower() or "fwd:" in thread[1]['subject'].lower(): 184 | return [] 185 | for email in thread: 186 | next_signature = False 187 | if not email['name'].strip().split(' ')[0].split('@')[0].split('.')[0]: 188 | return [] 189 | names = email['name'].strip().lower().split() 190 | froms.append(f"{email['name'].strip().lower()}") 191 | full_names = [] 192 | for i in range(1, len(names) + 1): 193 | for j in range(len(names) - i + 1): 194 | full_names.append(' '.join(names[j: j + i])) 195 | lines = email["content"] 196 | new_lines = [] 197 | for i, line in enumerate(lines): 198 | line = line.strip() 199 | if line.startswith('>') or line.startswith("Message-ID:") or line.startswith("From:") \ 200 | or line.startswith('In-reply-to:') or "wrote:" in line or line.startswith("Subject:") \ 201 | or line.startswith('Date:'): # special for W3C 202 | continue 203 | line = line.replace("-->", ' ').strip() 204 | clean_line = remove_char(line) 205 | if not line: 206 | continue 207 | if "*****************" in line or "-------------------" in line or '_______________' in line or \ 208 | '~~~~~~~~~~~~~~~~~' in line or "============" in line or \ 209 | "-----Original Appointment-----" in line: 210 | break 211 | if line == "---" or line == "--" or line == "Manager" or line == "R&D": 212 | break 213 | # signature 214 | if clean_line in full_names: 215 | new_lines.append(line.split(' ')[0].strip()) 216 | break 217 | tokens = clean_line.split(' ') 218 | if len(tokens) == 3 and (tokens[0] in full_names or tokens[1] in full_names or tokens[2] in full_names): 219 | new_lines.append(line.split(' ')[0].strip()) 220 | break 221 | if line.startswith('Address:'): 222 | if line.split(':')[1].strip(): 223 | line = 'Address: ADDRESS' 224 | elif line.startswith("Contract #:"): 225 | if line.split(':')[1].strip(): 226 | line = 'Contract #: CONTRACTNUMBER' 227 | line = remove_email(line) 228 | line = remove_url(line) 229 | line = remove_ip_address(line) 230 | line = remove_phone(line) 231 | line = remove_last_names(line) 232 | line = remove_last_name_in_context(line, name_dict) 233 | line = remove_local_path(line, {}) 234 | if next_signature: 235 | new_lines.append(line.split(' ')[0]) 236 | break 237 | else: 238 | new_lines.append(line) 239 | if clean_line in end_words: 240 | next_signature = True 241 | new_content = '\n'.join(new_lines) 242 | if "password" in new_content or "passcode" in new_content or "pw:" in new_content or \ 243 | "Password" in new_content or "Passcode" in new_content or "PW:" in new_content or \ 244 | "passwd" in new_content or "Pass code" in new_content or "pass code" in new_content: 245 | return [] 246 | words = word_tokenize(new_content) 247 | total_words.extend(words) 248 | if len(words) > 200 or len(words) < 5: 249 | return [] 250 | contents.append(new_content) 251 | if len(set(contents)) == 1: 252 | return [] 253 | if len(set(froms)) == 1: 254 | return [] 255 | if len(total_words) > 1000 or len(total_words) < 30: 256 | return [] 257 | return contents 258 | 259 | def _distinguish_names(thread): 260 | names = [] 261 | first_names = {} 262 | for email in thread: 263 | from_name = email["name"].strip() 264 | to_names = [person.split('<')[0].strip() for person in email["to"]] 265 | for name in [from_name] + to_names: 266 | if name not in names: 267 | names.append(name) 268 | first_name = name.split(' ')[0].split('@')[0].split('.')[0] 269 | if first_name not in first_names: 270 | first_names[first_name] = 0 271 | else: 272 | first_names[first_name] += 1 273 | count = {} 274 | name_map = {} 275 | for name in names: 276 | first_name = name.split(' ')[0].split('@')[0].split('.')[0] 277 | if first_names[first_name] > 0: 278 | if first_name not in count: 279 | count[first_name] = 0 280 | else: 281 | count[first_name] += 1 282 | name_map[name] = first_name + f'-{count[first_name]}' 283 | else: 284 | name_map[name] = first_name 285 | return name_map 286 | 287 | with open("W3C.json", 'r') as f: 288 | data = json.load(f) 289 | name_dict = _name_dict() 290 | 291 | if not os.path.exists("W3C_threads"): 292 | os.mkdir("W3C_threads") 293 | 294 | data = sorted(data.items(), key=lambda x: x[0]) 295 | for threadno, (subject, thread) in enumerate(tqdm(data)): 296 | if _subject_filter(subject.lower()): 297 | continue 298 | content_lengths = [len(email["content"]) for email in thread] 299 | if len(content_lengths) > 10: 300 | continue 301 | directory = 7 if len(content_lengths) >= 7 else len(content_lengths) 302 | new_contents = _thread_filter(thread) 303 | if not new_contents: 304 | continue 305 | if not os.path.exists(f"W3C_threads/{directory}"): 306 | os.mkdir(f"W3C_threads/{directory}") 307 | if os.path.exists(f"W3C_threads/{directory}/{threadno}"): 308 | # clean existing file 309 | with open(f"W3C_threads/{directory}/{threadno}", 'w') as f: 310 | f.write("") 311 | name_map = _distinguish_names(thread) 312 | with open(f"W3C_threads/{directory}/{threadno}", 'a') as f: 313 | f.write(f"THREAD {threadno}\n\n" 314 | f"subject: {subject}\n" 315 | f"# of emails: {len(content_lengths)}\n\n") 316 | for j, email in enumerate(thread): 317 | form_first_name = name_map[email['name'].strip()] 318 | to_first_names = [name_map[person.split('<')[0].strip()] for person in email['to']] 319 | f.write(f"Email {j}\n" 320 | f"From: {form_first_name}\n" 321 | f"To: {', '.join(to_first_names)}\n\n" 322 | f"Content:\n {new_contents[j]}\n\n" 323 | f"==============================================================\n\n") 324 | 325 | 326 | if __name__ == '__main__': 327 | filter() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import Sampler 5 | from torch.optim.lr_scheduler import LambdaLR 6 | from typing import Optional 7 | 8 | 9 | class SortishSampler(Sampler): 10 | "Go through the text data by order of src length with a bit of randomness. From fastai repo." 11 | 12 | def __init__(self, data, batch_size): 13 | self.data, self.bs = data, batch_size 14 | 15 | def key(self, i): 16 | return len(self.data[i]) 17 | 18 | def __len__(self) -> int: 19 | return len(self.data) 20 | 21 | def __iter__(self): 22 | idxs = np.random.permutation(len(self.data)) 23 | sz = self.bs * 50 24 | ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)] 25 | sort_idx = np.concatenate([sorted(s, key=self.key, reverse=True) for s in ck_idx]) 26 | sz = self.bs 27 | ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)] 28 | max_ck = np.argmax([self.key(ck[0]) for ck in ck_idx]) # find the chunk with the largest key, 29 | ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0] # then make sure it goes first. 30 | sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=np.int) 31 | sort_idx = np.concatenate((ck_idx[0], sort_idx)) 32 | return iter(sort_idx) 33 | 34 | 35 | def get_inverse_sqrt_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 36 | """ Create a schedule with a learning rate that decreases as 1/sqrt(current_step) after 37 | being constant during a warmup period. 38 | """ 39 | 40 | def lr_lambda(current_step): 41 | return max( 42 | 0.0, 1.0 / np.sqrt(float(max(current_step, float(max(1, num_warmup_steps))))) 43 | ) 44 | 45 | return LambdaLR(optimizer, lr_lambda, last_epoch) 46 | 47 | 48 | def get_range_vector(size: int, device: int) -> torch.Tensor: 49 | """ 50 | Returns a range vector with the desired size, starting at 0. The CUDA implementation 51 | is meant to avoid copy data from CPU to GPU. 52 | """ 53 | if device > -1: 54 | return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1 55 | else: 56 | return torch.arange(0, size, dtype=torch.long) 57 | 58 | 59 | def get_device_of(tensor: torch.Tensor) -> int: 60 | """ 61 | Returns the device of the tensor. 62 | """ 63 | if not tensor.is_cuda: 64 | return -1 65 | else: 66 | return tensor.get_device() 67 | 68 | 69 | def flatten_and_batch_shift_indices(indices: torch.Tensor, sequence_length: int) -> torch.Tensor: 70 | """ 71 | This is a subroutine for [`batched_index_select`](./util.md#batched_index_select). 72 | The given `indices` of size `(batch_size, d_1, ..., d_n)` indexes into dimension 2 of a 73 | target tensor, which has size `(batch_size, sequence_length, embedding_size)`. This 74 | function returns a vector that correctly indexes into the flattened target. The sequence 75 | length of the target must be provided to compute the appropriate offsets. 76 | ```python 77 | indices = torch.ones([2,3], dtype=torch.long) 78 | # Sequence length of the target tensor. 79 | sequence_length = 10 80 | shifted_indices = flatten_and_batch_shift_indices(indices, sequence_length) 81 | # Indices into the second element in the batch are correctly shifted 82 | # to take into account that the target tensor will be flattened before 83 | # the indices are applied. 84 | assert shifted_indices == [1, 1, 1, 11, 11, 11] 85 | ``` 86 | # Parameters 87 | indices : `torch.LongTensor`, required. 88 | sequence_length : `int`, required. 89 | The length of the sequence the indices index into. 90 | This must be the second dimension of the tensor. 91 | # Returns 92 | offset_indices : `torch.LongTensor` 93 | """ 94 | # Shape: (batch_size) 95 | if torch.max(indices) >= sequence_length or torch.min(indices) < 0: 96 | print(f"All elements in indices should be in range (0, {sequence_length - 1})") 97 | 98 | offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length 99 | for _ in range(len(indices.size()) - 1): 100 | offsets = offsets.unsqueeze(1) 101 | 102 | # Shape: (batch_size, d_1, ..., d_n) 103 | offset_indices = indices + offsets 104 | 105 | # Shape: (batch_size * d_1 * ... * d_n) 106 | offset_indices = offset_indices.view(-1) 107 | return offset_indices 108 | 109 | 110 | def batched_index_select( 111 | target: torch.Tensor, 112 | indices: torch.LongTensor, 113 | flattened_indices: Optional[torch.LongTensor] = None, 114 | ) -> torch.Tensor: 115 | """ 116 | The given `indices` of size `(batch_size, d_1, ..., d_n)` indexes into the sequence 117 | dimension (dimension 2) of the target, which has size `(batch_size, sequence_length, 118 | embedding_size)`. 119 | This function returns selected values in the target with respect to the provided indices, which 120 | have size `(batch_size, d_1, ..., d_n, embedding_size)`. This can use the optionally 121 | precomputed `flattened_indices` with size `(batch_size * d_1 * ... * d_n)` if given. 122 | An example use case of this function is looking up the start and end indices of spans in a 123 | sequence tensor. This is used in the 124 | [CoreferenceResolver](../models/coreference_resolution/coref.md). Model to select 125 | contextual word representations corresponding to the start and end indices of mentions. The key 126 | reason this can't be done with basic torch functions is that we want to be able to use look-up 127 | tensors with an arbitrary number of dimensions (for example, in the coref model, we don't know 128 | a-priori how many spans we are looking up). 129 | # Parameters 130 | target : `torch.Tensor`, required. 131 | A 3 dimensional tensor of shape (batch_size, sequence_length, embedding_size). 132 | This is the tensor to be indexed. 133 | indices : `torch.LongTensor` 134 | A tensor of shape (batch_size, ...), where each element is an index into the 135 | `sequence_length` dimension of the `target` tensor. 136 | flattened_indices : `Optional[torch.Tensor]`, optional (default = `None`) 137 | An optional tensor representing the result of calling `flatten_and_batch_shift_indices` 138 | on `indices`. This is helpful in the case that the indices can be flattened once and 139 | cached for many batch lookups. 140 | # Returns 141 | selected_targets : `torch.Tensor` 142 | A tensor with shape [indices.size(), target.size(-1)] representing the embedded indices 143 | extracted from the batch flattened target tensor. 144 | """ 145 | if flattened_indices is None: 146 | # Shape: (batch_size * d_1 * ... * d_n) 147 | flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1)) 148 | 149 | # Shape: (batch_size * sequence_length, embedding_size) 150 | flattened_target = target.view(-1, target.size(-1)) 151 | 152 | # Shape: (batch_size * d_1 * ... * d_n, embedding_size) 153 | flattened_selected = flattened_target.index_select(0, flattened_indices) 154 | selected_shape = list(indices.size()) + [target.size(-1)] 155 | # Shape: (batch_size, d_1, ..., d_n, embedding_size) 156 | selected_targets = flattened_selected.view(*selected_shape) 157 | return selected_targets 158 | 159 | 160 | class Adafactor(torch.optim.Optimizer): 161 | """Implements Adafactor algorithm. 162 | 163 | This implementation is based on: 164 | `Adafactor: Adaptive Learning Rates with Sublinear Memory Cost` 165 | (see https://arxiv.org/abs/1804.04235) 166 | 167 | Note that this optimizer internally adjusts the learning rate 168 | depending on the *scale_parameter*, *relative_step* and 169 | *warmup_init* options. To use a manual (external) learning rate 170 | schedule you should set `scale_parameter=False` and 171 | `relative_step=False`. 172 | 173 | Arguments: 174 | params (iterable): iterable of parameters to optimize or dicts defining 175 | parameter groups 176 | lr (float, optional): external learning rate (default: None) 177 | eps (tuple[float, float]): regularization constans for square gradient 178 | and parameter scale respectively (default: (1e-30, 1e-3)) 179 | clip_threshold (float): threshold of root mean square of 180 | final gradient update (default: 1.0) 181 | decay_rate (float): coefficient used to compute running averages of square 182 | gradient (default: -0.8) 183 | beta1 (float): coefficient used for computing running averages of gradient 184 | (default: None) 185 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 186 | scale_parameter (bool): if True, learning rate is scaled by root mean square of 187 | parameter (default: True) 188 | relative_step (bool): if True, time-dependent learning rate is computed 189 | instead of external learning rate (default: True) 190 | warmup_init (bool): time-dependent learning rate computation depends on 191 | whether warm-up initialization is being used (default: False) 192 | """ 193 | 194 | def __init__(self, params, lr=None, eps=(1e-30, 1e-3), clip_threshold=1.0, 195 | decay_rate=-0.8, beta1=None, weight_decay=0.0, scale_parameter=True, 196 | relative_step=True, warmup_init=False): 197 | if lr is not None and relative_step: 198 | raise ValueError('Cannot combine manual lr and relative_step options') 199 | if warmup_init and not relative_step: 200 | raise ValueError('warmup_init requires relative_step=True') 201 | 202 | defaults = dict(lr=lr, eps=eps, clip_threshold=clip_threshold, decay_rate=decay_rate, 203 | beta1=beta1, weight_decay=weight_decay, scale_parameter=scale_parameter, 204 | relative_step=relative_step, warmup_init=warmup_init) 205 | super(Adafactor, self).__init__(params, defaults) 206 | 207 | @property 208 | def supports_memory_efficient_fp16(self): 209 | return True 210 | 211 | @property 212 | def supports_flat_params(self): 213 | return False 214 | 215 | def _get_lr(self, param_group, param_state): 216 | rel_step_sz = param_group['lr'] 217 | if param_group['relative_step']: 218 | min_step = 1e-6 * param_state['step'] if param_group['warmup_init'] else 1e-2 219 | rel_step_sz = min(min_step, 1.0/math.sqrt(param_state['step'])) 220 | param_scale = 1.0 221 | if param_group['scale_parameter']: 222 | param_scale = max(param_group['eps'][1], param_state['RMS']) 223 | return param_scale * rel_step_sz 224 | 225 | def _get_options(self, param_group, param_shape): 226 | factored = len(param_shape) >= 2 227 | use_first_moment = param_group['beta1'] is not None 228 | return factored, use_first_moment 229 | 230 | def _rms(self, tensor): 231 | return tensor.norm(2) / (tensor.numel() ** 0.5) 232 | 233 | def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col): 234 | r_factor = ( 235 | exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True) 236 | ).rsqrt_() 237 | c_factor = exp_avg_sq_col.rsqrt() 238 | return torch.mm(r_factor.unsqueeze(-1), c_factor.unsqueeze(0)) 239 | 240 | def step(self, closure=None): 241 | """Performs a single optimization step. 242 | 243 | Arguments: 244 | closure (callable, optional): A closure that reevaluates the model 245 | and returns the loss. 246 | """ 247 | loss = None 248 | if closure is not None: 249 | loss = closure() 250 | 251 | for group in self.param_groups: 252 | for p in group['params']: 253 | if p.grad is None: 254 | continue 255 | grad = p.grad.data 256 | if grad.dtype in {torch.float16, torch.bfloat16}: 257 | grad = grad.float() 258 | if grad.is_sparse: 259 | raise RuntimeError('Adafactor does not support sparse gradients.') 260 | 261 | state = self.state[p] 262 | grad_shape = grad.shape 263 | 264 | factored, use_first_moment = self._get_options(group, grad_shape) 265 | # State Initialization 266 | if len(state) == 0: 267 | state['step'] = 0 268 | 269 | if use_first_moment: 270 | # Exponential moving average of gradient values 271 | state['exp_avg'] = torch.zeros_like(grad) 272 | if factored: 273 | state['exp_avg_sq_row'] = torch.zeros(grad_shape[:-1]).to(grad) 274 | state['exp_avg_sq_col'] = torch.zeros(grad_shape[:-2] + grad_shape[-1:]).to(grad) 275 | else: 276 | state['exp_avg_sq'] = torch.zeros_like(grad) 277 | 278 | state['RMS'] = 0 279 | else: 280 | if use_first_moment: 281 | state['exp_avg'] = state['exp_avg'].to(grad) 282 | if factored: 283 | state['exp_avg_sq_row'] = state['exp_avg_sq_row'].to(grad) 284 | state['exp_avg_sq_col'] = state['exp_avg_sq_col'].to(grad) 285 | else: 286 | state['exp_avg_sq'] = state['exp_avg_sq'].to(grad) 287 | 288 | p_data_fp32 = p.data 289 | if p.data.dtype in {torch.float16, torch.bfloat16}: 290 | p_data_fp32 = p_data_fp32.float() 291 | 292 | state['step'] += 1 293 | state['RMS'] = self._rms(p_data_fp32) 294 | group['lr'] = self._get_lr(group, state) 295 | 296 | beta2t = 1.0 - math.pow(state['step'], group['decay_rate']) 297 | update = (grad**2) + group['eps'][0] 298 | if factored: 299 | exp_avg_sq_row = state['exp_avg_sq_row'] 300 | exp_avg_sq_col = state['exp_avg_sq_col'] 301 | 302 | exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1)) 303 | exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2)) 304 | 305 | # Approximation of exponential moving average of square of gradient 306 | update = self._approx_sq_grad(exp_avg_sq_row, exp_avg_sq_col) 307 | update.mul_(grad) 308 | else: 309 | exp_avg_sq = state['exp_avg_sq'] 310 | 311 | exp_avg_sq.mul_(beta2t).add_(1.0 - beta2t, update) 312 | update = exp_avg_sq.rsqrt().mul_(grad) 313 | 314 | update.div_( 315 | (self._rms(update) / group['clip_threshold']).clamp_(min=1.0) 316 | ) 317 | update.mul_(group['lr']) 318 | 319 | if use_first_moment: 320 | exp_avg = state['exp_avg'] 321 | exp_avg.mul_(group['beta1']).add_(1 - group['beta1'], update) 322 | update = exp_avg 323 | 324 | if group['weight_decay'] != 0: 325 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 326 | 327 | p_data_fp32.add_(-update) 328 | 329 | if p.data.dtype in {torch.float16, torch.bfloat16}: 330 | p.data.copy_(p_data_fp32) 331 | 332 | return loss -------------------------------------------------------------------------------- /Avocado/anonymize_files/paths.txt: -------------------------------------------------------------------------------- 1 | \\UTILITYSRV1\home\Marketing\Technical 2 | ..\sources\com\aasaan\designer\Copy 3 | Specs\SFS_Offline_Biogen 4 | /Sources/com/avocadoit/adapter 5 | utilitysrv1/home/Publications 6 | where_your_applframe\apps\BiogenTest\repository\data. 7 | com/avocadoit/applframe/adapter/siebel/SiebelDi 8 | d:\myprojects 9 | \cm\Rel3.5ga\ipmp401\sources\AvocadoITAlertsLdapPanel.java 10 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncMLStatus.java 11 | online/javascript/Window.java 12 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncML.java 13 | ..\..\Sources\com\aasaan\designer 14 | ISAPI\lib\crimson.jar;e:\Program 15 | \\utilitysrv1\home\qa\4.0.1conversion\biogen 16 | \\cm\cm\build\3.0GA.x\03-30-01\047\latestinstall 17 | \\Utilitysrv1 18 | \Development\Sources\com\avocadoit\offine\oca\Repository.java 19 | online/javascript/scriptengineinterface.java 20 | Team\forRobertMoore\ecs_120600.pdf 21 | current/default/std.ini 22 | com/avocadoit/online/jspmanager/taglib/EPGetvarTag.java 23 | \\cm\cm\build\3.0GA.x\02-28-01\013\install 24 | applframe/project/BiogenTest/screen/data/xml 25 | \Development\Sources\com\avocadoit\offine\oca\CommThread.java 26 | \\Cm\Cm\Build\3.0GA_ServicePack.x\SP8\08-10-01\097a\patchinstaller 27 | \Sybase\createStatisticsSchema.sql 28 | \Development\Sources\com\avocadoit\offine\oca\datastore\EBOImpl.java 29 | \winnt\system32 30 | lib/common.jar; 31 | \Sybase\createtables.sql 32 | \Development\Sources\com\avocadoit\offine\oca\InsertNode.java 33 | Development\3rdparty\Classes\jsse.jar 34 | C:\weblogic. 35 | /Development/ini/declare_epvars.jsp 36 | lib/online.jar; 37 | \Development\Sources\com\avocadoit\offine\oca\UpdateNode.java 38 | \\Ws_mobility_1\d-drive\Common\jamesl\UseCases\index.html, 39 | d:\parser\test\wml\canon\wStepInp.html 40 | ISAPI\lib\servlet.jar;e:\Program 41 | \\coffee\home\Outlook\backups.doc 42 | /Development/qatests/bin/copyFiles_NS60.bat 43 | \\Utilitysrv1\home\BizDev\Accenture\AvocadoIT 44 | \\Cm\Cm\Build\3.0GA_JapaneseLanguage\x4 45 | D:\build\Classes 46 | \Development\Sources\com\avocadoit\offine\oca\datastore\IndexEBCImpl.java 47 | //cm/build/main/classes 48 | \\cm\cm\build\daily\2.5.991\12-13-00\release 49 | com/avocadoit/online/epservlet/JSPManaRequestImpl.java 50 | d:\parser\test\release2.6\features\persistence\pda\canon\gotoapp.html 51 | parser\device\TestServlet.java 52 | /Development/qatests/config/log4j_NS60.properties 53 | com/aasaan/parser/device/admin/, 54 | ..\sources\com\aasaan\parser\Admin\AdminCopyDbToDbHandler.java:167: 55 | PDA 0.-1723910365.61238638 IPADDRESS /ae/salesforce/pda/registration.aml 1966 56 | D:\sraghavan_4.0GA_Titanium_3_2\Development\Sources\com\avocadoit\on 57 | data/LastDownload_EBC.xml. 58 | \\cm\cm\build\3.0.x\08-09-00\005\release\ 59 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncMLItem.java 60 | C:\AvocadoIT\applframe>configure 61 | \Development\Sources\com\avocadoit\offine\oca\datastore\OCAReaderWriter.java 62 | \Development\statistics 63 | \cm\build\deploy.xml 64 | \\cm\cm\daily\3.0.999\11-07-00\release 65 | /opt/AvocadoIT 66 | \\cm\cm\build\3.0GA_SP1\04-18-01\ 67 | /Development/qatests/config/datasource_NS.xml 68 | $/WebByPhone/ApplicationsEngineering/Demos/paytrust/docs/ 69 | \Sybase\droptables.sql 70 | "siebel://qasiebel/qaes1/SSVObjMgr/qaffa" 71 | com/avocadoit/online/jspmanager/UrlConnectionHandler.java 72 | 3rdparty\broadbeam\cdpd\include 73 | Eptoronto\applframe\build.xml 74 | \Sqlserver\readmesqlserver.txt 75 | \\Utilitysrv1\home\Project 76 | Z:\Development\Sources\com\avocadoit\online\parser\html\HtmlDocument 77 | $/WebByPhone/Development/ini 78 | C:\AvocadoIT\applframe> 79 | CACHE-DIRECTORY="d:\temp\cache\ 80 | com/avocadoit/online/adaptors/html/HtmlAnchorChunkingAdaptor.java 81 | com/avocadoit/online/IJSPManRequest.java 82 | /cvs/cvsrep/devel/applframe/lib/common.jar,v 83 | com/avocadoit/online/adaptors/BaseChunkingAdaptor.java 84 | \Development\Sources\com\avocadoit\offine\oca\Schema.java 85 | '..\bin\setEnv.bat' 86 | emds:originalCacheUrl="D:\myprojects\TestProperties\Identification\Page3.html" 87 | CACHE-DIRECTORY="D:\Program 88 | ISAPI\lib\jaxp.jar;e:\Program 89 | X:\Development\Sources\com\aasaan\designer. 90 | \Developement\Sources\com\avocadoit\servlets\ServletConfiguration.java 91 | /Development/qatests/setup/QAJdbcCase4/pda/input/filter_NS.txt 92 | \\Cm\CM\Build\4.0GA_palladium.x\01-10-2003\377\patchinstaller 93 | c:\winnt\system32 94 | com/avocadoit/applframe/adapter/siebel/SiebelSer 95 | \Sqlserver\dbsetup.bat 96 | $/WebByPhone/ApplicationsEngineering/BiogenDrop2/Applications/Biogen/Templates 97 | Development\3rdparty\Classes\jnet.jar 98 | \Sybase\readmeSybase.txt 99 | \\syogi\d\parser\alwayson\ 100 | paytrustdemoidc/scripts/paytrust/substringP 101 | \\utilitysrv1\public\techsupport\cases\ep2026.doc 102 | //coffee/d_drive/amitabh/r&d/licensingSchedule 103 | Specs/SFS_NokiaBrowser.doc 104 | \cm\Rel3.5ga\ipmp401\sources\AvocadoITFeaturePanel.java 105 | Development/offline/Applications/Biogen/Templates 106 | Z:\Development\Classes 107 | d:\parser)? 108 | udio\mobilestudio\model\Screen.java:33: 109 | Coffee\home\Publications 110 | ..\sources\com\aasaan\parser\interpreter\WmlStep.java:229: 111 | /applframe/bin/comadapter.dll. 112 | BROWSER_ERROR_DIRECTORY="d:\build\browsererror\" 113 | \Development\Sources\com\avocadoit\offine\resources\OCAMessageBundle.properties 114 | \\qawebserver\Server4\docs\Europe\ 115 | /Development/ini/WEB-INF/eptaglib_NS.tld 116 | $/WebByPhone/ApplicationsEngineering/Biogen/Test 117 | \\Utilitysrv1\home\BizDev\Accenture\AvocadoIT_Reston_Presentation_for_KX.ppt 118 | \cm\Rel3.5ga\ipmp41\Installproject3.5gaFullinstall\resource+_en.properties 119 | Coffee/home/Marketing/ProductPlanning/Product. 120 | directory:/ae/salesforce/hdml 121 | line\requestmanager\ConstantRequestManagerImpl.java:67: 122 | ..\wlserver6.1\lib, 123 | parser\browser\WinInet\WinInetClientConnection.java 124 | $WebByPhone/ApplicationEngineerings/etrade/pqa/README.txt? 125 | /Sources/com/avocadoit/adapter/html 126 | \\utilitysrv1\home\projects\nike 127 | \\Omnamahsivaya\d\4.0GA.x\ 128 | D:\sraghavan_4.0GA_Titanium_3_2\Development\Sources\com\avocadoit\st 129 | cm\cm\build\3.0GASP1\04-18-01\ 130 | lib/OCA.jar; 131 | \Sybase\dbsetup.bat 132 | lib\build.xml 133 | com/aasaan/parser/device/nuance/, 134 | $/RnDwebsite/Rel 135 | <<\\Cm\Cm\Build\3.0GA_JapaneseLanguage\x4>> 136 | d:\weblogic\emas\ep\projects 137 | d:/oracle/oradata/nike603 138 | \Development\Sources\com\avocadoit\offine\oca\util\ConvertEBCBase.java 139 | //utilitysrv1/home/marketing/proposals 140 | QA\release2.6\features\softkeys. 141 | -Dbuild=d:\build\classes 142 | AESRVR1/Home/Applications 143 | utilityserv1/home/biz 144 | \\Vobserver\cmadm\cmmlog\vobserver\gnuplot\000724_lic_avail_ClearCase.gif 145 | Files: Development\statistics\createStatisticsSchemaSQLServer.sql 146 | parser\browser\WinInet\WinInetBrowser.java 147 | /opt/avocadoit 148 | /screenset/screenname? 149 | d:\parser\alwayson 150 | Util/ConfigTemplate.java 151 | uilitysvr1\home\SalesEngineering\Spec\AppReq\Ivillage_AppReq.com 152 | /com/aasaan/parser/system/Version.java 153 | d:\installbuild\sampleApp) 154 | \\Cm\Cm\Build\3.5GA.x\05-02-01\004\release 155 | ISAPI\lib\jndi.jar;e:\Program 156 | \parser\alwayson\3.5\oca\Update\Templates\output\UpdateDeleted.html 157 | MACRO_DIR="D:\Program 158 | sources\com\aasaan\parser\device\unified\AdminServlet.java 159 | QA\release2.5\features\HOLandStaticCheckbox\pda\canon\HOLandStaticCheckbox.html 160 | Utilityserv1/Home/Biz 161 | d:\temp\logfile1.txt 162 | \\epcanada\blaine\shared 163 | $\RnDwebsite\Rel4.0\Functional 164 | d:\installbuild\examples)--> 165 | Specs\SFS_CentralizedAdmin.doc 166 | com/aasaan/parser/device/pda/, 167 | \Ulititysrv1\home\Training\Course 168 | /cvs/cvsrep/devel/applframe/wizard/system/lib/build.xml,v 169 | \\cm\cm\build\maintenance2.5\11-29-00\048.2\release 170 | \Sqlserver\createStatisticsSchema.sql 171 | data/DownloadCriteria_EBC.xml 172 | applframe/project/BiogenTest/repository/data 173 | \ini\avocadoit_config.dtd 174 | parser\system\BrowserState.java 175 | $/WebByPhone/Documentation/ServerArchitecture.doc 176 | \\UTILITYSRV1\HOME\Sales\RFP. 177 | parser\browser\WinInet\HTTPClientCDLL\HTTPClient.cpp 178 | paytrustdemoidc/scripts/util/add 179 | \3rdparty\Classes\xerces.jar 180 | \\Cm\Cm\Build\3.0GA_ServicePack.x\SP8\08-08-01 181 | test/StepEditor.java 182 | \\utilitysrv1\Homedir\GermanaMartinez\ 183 | "vobserver:/vobstore/Development.vbs": 184 | d:\parser\test\release2.6\features\persistence\pda\canon\myyahoo.html 185 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncMLHeader.java 186 | \\cmwork\parser\SCFuncSummary.txt 187 | Parser/storage/oracle/Database.java 188 | Parser/Interpreter/DBConnectionPool.java 189 | com/aasaan/parser/device 190 | /Development/qatests/bin/testrunner_NS60.bat 191 | /Development/qatests/config/log4j_NS.properties 192 | Files\AvocadoIT\macros" 193 | \cm\Rel3.5GA\ipmp401\InstallProject3.5GAFullInstaller\AvocadoIT.xml 194 | paytrustdemoidc/scripts/paytrust/concat 195 | com/aasaan/parser/formtest/, 196 | classes\com\aasaan. 197 | \\\nora\d\weblogic\ep\projects 198 | sources\com\aasaan\parser\device\unified\EPServlet.java 199 | AvocadoIT.appl/applframe/runtime/mPharma/bin 200 | 3.5GA/Development/bin 201 | \\cm\cm\build\3.0GA.x\03-29-01\047\install 202 | \\Cm\Cm\Build\3.0GA_SP4\05-25-01\install 203 | d:\errorlogs\.... 204 | \\Utilitysrv1\home\Projects\Nike\Nike_issues 205 | \cm\Rel3.5ga\ipmp401\sources\AvocadoITConfigurationSybase.java 206 | d:\build\ini\std.ini. 207 | /build/classes/script/util 208 | /u01/home/qa1/EMAS1; 209 | Development\3rdparty\Classes\jcert.jar 210 | \cm\Rel3.5ga\ipmp41\Installproject3.5gaFullinstall\resource.properties 211 | Sources\com\aasaan\parser\browser\WinInet\HTTPClient.java 212 | siebel://qatest2/es70/SFSObjMgr/sis70 213 | com/avocadoit/online/jspmanager/taglib/EPLoopTag.java 214 | \Development\Sources\com\avocadoit\offine\oca\datastore\DataStoreImpl.java 215 | \Build\Classes 216 | ..\sources\com\aasaan\parser\interpreter\HdmlStep.java:271: 217 | \\Ws_mobility_1\d-drive\Common\jamesl\UseCasesAmericanCentury.html 218 | \\Sraghavan\sraghavan_d_drive\sraghavan_4.0GA_Titanium\Development\ 219 | "d:\jdk122". 220 | \3rdparty\Classes\xalan.jar 221 | <\\utilitysrv1\home\SalesEngineering\Demos\DemoNotes\... 222 | com/avocadoit/applframe/adapter/siebel/SiebelServerException 223 | Documentation\2.5Refresh1Docs\ECS_Release_Notes.pdf 224 | parser\browser\WinInet\HTTPClientCDLL\com_aasaan_parser_browser_WinInet_HTTPClientC.h 225 | "d:\temp\cache" 226 | SiebelDataBean.login(siebel://qatest2/es70/SFSObjMgr/sis70,SADMIN,SADMIN,null 227 | \cm\Rel3.5ga\ipmp401\sources\AvocadoITRegistrationPanel.java 228 | Files\AvocadoIT\cache\" 229 | \\Ccasestg\Cm\applframe4.0\palladium_drop2 230 | /Sources/com/avocadoit/adapter/jdbc 231 | file:///d:/temp/test2.htm) 232 | /Development/qatests/setup/HOLGlobalVariable/pda/input/filter_NS.txt 233 | STATISTICS_FILE_NAME="/u01/home/qa1/appsEMAS1/statistics/statistics1" 234 | \\cm\cm\daily 235 | /Development/qatests/bin/copyFiles_NS.bat 236 | Templates/Download.act 237 | coffee\home\publications 238 | D:\build\Sources\com\avocadoit\studio\mobilestudio\view\DefaultScreenEditorAdapter.java:39: 239 | D:\sraghavan_4.0GA_Titanium_3_2\Develo 240 | \Sqlserver\createdatabase.sql 241 | /Development/qatests/config/filter_NS.txt 242 | \\blaine\shared 243 | com/aasaan/parser/device/unified/, 244 | ..\sources\com\aasaan\parser\interpreter\PdaStep.java:268: 245 | ISAPI\Examples\com\newatlanta\bean;;D:\jdk131\lib\tools.jar 246 | /com/aasaan/designer/designertool", 247 | Templates/DownloadAll.act 248 | Utilitysrv1:\home\marketing 249 | \\Cm\Sitraka\Build3.5applframe_ffa.zip). 250 | C:\AvocadoIT>cd 251 | com/aasaan/parser/device/testservergui/, 252 | Project/Application/Page, 253 | $/Rndwebsite/Rel3.5/Test 254 | temp/error/errorlog.txt. 255 | (/com/avocadoit/offline/avocadoit_install) 256 | Files\cache\81bac9b0-8b7f-11d4-80ee-0050da6894b6-.html" 257 | \\coffee\vss_coff\srcsafe.ini 258 | Utilitysrv1\home\Training\Course 259 | ini/ApplFields.xml 260 | com/avocadoit/online/adaptors/IWmlConstants.java 261 | \Sqlserver\droptables.sql 262 | ..\sources\com\aasaan\parser\Admin\AdminCopyDbToDbHandler.java:169: 263 | \Development\Sources\com\avocadoit\offine\oca\XmlTable.java 264 | \\Utilitysrv1\Engineering\Vivek\Alerts 265 | Development\examples\README.txt 266 | file:///d:/temp/test1.htm 267 | \\CM\CM\Build\Main\05-02-00\X222 268 | /u1/home/qa1/EMAS 269 | C:\AvocadoIT\applframe\projects\Biogen\screen\data\webapp\images 270 | com/avocadoit/online/adaptors/html/HtmlFormChunkingAdaptor.java 271 | \\utilitysrv1\public\techsupport\cases\EP2043_Saleshound.doc 272 | D:/avocadoit/.... 273 | /cvs/cvsrep/devel/applframe/lib/online.jar,v 274 | com/aasaan/parser/system/WorkerThread 275 | d_drive\home\QA\defaultnn.htm 276 | \\qawebserver\Server4\docs\Sjis\ 277 | \Eptoronto\applframe\wizard\util\src\com\avocadoit\applframe\wizard\util\TagFormatter.java 278 | \\Utilitysrv1/Public/documents 279 | file:///d:\domtest.html 280 | \Development\dbsetup\ 281 | Development/ini/new_avocadoit_config.xml 282 | \Development\Sources\com\avocadoit\offine\oca\Transaction.java 283 | /u01/home/qa1/EMAS1/classes 284 | biogen/b10g3n 285 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncMLParser.java 286 | \\utilitysrv1\public\techsupport\cases\ep1095_salesforce.doc 287 | \\Utilitysrv1\home\Publications 288 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncMLCred.java 289 | $/WebByPhone/ApplicationsEngineering/symbol/OCA/Applications/Symbol_SB 290 | paytrustdemoidc/scripts/util/subtract 291 | cm/build/build.xml 292 | com/aasaan/parser/browser/microsoft-visualj/, 293 | /Development/declare_epvars_NS.jsp 294 | \\Ws_mobility_1\d-drive\Common\jamesl\UseCasesTemplate.html 295 | com/aasaan/parser/browser/microsoft/, 296 | \\UtilitySrvr1\home\Publications 297 | common/gui/dax/model/DaxTemplateParser.java, 298 | files\symantec 299 | \parser\alwayson\3.5\oca\Update\Templates\canon\UpdateDeleted.html 300 | com/aasaan/parser/javascript/test/, 301 | \\cm\cm\Build\3.5GA.x\07-06-01\039\release 302 | \\utilitysrv1\misc\outlook\backups.doc 303 | utilitysrv1\home\Project 304 | com/aasaan/parser/device/hdml/, 305 | application/vnd.ms-excel, 306 | 3.5GA\Development\bin 307 | com/avocadoit/online/epservlet/ControllingServlet.java 308 | d:\temp\debug1.txt. 309 | Files\AvocadoIT\cache". 310 | Development\bin\build.xml 311 | paytrustdemoidc/scripts/util/mod 312 | D:\weblogic61\wlserver6.1\config\EPOffline\applications\admin\WEB-INF\classes\EAOAdminConfig1.xml 313 | RnDWebSite/Rel4.6/RunningTests.doc 314 | "D:\Program 315 | \Sqlserver\createtables.sql 316 | /RnDwebsite/Palladium/Test 317 | WinInet/JSSE 318 | /bin/laden" 319 | \Development\Sources\com\avocadoit\offine\oca\datastore\EBCImpl.java 320 | ISAPI\classes;d:/build/classes;d:/build/classes/adapter63.jar;d:/build/connector/jdbc/jdbcconnector.jar;d:/build/classes/kxml.jar;d:/build/classes/osa.jar;d:/build/classes/tasks.jar;d:/build/classes/wizard.jar;d:/build/3rdparty/classes/activation.jar;d:/build/3rdparty/classes/classes12.zip;d:/build/3rdparty/classes/datacompression.jar;d:/build/3rdparty/classes/tools.jar;d:/build/3rdparty/classes/gnu-regexp-1.1.0.jar;d:/build/3rdparty/classes/j2ee.jar;d:/build/3rdparty/classes/jcert.jar;d:/build/3rdparty/classes/jdom.jar;d:/build/3rdparty/classes/jnet.jar;d:/build/3rdparty/classes/jsse.jar;d:/build/3rdparty/classes/xerces.jar;d:/build/3rdparty/classes/log4j.jar;d:/build/3rdparty/classes/mail.jar;d:/build/3rdparty/classes/SiebelDataBean.jar;d:/build/3rdparty/classes/SiebelTC_enu.jar;d:/build/3rdparty/classes/SiebelTcCommon.jar;d:/build/3rdparty/classes/SiebelTcOM.jar;d:/build/3rdparty/classes/Tidy.jar;d:/build/3rdparty/classes/xalan.jar;d:/build/3rdparty/classes/jaxp.jar;d:/build/3rdparty/classes/xml4j.jar;E:\Program 321 | C:\AvocadoIT\applframe>recompile 322 | C:\AvocadoIT\applframe\lib\build.xml:553: 323 | QA/release2/features 324 | localhost:8080/servlet/OSAServlet 325 | d:\parser\alwayson\AlwaysOnOCASummary.txt 326 | /parser/alwayson/3.5/oca/ReinitializeEBD_corrupted/Templates/noObject.act.' 327 | Specs\SFS_CentralizedAdmin.doc). 328 | development\qatests\projects 329 | STATISTICS_FILE_NAME="/u01/home/qa1/statistics/statistics1" 330 | lib/OSA.jar; 331 | \rmanocha1\xenon. 332 | \Development\Sources\Offline\Ceinoke\ceinvoke.cpp 333 | /cvs/cvsrep/devel/applframe/lib/OCA.jar,v 334 | \\cm\cm\build\2.5.x\10-03-00\043\install\ 335 | ini/OCAAppConfig.xml 336 | ERROR_LOGFILE_NAME="/u01/home/qa1/appsEMAS1/logs/logfile" 337 | $/WebByPhone/ApplicationsEngineering/Biogen/AvocadoIT 338 | x:\development 339 | sources\EPServlet.java 340 | D:/AvocadoIT/offline/sdk/client/templates/ErrorTemplate.html 341 | 1000 342 | \\Cm\Cm\Build\2.5GA.x\01-25-01\063\LatestInstall) 343 | d:/build/errorlogs 344 | /cvs/cvsrep/devel/applframe/lib/OSA.jar,v 345 | Utilitysrv1/misc/it/software/visio 346 | /build/classes/script/sf. 347 | QA\release2.5\features\HOLandDynamicCheckbox\pda\canon\HOLandDynamicCheckbox.html 348 | \Development\Sources\com\avocadoit\offine\oca\Synchronization.java 349 | D:/build/bin/cabfiles 350 | sources\AdminServlet.java 351 | \\Cm\D$\CM\Build\3.5GA.x 352 | PDA 0.-1723910365.61238638 IPADDRESS /ae/salesforce/pda/start.aml 353 | C:\AvocadoIT\applframe\lib\build.xml:548: 354 | \\WebByPhone\Development\Sources\com\avocadoit\portal\utilities 355 | /Development/qatests/bin/testrunner_NS.bat 356 | 3.5GA\Developtment\bin 357 | /opt/java1.2 358 | /opt/java1.2/jre/lib 359 | Eptoronto\applframe\tasks\src\ChangeXMLFile.java 360 | \\utilitysrv1\home\projects\nike\requirements\Layout_of_Printable_Material.xls 361 | and \Development\dbsetup\Sqlserver\createStatisticsSchema.sql 362 | /ae/salesforce/hdml. 363 | $\RnDwebsite\Palladium\Functional 364 | \\simply\interoperable 365 | \Development\Sources\com\avocadoit\offine\oca\syncml\SyncMLItemBase64Encoded.java 366 | //coffee/home/SalesEngineering/Specs/AppScope/. 367 | Users\Dondi\Gaskill\Screens\imode. 368 | (/Utilitysrv1/home/Marketing/ac-demo/...). 369 | WebEdge\config_mdn\lang\common\Templates\calendar 370 | d:\build\logs\logfile1.txt)" 371 | (d:\installbuild\examples), 372 | d:/build/ini/EAOAdminConfig.xml 373 | 3rdparty\Classes 374 | /parser/alwayson/3.5/oca/Config_MissingName/ini/OCAAppConfig.xml. 375 | \\epjapan1\engineering\product_planning\WAP_2.0 376 | $/RnDwebsite/Rel4.0/Design 377 | com/aasaan/parser/device/wml/, 378 | \\utilitysrv1\public\techsupport\cases\ep1099_epj.doc 379 | ini/datasource.xml 380 | Documentation\2.5Refresh1Docs\ 381 | d:/errorlogs 382 | /opt/AvocadoIT/*.* 383 | Team\AVOCADOIT 384 | \\CM\CM\Build\Main\06-14-00\X242 385 | ISAPI\lib\ServletExec41.jar;e:\Program 386 | \com\aasaan\parser\interpreter\HdmlStep.java 387 | MobileFocus/PCExpo/Jupiter 388 | Persist/OracleDatabase.java 389 | paytrustdemoidc/pda 390 | d:\weblogic\emas\ep\projects. 391 | DEBUG_OUTPUT="/u01/home/qa1/appsEMAS1/logs/debug" 392 | \developement\projects\NTService\ServiceInterface.cpp 393 | \Sybase\createdatabase.sql 394 | "d:\Program 395 | 3.5\ReinitializeEBD_corrupted 396 | /Sources/com/avocadoit/adapter/domino 397 | com/aasaan/parser. 398 | \Phoneclient\document\PhoneClientConfig.doc] 399 | C:\AvocadoIT\applframe.pd_d2\wizard\system\lib\) 400 | Development\3rdparty\Classes\xerces.jar. 401 | d:\build\connector\jdbc 402 | \Development\Sources\com\avocadoit\offine\oca\datastore\DataStore.java 403 | /tmp/logfile) 404 | opeing/closing/applying 405 | \\Cm\Cm\Build\3.0GA.x\FinalBuild3.0.047 406 | \\CM\CM\Build\4.0GA_palladium.x\04-23-2002\161a\fullinstaller 407 | epcanada\machinename 408 | --------------------------------------------------------------------------------