├── .gitignore ├── Makefile ├── README.md ├── configs ├── examples │ ├── deepl_gen_eval.yaml │ ├── evaluate.yaml │ ├── gen_eval.yaml │ ├── generate.yaml │ ├── index.yaml │ ├── lm_harness.yaml │ ├── prepare.yaml │ ├── prepare_random.yaml │ ├── prepare_similarity.yaml │ └── vertex_ai_mt.yaml ├── reference_benchmark │ ├── 0_shot_wmt23.yaml │ └── standard_benchmarks.yaml └── tower_paper │ ├── 0_shot_openai.yaml │ ├── 5_shot_generic_models.yaml │ ├── 5_shot_openai.yaml │ ├── tower_instruct_0_shot.yaml │ └── tower_instruct_5_shot.yaml ├── poetry.lock ├── pyproject.toml ├── run_paper_benchmark.sh ├── run_reference_benchmark.sh └── tower_eval ├── __init__.py ├── cli.py ├── error_span_utils.py ├── fewshot_retrieval_utils.py ├── metrics ├── __init__.py ├── accuracy │ ├── __init__.py │ ├── metric.py │ └── result.py ├── base │ ├── comet.py │ ├── error_span_detection.py │ ├── metrics_handler.py │ ├── metricx.py │ ├── result_handler.py │ └── xml_metric.py ├── bleu │ ├── __init__.py │ ├── metric.py │ └── result.py ├── bleurt │ ├── __init__.py │ ├── metric.py │ └── result.py ├── chrf │ ├── __init__.py │ ├── metric.py │ └── result.py ├── comet │ ├── __init__.py │ └── metric.py ├── comet_kiwi │ ├── __init__.py │ └── metric.py ├── comet_kiwi_23_xxl │ ├── __init__.py │ └── metric.py ├── errant │ ├── __init__.py │ ├── metric.py │ └── result.py ├── error_span_detection_f1 │ ├── __init__.py │ └── metric.py ├── error_span_detection_precision │ ├── __init__.py │ └── metric.py ├── error_span_detection_recall │ ├── __init__.py │ └── metric.py ├── f1 │ ├── __init__.py │ ├── metric.py │ └── result.py ├── f1_sequence │ ├── __init__.py │ ├── conlleval.py │ ├── metric.py │ └── result.py ├── metricx │ ├── __init__.py │ └── metric.py ├── metricx_24 │ ├── __init__.py │ └── metric.py ├── metricx_large │ ├── __init__.py │ └── metric.py ├── metricx_qe │ ├── __init__.py │ └── metric.py ├── metricx_qe_large │ ├── __init__.py │ └── metric.py ├── metricx_qe_xxl │ ├── __init__.py │ └── metric.py ├── metricx_xxl │ ├── __init__.py │ └── metric.py ├── pearson │ ├── __init__.py │ ├── metric.py │ └── result.py ├── perplexity │ ├── __init__.py │ ├── metric.py │ ├── result.py │ └── vllm_subprocess.py ├── spearman │ ├── __init__.py │ ├── metric.py │ └── result.py ├── ter │ ├── __init__.py │ ├── metric.py │ └── result.py ├── xcomet_qe_xl │ ├── __init__.py │ └── metric.py ├── xcomet_qe_xxl │ ├── __init__.py │ └── metric.py ├── xcomet_xl │ ├── __init__.py │ └── metric.py ├── xcomet_xxl │ ├── __init__.py │ └── metric.py ├── xml_chrf │ ├── __init__.py │ └── metric.py └── xml_match │ ├── __init__.py │ ├── metric.py │ └── result.py ├── models ├── __init__.py ├── anthropic │ ├── __init__.py │ └── generator.py ├── cohere │ ├── __init__.py │ └── generator.py ├── deepl │ ├── __init__.py │ └── generator.py ├── exceptions.py ├── hf │ ├── __init__.py │ └── generator.py ├── inference_handler.py ├── openAI │ ├── __init__.py │ └── generator.py ├── seq2seq │ ├── __init__.py │ └── generator.py ├── vertexAI │ ├── __init__.py │ └── generator.py └── vllm │ ├── __init__.py │ └── generator.py ├── tasks ├── __init__.py ├── evaluate.py ├── generate.py ├── index.py └── prepare.py ├── tools ├── logging │ └── wandb_utils.py └── run_calame.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | tllm-env 2 | tower-eval-env 3 | __pycache__ 4 | .vscode 5 | wandb 6 | TowerEval-Data-v0.1* 7 | tower-eval-env* 8 | local_configs* -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: setup install publish flake8-test black-test tests 2 | 3 | setup: 4 | pip install "poetry" 5 | poetry config virtualenvs.create false 6 | 7 | install: setup 8 | poetry install 9 | 10 | publish: install 11 | poetry publish --build 12 | -------------------------------------------------------------------------------- /configs/examples/deepl_gen_eval.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions_data/0_shot" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/0_shot" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/0_shot" 5 | 6 | tasks: 7 | - name: mt 8 | subtasks: 9 | flores.en-de: 10 | metrics: 11 | chrf: 12 | models: 13 | - name: deepl_next_gen 14 | type: deepl 15 | arguments: 16 | model: quality_optimized 17 | - name: deepl_classic 18 | type: deepl 19 | arguments: 20 | model: latency_optimized -------------------------------------------------------------------------------- /configs/examples/evaluate.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "TowerEval-Data-v0.1/data/raw_data" 2 | output_dir: "TowerEval-Data-v0.1/evaluations/5_shot" 3 | tasks: 4 | - name: mt 5 | subtasks: 6 | flores.en-pt: 7 | flores.en-zh: 8 | metrics: 9 | chrf: 10 | bleu: 11 | tokenizer: zh 12 | comet: 13 | batch_size: 16 14 | metrics: 15 | chrf: 16 | bleu: 17 | comet: 18 | batch_size: 16 19 | models: 20 | - name: TowerInstruct-7B-v0.1 21 | type: vllm 22 | - name: TowerBase-7B-v0.1 23 | type: vllm -------------------------------------------------------------------------------- /configs/examples/gen_eval.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions_data/5_shot" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/5_shot" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/5_shot" 5 | tasks: 6 | - name: mt 7 | subtasks: 8 | flores.en-pt: 9 | flores.en-zh: 10 | gen_args: 11 | eval_args: 12 | metrics: 13 | chrf: 14 | bleu: 15 | tokenizer: zh 16 | comet: 17 | batch_size: 16 18 | metrics: 19 | chrf: 20 | bleu: 21 | comet: 22 | batch_size: 16 23 | models: 24 | - name: TowerInstruct-7B-v0.1 25 | type: vllm 26 | arguments: 27 | model_dir: Unbabel/TowerInstruct-7B-v0.1 28 | n_gpus: 1 29 | max_tokens: 1024 30 | run_async: True 31 | batch_size: -1 32 | stop_sequences: [""] 33 | - name: TowerBase-7B-v0.1 34 | type: vllm 35 | arguments: 36 | model_dir: Unbabel/TowerBase-7B-v0.1 37 | n_gpus: 1 38 | max_tokens: 1024 39 | run_async: True 40 | batch_size: -1 -------------------------------------------------------------------------------- /configs/examples/generate.yaml: -------------------------------------------------------------------------------- 1 | data_dir: "TowerEval-Data-v0.1/data/instructions_data/5_shot" 2 | output_dir: "TowerEval-Data-v0.1/generations/5_shot" 3 | tasks: 4 | - name: mt 5 | subtasks: 6 | flores.en-pt: 7 | flores.en-zh: 8 | models: 9 | - name: TowerInstruct-7B-v0.1 10 | type: vllm 11 | arguments: 12 | model_dir: Unbabel/TowerInstruct-7B-v0.1 13 | n_gpus: 1 14 | max_tokens: 1024 15 | run_async: True 16 | batch_size: -1 17 | stop_sequences: [""] 18 | - name: TowerBase-7B-v0.1 19 | type: vllm 20 | arguments: 21 | model_dir: Unbabel/TowerBase-7B-v0.1 22 | n_gpus: 1 23 | max_tokens: 1024 24 | run_async: True 25 | batch_size: -1 26 | -------------------------------------------------------------------------------- /configs/examples/index.yaml: -------------------------------------------------------------------------------- 1 | seed: 52 2 | data_dir: "TowerEval-Data-v0.1/data/raw_data/" 3 | output_dir: "TowerEval-Data-v0.1/data/indexed_data/" 4 | tasks: 5 | - name: mt 6 | jsonl: True 7 | subtasks: 8 | flores.en-pt: 9 | flores.en-zh: -------------------------------------------------------------------------------- /configs/examples/lm_harness.yaml: -------------------------------------------------------------------------------- 1 | output_dir: 2 | harness_args: { 3 | "--batch_size": "auto", 4 | "--log_samples": null 5 | } 6 | devices: "3" 7 | tasks: 8 | - name: lm_harness 9 | subtasks: 10 | xstorycloze_en: 11 | xstorycloze_es: 12 | xstorycloze_ru: 13 | xcopa_it: 14 | xcopa_zh: 15 | belebele_eng_Latn: 16 | belebele_deu_Latn: 17 | belebele_fra_Latn: 18 | belebele_ita_Latn: 19 | belebele_nld_Latn: 20 | belebele_por_Latn: 21 | belebele_rus_Cyrl: 22 | belebele_spa_Latn: 23 | belebele_zho_Hans: 24 | belebele_kor_Hang: 25 | xwinograd_en: 26 | xwinograd_fr: 27 | xwinograd_pt: 28 | xwinograd_ru: 29 | xwinograd_zh: 30 | xnli_de: 31 | xnli_en: 32 | xnli_es: 33 | xnli_fr: 34 | xnli_ru: 35 | xnli_zh: 36 | arc_easy: 37 | arc_challenge: { 38 | "num_fewshot": "25" 39 | } 40 | arc_de: { 41 | "num_fewshot": "25" 42 | } 43 | arc_es: { 44 | "num_fewshot": "25" 45 | } 46 | arc_fr: { 47 | "num_fewshot": "25" 48 | } 49 | arc_it: { 50 | "num_fewshot": "25" 51 | } 52 | arc_nl: { 53 | "num_fewshot": "25" 54 | } 55 | arc_pt: { 56 | "num_fewshot": "25" 57 | } 58 | arc_ru: { 59 | "num_fewshot": "25" 60 | } 61 | arc_zh: { 62 | "num_fewshot": "25" 63 | } 64 | hellaswag: 65 | hellaswag_de: 66 | hellaswag_es: 67 | hellaswag_fr: 68 | hellaswag_it: 69 | hellaswag_nl: 70 | hellaswag_pt: 71 | hellaswag_ru: 72 | m_mmlu_de: { 73 | "num_fewshot": "25" 74 | } 75 | m_mmlu_en: { 76 | "num_fewshot": "25" 77 | } 78 | m_mmlu_es: { 79 | "num_fewshot": "25" 80 | } 81 | m_mmlu_fr: { 82 | "num_fewshot": "25" 83 | } 84 | m_mmlu_it: { 85 | "num_fewshot": "25" 86 | } 87 | m_mmlu_nl: { 88 | "num_fewshot": "25" 89 | } 90 | m_mmlu_pt: { 91 | "num_fewshot": "25" 92 | } 93 | m_mmlu_ru: { 94 | "num_fewshot": "25" 95 | } 96 | models: 97 | - name: TowerBase-7B-v0.1 98 | path: Unbabel/TowerBase-7B-v0.1 -------------------------------------------------------------------------------- /configs/examples/prepare.yaml: -------------------------------------------------------------------------------- 1 | seed: 52 2 | data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | output_dir: "TowerEval-Data-v0.1/data/test_instructions" 4 | tasks: 5 | - name: ner 6 | prompt_templates: 7 | - "<|im_start|>user\\nHighlight all the named entities in the following list of tokens \"{{ input }}\".\\nThese are the annotation guidelines for named entities you need to follow:\\n{% for annotation in annotations -%}- {{ annotation.tag }} - {{ annotation.description }}{% if not loop.last %}\\n{% endif %}{%- endfor %}.\\nBesides that, prepend B- as a prefix of the first token of any entity and I- to subsequent ones if they exist. If a token is not a part of a named entity, mark it as O.\\nAnswer: <|im_end|>\\n<|im_start|>assistant\\n" 8 | n_fewshots: 0 9 | subtasks: 10 | multiconer2023.en: 11 | prompt_args: 12 | annotations: 13 | - tag: Person 14 | description: Names of people 15 | - tag: Location 16 | description: Location or physical facilities 17 | - tag: Group 18 | description: Groups of people, organizations, corporations or other entities 19 | - tag: Product 20 | description: Consumer products such as food, drinks, clothing, and vehicles 21 | - tag: CreativeWorks 22 | description: Titles of creative works like movie, song, and book titles 23 | - tag: Medical 24 | description: Entities from the medical domain, including diseases, symptoms, and medications 25 | multiconer2023.de: 26 | prompt_args: 27 | annotations: 28 | - tag: Person 29 | description: Names of people 30 | - tag: Location 31 | description: Location or physical facilities 32 | - tag: Group 33 | description: Groups of people, organizations, corporations or other entities 34 | - tag: Product 35 | description: Consumer products such as food, drinks, clothing, and vehicles 36 | - tag: CreativeWorks 37 | description: Titles of creative works like movie, song, and book titles 38 | - tag: Medical 39 | description: Entities from the medical domain, including diseases, symptoms, and medications 40 | multiconer2023.fr: 41 | prompt_args: 42 | annotations: 43 | - tag: Person 44 | description: Names of people 45 | - tag: Location 46 | description: Location or physical facilities 47 | - tag: Group 48 | description: Groups of people, organizations, corporations or other entities 49 | - tag: Product 50 | description: Consumer products such as food, drinks, clothing, and vehicles 51 | - tag: CreativeWorks 52 | description: Titles of creative works like movie, song, and book titles 53 | - tag: Medical 54 | description: Entities from the medical domain, including diseases, symptoms, and medications 55 | multiconer2023.es: 56 | prompt_args: 57 | annotations: 58 | - tag: Person 59 | description: Names of people 60 | - tag: Location 61 | description: Location or physical facilities 62 | - tag: Group 63 | description: Groups of people, organizations, corporations or other entities 64 | - tag: Product 65 | description: Consumer products such as food, drinks, clothing, and vehicles 66 | - tag: CreativeWorks 67 | description: Titles of creative works like movie, song, and book titles 68 | - tag: Medical 69 | description: Entities from the medical domain, including diseases, symptoms, and medications 70 | multiconer2023.it: 71 | prompt_args: 72 | annotations: 73 | - tag: Person 74 | description: Names of people 75 | - tag: Location 76 | description: Location or physical facilities 77 | - tag: Group 78 | description: Groups of people, organizations, corporations or other entities 79 | - tag: Product 80 | description: Consumer products such as food, drinks, clothing, and vehicles 81 | - tag: CreativeWorks 82 | description: Titles of creative works like movie, song, and book titles 83 | - tag: Medical 84 | description: Entities from the medical domain, including diseases, symptoms, and medications 85 | multiconer2023.pt: 86 | prompt_args: 87 | annotations: 88 | - tag: Person 89 | description: Names of people 90 | - tag: Location 91 | description: Location or physical facilities 92 | - tag: Group 93 | description: Groups of people, organizations, corporations or other entities 94 | - tag: Product 95 | description: Consumer products such as food, drinks, clothing, and vehicles 96 | - tag: CreativeWorks 97 | description: Titles of creative works like movie, song, and book titles 98 | - tag: Medical 99 | description: Entities from the medical domain, including diseases, symptoms, and medications 100 | multiconer2023.zh: 101 | prompt_args: 102 | annotations: 103 | - tag: Person 104 | description: Names of people 105 | - tag: Location 106 | description: Location or physical facilities 107 | - tag: Group 108 | description: Groups of people, organizations, corporations or other entities 109 | - tag: Product 110 | description: Consumer products such as food, drinks, clothing, and vehicles 111 | - tag: CreativeWorks 112 | description: Titles of creative works like movie, song, and book titles 113 | - tag: Medical 114 | description: Entities from the medical domain, including diseases, symptoms, and medications 115 | - name: gec 116 | prompt_templates: 117 | - "<|im_start|>user\\n{%- if examples | length > 1 -%} Here are some examples of {{ lang }} texts with errors and their corrections.{%- else -%} Here is an example of a {{ lang }} text with errors and its correction.\\n{%- endif -%}{%- for example in examples -%}Source: {{ example.src }}\\nCorrected: {{ example.ref }}\\n\\n{%- endfor -%}Correct the errors in the following {{ lang }} text.\\nSource: {{ src }}\\nCorrected: <|im_end|>\\n<|im_start|>assistant\\n" 118 | n_fewshots: 5 119 | fewshot_retrieval_method: force_label_balance 120 | fewshot_retrieval_args: 121 | n_positive: 2 122 | subtasks: 123 | conll14.en: 124 | prompt_args: 125 | lang: English 126 | fm.de: 127 | prompt_args: 128 | lang: German 129 | cowsl2h.es: 130 | prompt_args: 131 | lang: Spanish -------------------------------------------------------------------------------- /configs/examples/prepare_random.yaml: -------------------------------------------------------------------------------- 1 | seed: 52 2 | data_dir: "TowerEval-Data-v0.1/data/raw_data/" 3 | output_dir: "TowerEval-Data-v0.1/data/instructions_data/5_shot_tower_instruct_prompt_random/" 4 | tasks: 5 | - name: mt 6 | prompt_templates: 7 | - "<|im_start|>user\n{%- for example in examples -%}{{ lp0 }}: {{ example.src }}\\n{{ lp1 }}: {{ example.ref }}\\n{%- endfor -%}{{ lp0 }}: {{ src }}\\n{{ lp1 }}: <|im_end|>\\n<|im_start|>assistant\\n" 8 | jsonl: True 9 | n_fewshots: 5 10 | fewshot_retrieval_method: random 11 | subtasks: 12 | flores.en-pt: 13 | prompt_args: 14 | lp0: English 15 | lp1: Portuguese 16 | flores.en-zh: 17 | prompt_args: 18 | lp0: English 19 | lp1: Chinese -------------------------------------------------------------------------------- /configs/examples/prepare_similarity.yaml: -------------------------------------------------------------------------------- 1 | seed: 52 2 | data_dir: "TowerEval-Data-v0.1/data/raw_data/" 3 | index_dir: "TowerEval-Data-v0.1/data/indexed_data/" 4 | output_dir: "TowerEval-Data-v0.1/data/instructions_data/5_shot_tower_instruct_prompt_similarity/" 5 | tasks: 6 | - name: mt 7 | prompt_templates: 8 | - "<|im_start|>user\n{%- for example in examples -%}{{ lp0 }}: {{ example.src }}\\n{{ lp1 }}: {{ example.ref }}\\n{%- endfor -%}{{ lp0 }}: {{ src }}\\n{{ lp1 }}: <|im_end|>\\n<|im_start|>assistant\\n" 9 | jsonl: True 10 | n_fewshots: 5 11 | fewshot_retrieval_method: similarity 12 | subtasks: 13 | flores.en-pt: 14 | prompt_args: 15 | lp0: English 16 | lp1: Portuguese 17 | flores.en-zh: 18 | prompt_args: 19 | lp0: English 20 | lp1: Chinese -------------------------------------------------------------------------------- /configs/examples/vertex_ai_mt.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions_data/0_shot" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/0_shot" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/0_shot" 5 | tasks: 6 | - name: mt 7 | subtasks: 8 | flores.en-de: 9 | flores.en-fr: 10 | metrics: 11 | chrf: 12 | bleu: 13 | comet: 14 | batch_size: 16 15 | models: 16 | - name: gemini 17 | type: vertex-ai 18 | arguments: 19 | model: "gemini-pro" 20 | max_tokens: 1024 21 | debug: True 22 | retry_max_attempts: 200 23 | - name: palm2 24 | type: vertex-ai 25 | arguments: 26 | model: "text-bison" 27 | max_tokens: 1024 28 | debug: True 29 | retry_max_attempts: 200 -------------------------------------------------------------------------------- /configs/reference_benchmark/0_shot_wmt23.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions_data/0_shot_tower_instruct_prompt" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/0_shot_tower_instruct_prompt" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/0_shot_tower_instruct_prompt" 5 | tasks: 6 | - name: mt 7 | subtasks: 8 | wmt23.en-de: 9 | wmt23.en-ru: 10 | wmt23.en-zh: 11 | gen_args: 12 | eval_args: 13 | metrics: 14 | chrf: 15 | bleu: 16 | tokenizer: zh 17 | comet: 18 | batch_size: 16 19 | comet_kiwi: 20 | batch_size: 16 21 | bleurt: 22 | batch_size: 16 23 | wmt23.de-en: 24 | wmt23.ru-en: 25 | wmt23.zh-en: 26 | metrics: 27 | chrf: 28 | bleu: 29 | comet: 30 | batch_size: 16 31 | xcomet: 32 | batch_size: 16 33 | comet_kiwi: 34 | batch_size: 16 35 | bleurt: 36 | batch_size: 16 37 | comet: 38 | batch_size: 16 39 | models: 40 | - name: 41 | type: vllm 42 | arguments: 43 | model_dir: 44 | n_gpus: 1 45 | max_tokens: 1024 46 | run_async: True 47 | batch_size: -1 -------------------------------------------------------------------------------- /configs/reference_benchmark/standard_benchmarks.yaml: -------------------------------------------------------------------------------- 1 | output_dir: 2 | harness_args: { 3 | "--batch_size": "auto", 4 | "--log_samples": null 5 | } 6 | devices: "1" 7 | tasks: 8 | - name: lm_harness 9 | subtasks: 10 | xstorycloze_en: 11 | xstorycloze_es: 12 | xstorycloze_ru: 13 | xcopa_it: 14 | xcopa_zh: 15 | belebele_eng_Latn: 16 | belebele_deu_Latn: 17 | belebele_fra_Latn: 18 | belebele_ita_Latn: 19 | belebele_nld_Latn: 20 | belebele_por_Latn: 21 | belebele_rus_Cyrl: 22 | belebele_spa_Latn: 23 | belebele_zho_Hans: 24 | belebele_kor_Hang: 25 | xwinograd_en: 26 | xwinograd_fr: 27 | xwinograd_pt: 28 | xwinograd_ru: 29 | xwinograd_zh: 30 | xnli_ar: 31 | xnli_bg: 32 | xnli_de: 33 | xnli_el: 34 | xnli_en: 35 | xnli_es: 36 | xnli_fr: 37 | xnli_hi: 38 | xnli_ru: 39 | xnli_sw: 40 | xnli_th: 41 | xnli_tr: 42 | xnli_ur: 43 | xnli_vi: 44 | xnli_zh: 45 | arc_easy: 46 | arc_challenge: { 47 | "num_fewshot": "25" 48 | } 49 | arc_de: { 50 | "num_fewshot": "25" 51 | } 52 | arc_es: { 53 | "num_fewshot": "25" 54 | } 55 | arc_fr: { 56 | "num_fewshot": "25" 57 | } 58 | arc_it: { 59 | "num_fewshot": "25" 60 | } 61 | arc_nl: { 62 | "num_fewshot": "25" 63 | } 64 | arc_pt: { 65 | "num_fewshot": "25" 66 | } 67 | arc_ru: { 68 | "num_fewshot": "25" 69 | } 70 | arc_zh: { 71 | "num_fewshot": "25" 72 | } 73 | hellaswag: 74 | hellaswag_de: 75 | hellaswag_es: 76 | hellaswag_fr: 77 | hellaswag_it: 78 | hellaswag_nl: 79 | hellaswag_pt: 80 | hellaswag_ru: 81 | m_mmlu_de: { 82 | "num_fewshot": "25" 83 | } 84 | m_mmlu_en: { 85 | "num_fewshot": "25" 86 | } 87 | m_mmlu_es: { 88 | "num_fewshot": "25" 89 | } 90 | m_mmlu_fr: { 91 | "num_fewshot": "25" 92 | } 93 | m_mmlu_it: { 94 | "num_fewshot": "25" 95 | } 96 | m_mmlu_nl: { 97 | "num_fewshot": "25" 98 | } 99 | m_mmlu_pt: { 100 | "num_fewshot": "25" 101 | } 102 | m_mmlu_ru: { 103 | "num_fewshot": "25" 104 | } 105 | models: 106 | - name: 107 | path: -------------------------------------------------------------------------------- /configs/tower_paper/0_shot_openai.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions/0_shot" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/0_shot" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/0_shot" 5 | tasks: 6 | - name: mt 7 | subtasks: 8 | flores.en-de: 9 | flores.en-fr: 10 | flores.en-pt: 11 | flores.en-nl: 12 | flores.en-es: 13 | flores.en-it: 14 | flores.en-zh: 15 | gen_args: 16 | eval_args: 17 | metrics: 18 | chrf: 19 | bleu: 20 | tokenizer: zh 21 | comet: 22 | batch_size: 16 23 | flores.en-ko: 24 | gen_args: 25 | eval_args: 26 | metrics: 27 | chrf: 28 | bleu: 29 | tokenizer: ko-mecab 30 | comet: 31 | batch_size: 16 32 | flores.en-ru: 33 | flores.de-en: 34 | flores.fr-en: 35 | flores.pt-en: 36 | flores.nl-en: 37 | flores.es-en: 38 | flores.it-en: 39 | flores.zh-en: 40 | flores.ko-en: 41 | flores.ru-en: 42 | wmt23.en-de: 43 | wmt23.en-ru: 44 | wmt23.en-zh: 45 | eval_args: 46 | metrics: 47 | chrf: 48 | bleu: 49 | tokenizer: zh 50 | comet: 51 | batch_size: 16 52 | wmt23.de-en: 53 | wmt23.ru-en: 54 | wmt23.zh-en: 55 | tico19.en-es: 56 | tico19.en-fr: 57 | tico19.en-pt: 58 | tico19.en-ru: 59 | tico19.en-zh: 60 | gen_args: 61 | eval_args: 62 | metrics: 63 | chrf: 64 | bleu: 65 | tokenizer: zh 66 | comet: 67 | batch_size: 16 68 | metrics: 69 | chrf: 70 | bleu: 71 | comet: 72 | batch_size: 16 73 | - name: ape 74 | subtasks: 75 | nllb_3b_wmt23.de-en: 76 | nllb_3b_wmt23.en-de: 77 | nllb_3b_wmt23.en-zh: 78 | eval_args: 79 | metrics: 80 | ter: 81 | asian_support: True 82 | chrf: 83 | bleu: 84 | tokenizer: zh 85 | comet: 86 | batch_size: 16 87 | comet_kiwi: 88 | batch_size: 16 89 | nllb_3b_wmt23.zh-en: 90 | nllb_3b_wmt23.ru-en: 91 | nllb_3b_wmt23.en-ru: 92 | metrics: 93 | ter: 94 | asian_support: True 95 | chrf: 96 | comet: 97 | batch_size: 16 98 | comet_kiwi: 99 | batch_size: 16 100 | bleu: 101 | - name: ner 102 | subtasks: 103 | multiconer2023.en: 104 | multiconer2023.de: 105 | multiconer2023.fr: 106 | multiconer2023.es: 107 | multiconer2023.it: 108 | multiconer2023.pt: 109 | multiconer2023.zh: 110 | metrics: 111 | f1sequence: 112 | hypothesis_format: "text-tuple-list" 113 | reference_format: "jsonl" 114 | tokenize_hypothesis: False 115 | default_noent_tag: "O" 116 | valid_ner_tags: ["Person", "Location", "Group", "Product", "CreativeWorks", "Medical"] 117 | models: 118 | - name: gpt-3.5-turbo 119 | type: open-ai 120 | arguments: 121 | model: "gpt-3.5-turbo" 122 | max_tokens: 1024 123 | debug: True 124 | retry_max_attempts: 200 125 | - name: gpt-4 126 | type: open-ai 127 | arguments: 128 | model: "gpt-4" 129 | max_tokens: 1024 130 | debug: True 131 | retry_max_attempts: 200 -------------------------------------------------------------------------------- /configs/tower_paper/5_shot_generic_models.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions/5_shot" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/5_shot" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/5_shot" 5 | tasks: 6 | - name: mt 7 | subtasks: 8 | flores.en-de: 9 | flores.en-fr: 10 | flores.en-pt: 11 | flores.en-nl: 12 | flores.en-es: 13 | flores.en-it: 14 | flores.en-zh: 15 | gen_args: 16 | eval_args: 17 | metrics: 18 | chrf: 19 | bleu: 20 | tokenizer: zh 21 | comet: 22 | batch_size: 16 23 | flores.en-ko: 24 | gen_args: 25 | eval_args: 26 | metrics: 27 | chrf: 28 | bleu: 29 | tokenizer: ko-mecab 30 | comet: 31 | batch_size: 16 32 | flores.en-ru: 33 | flores.de-en: 34 | flores.fr-en: 35 | flores.pt-en: 36 | flores.nl-en: 37 | flores.es-en: 38 | flores.it-en: 39 | flores.zh-en: 40 | flores.ko-en: 41 | flores.ru-en: 42 | wmt23.en-de: 43 | wmt23.en-ru: 44 | wmt23.en-zh: 45 | eval_args: 46 | metrics: 47 | chrf: 48 | bleu: 49 | tokenizer: zh 50 | comet: 51 | batch_size: 16 52 | wmt23.de-en: 53 | wmt23.ru-en: 54 | wmt23.zh-en: 55 | tico19.en-es: 56 | tico19.en-fr: 57 | tico19.en-pt: 58 | tico19.en-ru: 59 | tico19.en-zh: 60 | gen_args: 61 | eval_args: 62 | metrics: 63 | chrf: 64 | bleu: 65 | tokenizer: zh 66 | comet: 67 | batch_size: 16 68 | metrics: 69 | chrf: 70 | bleu: 71 | comet: 72 | batch_size: 16 73 | - name: ape 74 | subtasks: 75 | nllb_3b_wmt23.de-en: 76 | nllb_3b_wmt23.en-de: 77 | nllb_3b_wmt23.en-zh: 78 | eval_args: 79 | metrics: 80 | ter: 81 | asian_support: True 82 | chrf: 83 | bleu: 84 | tokenizer: zh 85 | comet: 86 | batch_size: 16 87 | comet_kiwi: 88 | batch_size: 16 89 | nllb_3b_wmt23.ru-en: 90 | metrics: 91 | ter: 92 | asian_support: True 93 | chrf: 94 | comet: 95 | batch_size: 16 96 | comet_kiwi: 97 | batch_size: 16 98 | bleu: 99 | - name: gec 100 | subtasks: 101 | conll14.en: 102 | fm.de: 103 | cowsl2h.es: 104 | metrics: 105 | errant: 106 | tokenize_hypothesis: True 107 | ter: 108 | models: 109 | - name: TowerBase-7B-v0.1 110 | type: vllm 111 | arguments: 112 | model_dir: Unbabel/TowerBase-7B-v0.1 113 | n_gpus: 1 114 | max_tokens: 1024 115 | run_async: True 116 | batch_size: -1 117 | - name: llama2-7b-hf 118 | type: vllm 119 | arguments: 120 | model_dir: meta-llama/Llama-2-7b-hf 121 | n_gpus: 1 122 | max_tokens: 1024 123 | run_async: True 124 | batch_size: -1 125 | - name: llama2-13b-hf 126 | type: vllm 127 | arguments: 128 | model_dir: meta-llama/Llama-2-13b-hf 129 | n_gpus: 1 130 | max_tokens: 1024 131 | run_async: True 132 | batch_size: -1 133 | - name: alma-pretrained-13b 134 | type: vllm 135 | arguments: 136 | model_dir: haoranxu/ALMA-13B 137 | n_gpus: 1 138 | max_tokens: 1024 139 | run_async: True 140 | batch_size: -1 141 | - name: llama2-70b-hf 142 | type: vllm 143 | arguments: 144 | model_dir: meta-llama/Llama-2-70b-hf 145 | n_gpus: 2 146 | max_tokens: 1024 147 | run_async: True 148 | batch_size: -1 149 | - name: mixtral-8x7B-v0.1 150 | type: vllm 151 | arguments: 152 | model_dir: mistralai/Mixtral-8x7B-v0.1 153 | n_gpus: 2 154 | max_tokens: 1024 155 | run_async: True 156 | batch_size: -1 -------------------------------------------------------------------------------- /configs/tower_paper/5_shot_openai.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions/5_shot" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/5_shot" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/5_shot" 5 | tasks: 6 | - name: gec 7 | subtasks: 8 | conll14.en: 9 | fm.de: 10 | cowsl2h.es: 11 | metrics: 12 | errant: 13 | tokenize_hypothesis: True 14 | ter: 15 | models: 16 | - name: gpt-3.5-turbo 17 | type: open-ai 18 | arguments: 19 | model: "gpt-3.5-turbo" 20 | max_tokens: 1024 21 | debug: True 22 | retry_max_attempts: 200 23 | - name: gpt-4 24 | type: open-ai 25 | arguments: 26 | model: "gpt-4" 27 | max_tokens: 1024 28 | debug: True 29 | retry_max_attempts: 200 -------------------------------------------------------------------------------- /configs/tower_paper/tower_instruct_0_shot.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions/0_shot_tower_instruct" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/0_shot_tower_instruct" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/0_shot_tower_instruct" 5 | tasks: 6 | - name: mt 7 | subtasks: 8 | flores.en-de: 9 | flores.en-fr: 10 | flores.en-pt: 11 | flores.en-nl: 12 | flores.en-es: 13 | flores.en-it: 14 | flores.en-zh: 15 | gen_args: 16 | eval_args: 17 | metrics: 18 | chrf: 19 | bleu: 20 | tokenizer: zh 21 | comet: 22 | batch_size: 16 23 | flores.en-ko: 24 | gen_args: 25 | eval_args: 26 | metrics: 27 | chrf: 28 | bleu: 29 | tokenizer: ko-mecab 30 | comet: 31 | batch_size: 16 32 | flores.en-ru: 33 | flores.de-en: 34 | flores.fr-en: 35 | flores.pt-en: 36 | flores.nl-en: 37 | flores.es-en: 38 | flores.it-en: 39 | flores.zh-en: 40 | flores.ko-en: 41 | flores.ru-en: 42 | wmt23.en-de: 43 | wmt23.en-ru: 44 | wmt23.en-zh: 45 | eval_args: 46 | metrics: 47 | chrf: 48 | bleu: 49 | tokenizer: zh 50 | comet: 51 | batch_size: 16 52 | wmt23.de-en: 53 | wmt23.ru-en: 54 | wmt23.zh-en: 55 | tico19.en-es: 56 | tico19.en-fr: 57 | tico19.en-pt: 58 | tico19.en-ru: 59 | tico19.en-zh: 60 | gen_args: 61 | eval_args: 62 | metrics: 63 | chrf: 64 | bleu: 65 | tokenizer: zh 66 | comet: 67 | batch_size: 16 68 | metrics: 69 | chrf: 70 | bleu: 71 | comet: 72 | batch_size: 16 73 | - name: ape 74 | subtasks: 75 | nllb_3b_wmt23.de-en: 76 | nllb_3b_wmt23.en-de: 77 | nllb_3b_wmt23.en-zh: 78 | eval_args: 79 | metrics: 80 | ter: 81 | asian_support: True 82 | chrf: 83 | bleu: 84 | tokenizer: zh 85 | comet: 86 | batch_size: 16 87 | comet_kiwi: 88 | batch_size: 16 89 | nllb_3b_wmt23.zh-en: 90 | nllb_3b_wmt23.ru-en: 91 | nllb_3b_wmt23.en-ru: 92 | metrics: 93 | ter: 94 | asian_support: True 95 | chrf: 96 | comet: 97 | batch_size: 16 98 | comet_kiwi: 99 | batch_size: 16 100 | bleu: 101 | - name: ner 102 | subtasks: 103 | multiconer2023.en: 104 | multiconer2023.de: 105 | multiconer2023.fr: 106 | multiconer2023.es: 107 | multiconer2023.it: 108 | multiconer2023.pt: 109 | multiconer2023.zh: 110 | metrics: 111 | f1sequence: 112 | hypothesis_format: "text-tuple-list" 113 | reference_format: "jsonl" 114 | tokenize_hypothesis: False 115 | default_noent_tag: "O" 116 | valid_ner_tags: ["Person", "Location", "Group", "Product", "CreativeWorks", "Medical"] 117 | models: 118 | - name: TowerInstruct-7B-v0.2 119 | type: vllm 120 | arguments: 121 | model_dir: Unbabel/TowerInstruct-7B-v0.2 122 | n_gpus: 1 123 | max_tokens: 1024 124 | run_async: True 125 | batch_size: -1 126 | strip: False 127 | - name: TowerInstruct-13B-v0.2 128 | type: vllm 129 | arguments: 130 | model_dir: Unbabel/TowerInstruct-7B-v0.2 131 | n_gpus: 1 132 | max_tokens: 1024 133 | run_async: True 134 | batch_size: -1 135 | strip: False 136 | -------------------------------------------------------------------------------- /configs/tower_paper/tower_instruct_5_shot.yaml: -------------------------------------------------------------------------------- 1 | gen_data_dir: "TowerEval-Data-v0.1/data/instructions/5_shot_tower_instruct" 2 | eval_data_dir: "TowerEval-Data-v0.1/data/raw_data" 3 | gen_output_dir: "TowerEval-Data-v0.1/generations/5_shot_tower_instruct" 4 | eval_output_dir: "TowerEval-Data-v0.1/evaluations/5_shot_tower_instruct" 5 | tasks: 6 | - name: gec 7 | subtasks: 8 | conll14.en: 9 | fm.de: 10 | cowsl2h.es: 11 | metrics: 12 | errant: 13 | tokenize_hypothesis: True 14 | ter: 15 | models: 16 | - name: TowerInstruct-7B-v0.2 17 | type: vllm 18 | arguments: 19 | model_dir: Unbabel/TowerInstruct-7B-v0.2 20 | n_gpus: 1 21 | max_tokens: 1024 22 | run_async: True 23 | batch_size: -1 24 | strip: False 25 | - name: TowerInstruct-13B-v0.2 26 | type: vllm 27 | arguments: 28 | model_dir: Unbabel/TowerInstruct-7B-v0.2 29 | n_gpus: 1 30 | max_tokens: 1024 31 | run_async: True 32 | batch_size: -1 33 | strip: False -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "tower-eval" 3 | version = "0.1.0" 4 | description = "LLM generation and evaluation repository for MT and related tasks (e.g., APE, NER, GEC)." 5 | authors = [ 6 | "Amin Farajian ", 7 | "José Pombal " 8 | ] 9 | maintainers = [ 10 | "Amin Farajian ", 11 | "José Pombal " 12 | ] 13 | readme = "README.md" 14 | 15 | [tool.poetry.dependencies] 16 | python = ">=3.9.0,<3.13" 17 | sacrebleu = {extras = ["ko", "ja"], version = "^2.3.1"} 18 | mecab-ko = "^1.0.0" 19 | errant = ">=2.3.3" 20 | unbabel-comet = {git = "https://github.com/Unbabel/COMET.git", branch = "master"} 21 | loguru = "^0.7" 22 | tenacity = "^8.2" 23 | jinja2 = "^3.1" 24 | jupyter = "^1.0.0" 25 | seaborn = "^0.13.0" 26 | spacy = "^3.7.2" 27 | nltk = "^3.8.1" 28 | mosestokenizer = "^1.2.1" 29 | sentence-transformers = "^2.2.2" 30 | faiss-cpu = "^1.7.4" 31 | bleurt-pytorch = {git = "https://github.com/lucadiliello/bleurt-pytorch.git"} 32 | google-cloud-aiplatform = "^1.40.0" 33 | metricx = {git = "https://github.com/ricardorei/metricx.git"} 34 | vllm = "^0.6.4" 35 | anthropic = "^0.40.0" 36 | cohere = "^5.13.3" 37 | deepl = "^1.21.1" 38 | 39 | [tool.poetry.dev-dependencies] 40 | mock = ">=3.0.5,<4.0.0" 41 | coverage = ">=5.5.0,<6.0.0" 42 | 43 | [build-system] 44 | requires = ["poetry-core>=1.0.0"] 45 | build-backend = "poetry.core.masonry.api" 46 | -------------------------------------------------------------------------------- /run_paper_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OPENAI=false 4 | 5 | while (( "$#" )); do 6 | case "$1" in 7 | --openai) 8 | OPENAI=true 9 | shift 10 | ;; 11 | --) # end argument parsing 12 | shift 13 | break 14 | ;; 15 | -*|--*=) # unsupported flags 16 | echo "Error: Unsupported flag $1" >&2 17 | exit 1 18 | ;; 19 | *) # preserve positional arguments 20 | PARAMS="$PARAMS $1" 21 | shift 22 | ;; 23 | esac 24 | done 25 | 26 | if [ "$OPENAI" = true ] ; then 27 | configs=(configs/tower_paper/0_shot_openai.yaml configs/tower_paper/5_shot_generic_models.yaml configs/tower_paper/5_shot_openai.yaml configs/tower_paper/tower_instruct_0_shot.yaml configs/tower_paper/tower_instruct_5_shot.yaml) 28 | echo "Running Tower paper benchmark including open-ai models." 29 | else 30 | configs=(configs/tower_paper/5_shot_generic_models.yaml configs/tower_paper/tower_instruct_0_shot.yaml configs/tower_paper/tower_instruct_5_shot.yaml) 31 | echo "Running Tower paper benchmark for open models only." 32 | fi 33 | 34 | for config in "${configs[@]}"; do 35 | echo "Running $config" 36 | python -m tower_eval.cli gen-eval --config $config 37 | done -------------------------------------------------------------------------------- /run_reference_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | echo "Running WMT23." 3 | python -m tower_eval.cli gen-eval --config configs/reference_benchmark/0_shot_wmt23.yaml 4 | 5 | echo "Running standard multilingual benchmarks." 6 | python -m tower_eval.cli lm_eval --config configs/reference_benchmark/standard_benchmarks.yaml -------------------------------------------------------------------------------- /tower_eval/__init__.py: -------------------------------------------------------------------------------- 1 | __path__ = __import__("pkgutil").extend_path(__path__, __name__) 2 | -------------------------------------------------------------------------------- /tower_eval/error_span_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | DET_REGEX = r"\"(?P.+?)\" - (?Pminor|major?)" 4 | TAG_REGEX = r"(?P|)(?P.+?)(|)" 5 | 6 | 7 | def tag_to_annotation( 8 | generation: str, 9 | mt: str, 10 | ) -> dict[str, str | list[dict[str, str | int]]]: 11 | """Converts text like: "This is an annotation<\\minor>" to a dictionary like: 12 | {"mt": "This is an annotation", "annotations": [{"start": 11, "end": 21, "severity": "minor"}]} 13 | """ 14 | seen_tags = 0 15 | matches = list(re.finditer(TAG_REGEX, generation)) 16 | annotations = [] 17 | for m in matches: 18 | text = m.group("text") 19 | # if text is not in mt, reject annotation 20 | if text in mt: 21 | severity = m.group("severity")[1:-1] # remove < and > 22 | # for every tag we have seen, there are 15 characters that we need to subtract from the start index 23 | start = m.start() - seen_tags * 15 24 | end = start + len(text) 25 | annotations.append( 26 | {"start": start, "end": end, "severity": severity, "text": text} 27 | ) 28 | seen_tags += 1 29 | 30 | return annotations 31 | 32 | 33 | def det_to_annotation( 34 | generation: str, mt: str 35 | ) -> dict[str, str | list[dict[str, str | int]]]: 36 | """Converts text like: "This is an annotation<\\minor>" to a dictionary like: 37 | {"mt": "This is an annotation", "annotations": [{"start": 11, "end": 21, "severity": "minor"}]} 38 | """ 39 | annotations = [] 40 | matches = list(re.finditer(DET_REGEX, generation)) 41 | for m in matches: 42 | text = m.group("text") 43 | # if text is not in mt, reject annotation 44 | if text in mt: 45 | severity = m.group("severity") 46 | # check if flagged text is already in annotations list 47 | # if not, add the first match; if present n times, add the nth + 1 match 48 | equal_matches = list(re.finditer(re.escape(text), mt)) 49 | i_to_append = 0 50 | for a in annotations: 51 | if a["text"] == text: 52 | i_to_append += 1 53 | # for some reason the model flagged the same thing twice; do not append annotation 54 | if i_to_append >= len(equal_matches): 55 | continue 56 | match_to_append = equal_matches[i_to_append] 57 | start, end = match_to_append.span() 58 | annotations.append( 59 | {"start": start, "end": end, "severity": severity, "text": text} 60 | ) 61 | 62 | return annotations 63 | -------------------------------------------------------------------------------- /tower_eval/fewshot_retrieval_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | from tqdm import tqdm 4 | import faiss 5 | import numpy as np 6 | import torch 7 | from sentence_transformers import SentenceTransformer 8 | 9 | from tower_eval.utils import load_data_to_records 10 | 11 | 12 | def random_retrieval( 13 | examples: list[dict[str, str]], 14 | n_shots: int, 15 | total_examples: int, 16 | **kwargs, 17 | ) -> list[list[dict[str, str]]]: 18 | examples_idxs = np.random.choice(len(examples), total_examples) 19 | examples = [examples[i] for i in examples_idxs] 20 | 21 | # split_idxs are the indices where we should split the examples 22 | # If n_shots = 2, then split_idxs = [2, 4, 6] 23 | # meaning the first 2 examples go for instance 1, the next 2 for instance 2, etc. 24 | examples_per_instance = [] 25 | for i in range(0, total_examples - n_shots + 1, n_shots): 26 | examples_per_instance.append(examples[i : i + n_shots]) 27 | 28 | return [list(examples) for examples in examples_per_instance] 29 | 30 | 31 | def ordered_retrieval( 32 | examples: list[dict[str, str]], 33 | n_shots: int, 34 | total_examples: int, 35 | **kwargs, 36 | ) -> list[list[dict[str, str]]]: 37 | examples_per_instance = [] 38 | for _ in range(0, total_examples - n_shots + 1, n_shots): 39 | examples_per_instance.append(examples[:n_shots]) 40 | 41 | return [list(examples) for examples in examples_per_instance] 42 | 43 | 44 | def get_similar_examples(n_examples, encoder, index, seed_sentence): 45 | # encode the seed sentence 46 | question_embedding = encoder.encode(seed_sentence, show_progress_bar=False) 47 | 48 | # search for similar examples 49 | dists, I = index.search( 50 | torch.FloatTensor(question_embedding).unsqueeze(0), n_examples 51 | ) 52 | idxs = I[0] 53 | 54 | return idxs, dists 55 | 56 | 57 | def similarity_retrieval( 58 | test_set: list[dict[str, str]], 59 | examples: list[dict[str, str]], 60 | n_shots: int, 61 | similarity_ordering: str = "descending", 62 | **kwargs, 63 | ) -> list[list[dict[str, str]]]: 64 | # determines whether most similar example will come out at the beginning of the prompt, at the end, or randomly 65 | assert n_shots > 0, "n_shots should be greater than 0." 66 | assert similarity_ordering in [ 67 | "ascending", 68 | "descending", 69 | "random", 70 | ], "similarity_ordering should be either 'ascending' or 'descending', or 'random'." 71 | encoder = SentenceTransformer( 72 | "sentence-transformers/LaBSE", 73 | # device="cpu" 74 | ) 75 | 76 | # load faiss index 77 | index_file = os.path.join(kwargs["index_path"], "knn.index") 78 | index = faiss.read_index(index_file) 79 | 80 | selected_examples = [] 81 | # i = 0 82 | for row in tqdm(test_set): 83 | selected_examples_per_instance = [] 84 | # get similar examples 85 | idxs, dists = get_similar_examples(n_shots, encoder, index, row["src"]) 86 | # by default, the most similar examples will be at the beginning of the prompt 87 | for idx in idxs: 88 | selected_examples_per_instance.append(examples[idx]) 89 | if similarity_ordering == "ascending": 90 | selected_examples_per_instance = selected_examples_per_instance[::-1] 91 | elif similarity_ordering == "random": 92 | np.random.shuffle(selected_examples_per_instance) 93 | selected_examples.append(selected_examples_per_instance) 94 | return selected_examples 95 | 96 | 97 | def force_label_balance_retrieval( 98 | examples: list[dict[str, str]], 99 | n_shots: int, 100 | total_examples: int, 101 | task: str, 102 | n_positive: int = 1, 103 | retrieval: str = "random", 104 | **kwargs, 105 | ) -> list[list[dict[str, str]]]: 106 | """ 107 | Gets few shot examples for APE such that, for each instance, a minimum of positive examples (for which no PE is required) are included in the prompt. 108 | Once example pools are separated, the remaining examples are sampled randomly or ordered, depending on method choice. 109 | """ 110 | assert n_positive < n_shots, "n_positive should be less than or equal to n_shots." 111 | n_negative = n_shots - n_positive 112 | total_positive_examples = n_positive * (total_examples // n_shots) 113 | total_negative_examples = n_negative * (total_examples // n_shots) 114 | # iterate over examples to split them between positive and negative ones 115 | positive_examples, negative_examples = get_positive_negative_examples_from_task( 116 | examples, task 117 | ) 118 | # sample positive examples 119 | _retrieval_func = get_fewshot_retrieval_method(retrieval) 120 | positive_examples = _retrieval_func( 121 | positive_examples, n_positive, total_positive_examples 122 | ) 123 | negative_examples = _retrieval_func( 124 | negative_examples, n_negative, total_negative_examples 125 | ) 126 | # combine positive and negative examples 127 | out_examples = [] 128 | for positive, negative in zip(positive_examples, negative_examples): 129 | joined_examples = positive + negative 130 | np.random.shuffle(joined_examples) 131 | # shuffle the joined list to avoid having all positive examples at the beginning 132 | out_examples.append(joined_examples) 133 | return out_examples 134 | 135 | 136 | def get_positive_negative_examples_from_task(examples: list[dict[str, str]], task: str): 137 | """ 138 | Gets positive and negative examples from a list of examples, given a task. 139 | """ 140 | positive_examples = [] 141 | negative_examples = [] 142 | # each "e" is an example inside examples, which corresponds to a row in the raw data's dataframe 143 | positive_label_condition: Callable[[dict[str, str]], bool] = None 144 | if task in ["ape"]: 145 | positive_label_condition = lambda e: e["mt"] == e["ref"] 146 | elif task in ["paraphrase_identification", "word_sense_disambiguation"]: 147 | positive_label_condition = lambda e: e["answer"] == "Yes" 148 | elif task in ["gec"]: 149 | positive_label_condition = lambda e: e["src"] == e["ref"] 150 | else: 151 | raise NotImplementedError( 152 | f"Retrieval with forced label balance is not implemented for task {task}." 153 | ) 154 | for e in examples: 155 | if positive_label_condition(e): 156 | positive_examples.append(e) 157 | else: 158 | negative_examples.append(e) 159 | return positive_examples, negative_examples 160 | 161 | 162 | def get_fewshot_retrieval_method(method: str) -> Callable: 163 | """Returns a few shot retrieval function, given a method name. Handles exception when method name is not implemented""" 164 | available_fewshot_retrieval_methods = { 165 | "random": random_retrieval, 166 | "ordered": ordered_retrieval, 167 | "force_label_balance": force_label_balance_retrieval, 168 | "similarity": similarity_retrieval, 169 | } 170 | if method is not None: 171 | try: 172 | fewshot_retrieval_method = available_fewshot_retrieval_methods[method] 173 | except KeyError as e: 174 | e( 175 | f"{method} fewshot retrieval method is not implemented. Please choose from {list(available_fewshot_retrieval_methods.keys())}." 176 | ) 177 | else: 178 | fewshot_retrieval_method = None 179 | return fewshot_retrieval_method 180 | 181 | 182 | def load_few_shot_data( 183 | test_set: list[dict[str, str]], 184 | datastore_data_path: str, 185 | n_fewshots: int, 186 | total_shots: int, 187 | fewshot_retrieval_method: str, 188 | task: str, 189 | datastore_index_path: str = None, 190 | fewshot_retrieval_args: dict = {}, 191 | ) -> list[list[dict[str, str]]]: 192 | """ 193 | Loads fewshot data from json or txt file and returns a list of fewshot examples, where each item is a list of examples 194 | pertaining to a single data instance. 195 | """ 196 | # raw data file must be a jsonl 197 | datastore_data = load_data_to_records(datastore_data_path) 198 | # choose method of fewshot retrieval 199 | _fewshot_retrieval_func = get_fewshot_retrieval_method(fewshot_retrieval_method) 200 | fewshot_retrieval_args["task"] = task 201 | fewshot_retrieval_args["index_path"] = datastore_index_path 202 | fewshot_examples_list = _fewshot_retrieval_func( 203 | test_set=test_set, 204 | examples=datastore_data, 205 | n_shots=n_fewshots, 206 | total_examples=total_shots, 207 | **fewshot_retrieval_args, 208 | ) 209 | 210 | return fewshot_examples_list 211 | -------------------------------------------------------------------------------- /tower_eval/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from tower_eval.metrics.accuracy.metric import ACCURACY 2 | from tower_eval.metrics.bleu.metric import BLEU 3 | from tower_eval.metrics.bleurt.metric import BLEURT 4 | from tower_eval.metrics.chrf.metric import CHRF 5 | from tower_eval.metrics.comet.metric import COMET 6 | from tower_eval.metrics.comet_kiwi.metric import COMETKiwi 7 | from tower_eval.metrics.comet_kiwi_23_xxl.metric import COMETKiwi23XXL 8 | from tower_eval.metrics.errant.metric import ERRANT 9 | from tower_eval.metrics.error_span_detection_f1.metric import ErrorSpanDetectionF1 10 | from tower_eval.metrics.error_span_detection_precision.metric import ( 11 | ErrorSpanDetectionPrecision, 12 | ) 13 | from tower_eval.metrics.error_span_detection_recall.metric import ( 14 | ErrorSpanDetectionRecall, 15 | ) 16 | from tower_eval.metrics.f1.metric import F1 17 | from tower_eval.metrics.f1_sequence.metric import F1SEQUENCE 18 | from tower_eval.metrics.metricx.metric import MetricX 19 | from tower_eval.metrics.metricx_24.metric import ( 20 | MetricX_24_Large, 21 | MetricX_24_QE_Large, 22 | MetricX_24_QE_XL, 23 | MetricX_24_QE_XXL, 24 | MetricX_24_XL, 25 | MetricX_24_XXL, 26 | ) 27 | from tower_eval.metrics.metricx_large.metric import MetricXLarge 28 | from tower_eval.metrics.metricx_qe.metric import MetricXQE 29 | from tower_eval.metrics.metricx_qe_large.metric import MetricXQELarge 30 | from tower_eval.metrics.metricx_qe_xxl.metric import MetricXQEXXL 31 | from tower_eval.metrics.metricx_xxl.metric import MetricXXXL 32 | from tower_eval.metrics.pearson.metric import PEARSON 33 | from tower_eval.metrics.perplexity.metric import Perplexity 34 | from tower_eval.metrics.spearman.metric import SPEARMAN 35 | from tower_eval.metrics.ter.metric import TER 36 | from tower_eval.metrics.xcomet_qe_xl.metric import XCOMETQEXL 37 | from tower_eval.metrics.xcomet_qe_xxl.metric import XCOMETQEXXL 38 | from tower_eval.metrics.xcomet_xl.metric import XCOMETXL 39 | from tower_eval.metrics.xcomet_xxl.metric import XCOMETXXL 40 | from tower_eval.metrics.xml_chrf.metric import XML_CHRF 41 | from tower_eval.metrics.xml_match.metric import XML_MATCH 42 | 43 | __all__ = [ 44 | TER, 45 | BLEU, 46 | XCOMETXL, 47 | XCOMETQEXL, 48 | XCOMETXXL, 49 | XCOMETQEXXL, 50 | COMET, 51 | COMETKiwi, 52 | COMETKiwi23XXL, 53 | BLEURT, 54 | CHRF, 55 | ERRANT, 56 | F1, 57 | F1SEQUENCE, 58 | ACCURACY, 59 | PEARSON, 60 | SPEARMAN, 61 | ErrorSpanDetectionF1, 62 | ErrorSpanDetectionRecall, 63 | ErrorSpanDetectionPrecision, 64 | Perplexity, 65 | MetricXLarge, 66 | MetricXQELarge, 67 | MetricX, 68 | MetricXQE, 69 | MetricXQEXXL, 70 | MetricXXXL, 71 | MetricX_24_Large, 72 | MetricX_24_XL, 73 | MetricX_24_XXL, 74 | MetricX_24_QE_Large, 75 | MetricX_24_QE_XL, 76 | MetricX_24_QE_XXL, 77 | XML_CHRF, 78 | XML_MATCH, 79 | ] 80 | 81 | 82 | available_metrics = {metric.metric_name(): metric for metric in __all__} 83 | -------------------------------------------------------------------------------- /tower_eval/metrics/accuracy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/accuracy/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/accuracy/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | 4 | from loguru import logger 5 | from sklearn.metrics import accuracy_score 6 | 7 | from tower_eval.metrics.accuracy.result import AccuracyResult 8 | from tower_eval.metrics.base.metrics_handler import Metric 9 | from tower_eval.metrics.base.result_handler import MetricResult 10 | from tower_eval.utils import text_to_label 11 | 12 | 13 | class ACCURACY(Metric): 14 | def __init__( 15 | self, 16 | **kwargs, 17 | ) -> None: 18 | """Initializes an instance of the Accuracy metric. 19 | 20 | Args: 21 | source_type (str): The type of source data. Either "categorical" or "text". 22 | source_labels (List[str]): A list of labels for the source data. Required if source_type is "text". 23 | **kwargs: Additional keyword arguments. Must include "hypothesis" and "references". 24 | 25 | Raises: 26 | AssertionError: If multiple references are provided. 27 | 28 | Returns: 29 | None 30 | """ 31 | super().__init__(**kwargs) 32 | 33 | def run(self, hypothesis_path, gold_data_path) -> dict: 34 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 35 | reference_lines = gold_data["ref"] 36 | gold_labels = [] 37 | predicted_labels = [] 38 | is_random_count = 0 39 | # if hypothesis is already numbered 40 | for ref_line, hyp_line in zip(reference_lines, hypotheses): 41 | # reference is always assumed to be in categorical format; i.e., [0,1,2,3,...] 42 | gold_labels.append(text_to_label(ref_line, "categorical")) 43 | label, is_random = text_to_label( 44 | hyp_line, 45 | self.source_type, 46 | self.source_labels, 47 | return_is_random=True, 48 | ) 49 | is_random_count += 1 if is_random else 0 50 | predicted_labels.append(label) 51 | # warn user that some labels were randomly assigned 52 | if is_random_count > 0: 53 | pct_random = (is_random_count / len(gold_labels)) * 100 54 | logger.opt(colors=True).warning( 55 | f"{is_random_count} ({pct_random:.2f}% of total) labels did not correspond to any label in source_labels, so a random label was a assigned." 56 | ) 57 | 58 | result = self.evaluate( 59 | gold_labels=gold_labels, predicted_labels=predicted_labels 60 | ) 61 | result.print_result(self.metric_name()) 62 | return result.format_result(self.metric_name()) 63 | 64 | def evaluate(self, gold_labels, predicted_labels) -> AccuracyResult: 65 | """ 66 | Evaluate function receives the gold labels as well as the predicted ones and returns the Accuracy score of the predictions. 67 | The accuracy is calculate by calling the corresponding function in Scikit Learn library 68 | 69 | Args: 70 | gold_labels: The gold labels. 71 | predicted_labels: The predicted labels. 72 | 73 | Returns: 74 | AccuracyResult: The accuracy score. 75 | """ 76 | score = accuracy_score(y_true=gold_labels, y_pred=predicted_labels) 77 | result = AccuracyResult(score) 78 | return result 79 | 80 | def process_result(self, result) -> MetricResult: 81 | pass 82 | 83 | @staticmethod 84 | def metric_name(): 85 | """Returns the name of the metric. 86 | 87 | Returns: 88 | str: The name of the metric. 89 | """ 90 | return "accuracy" 91 | -------------------------------------------------------------------------------- /tower_eval/metrics/accuracy/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class AccuracyResult(MetricResult): 6 | """ 7 | Accuracy Result Handler. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | result: float, 13 | ) -> None: 14 | super().__init__(result) 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/base/comet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from comet import download_model, load_from_checkpoint 3 | 4 | from tower_eval.metrics.base.metrics_handler import Metric 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | 7 | 8 | class BaseCOMETResult(MetricResult): 9 | """ 10 | COMET Result Handler. 11 | """ 12 | 13 | def __init__( 14 | self, 15 | result: float, 16 | ) -> None: 17 | super().__init__(result) 18 | 19 | 20 | class BaseCOMET(Metric): 21 | def __init__(self, model: str, **kwargs) -> None: 22 | super().__init__(**kwargs) 23 | model_path = download_model(model) 24 | self.model = load_from_checkpoint(model_path) 25 | self.model.eval() 26 | 27 | def load_gold_data(self, gold_data): 28 | pass 29 | 30 | def make_samples( 31 | self, sources: list[str], hypotheses: list[str], references: list[str] = None 32 | ): 33 | pass 34 | 35 | def run( 36 | self, 37 | hypothesis_path, 38 | gold_data_path, 39 | gpus: int = 1, 40 | batch_size: int = 16, 41 | **kwargs 42 | ) -> dict: 43 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 44 | references, sources = self.load_gold_data(gold_data) 45 | result = self.evaluate(hypotheses, references, sources, gpus, batch_size) 46 | result.print_result(self.metric_name()) 47 | return result.format_result(self.metric_name()) 48 | 49 | def evaluate( 50 | self, hypotheses: list, references: list, sources: list, gpus, batch_size 51 | ) -> BaseCOMETResult: 52 | """ 53 | Evaluate function receives the hypotheses and the references and returns a COMETResult object. 54 | 55 | :param hypotheses: List of the MT outputs (sentences). 56 | :param references: List of the reference sentences. 57 | :param sources: List of source sentences 58 | """ 59 | samples = self.make_samples(sources, hypotheses, references) 60 | 61 | outputs = self.model.predict( 62 | samples=samples, 63 | batch_size=batch_size, 64 | gpus=gpus, 65 | accelerator="auto", 66 | ) 67 | system_score, segments_scores = outputs.system_score, outputs.scores 68 | 69 | comet_result = BaseCOMETResult( 70 | { 71 | "system_score": system_score, 72 | "segments_scores": segments_scores, 73 | } 74 | ) 75 | return comet_result 76 | 77 | def process_result(self, result) -> MetricResult: 78 | pass 79 | 80 | @staticmethod 81 | def metric_name(): 82 | pass 83 | 84 | 85 | class RefCOMET(BaseCOMET): 86 | def __init__(self, model: str, **kwargs) -> None: 87 | super().__init__(model, **kwargs) 88 | 89 | def load_gold_data(self, gold_data): 90 | references, sources = gold_data["ref"], gold_data["src"] 91 | return references, sources 92 | 93 | def make_samples( 94 | self, sources: list[str], hypotheses: list[str], references: list[str] 95 | ): 96 | samples = {"src": sources, "mt": hypotheses, "ref": references} 97 | samples = [dict(zip(samples, t)) for t in zip(*samples.values())] 98 | return samples 99 | 100 | 101 | class QECOMET(BaseCOMET): 102 | def __init__(self, model: str, **kwargs) -> None: 103 | super().__init__(model, **kwargs) 104 | 105 | def load_gold_data(self, gold_data): 106 | references, sources = None, gold_data["src"] 107 | return references, sources 108 | 109 | def make_samples( 110 | self, sources: list[str], hypotheses: list[str], references: list[str] = None 111 | ): 112 | samples = {"src": sources, "mt": hypotheses} 113 | samples = [dict(zip(samples, t)) for t in zip(*samples.values())] 114 | return samples 115 | -------------------------------------------------------------------------------- /tower_eval/metrics/base/error_span_detection.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pathlib import Path 3 | 4 | from tower_eval.error_span_utils import det_to_annotation, tag_to_annotation 5 | from tower_eval.metrics.base.metrics_handler import Metric 6 | from tower_eval.metrics.base.result_handler import MetricResult 7 | from tower_eval.utils import load_jsonl_file, read_lines 8 | 9 | 10 | class ErrorSpanDetectionMetric(Metric): 11 | def __init__(self, key, **kwargs) -> None: 12 | super().__init__(**kwargs) 13 | self.key = key 14 | 15 | def run( 16 | self, 17 | hypothesis_path, 18 | gold_data_path, 19 | severity_mismatch_penalty: float = 0.5, 20 | hyp_type: str = "jsonl", 21 | **kwargs, 22 | ) -> dict: 23 | hypotheses, gold_data = self._handle_inputs( 24 | hypothesis_path, gold_data_path, hyp_type=hyp_type 25 | ) 26 | reference_list = gold_data["ref"] 27 | result = ErrorSpanDetectionResult( 28 | self.evaluate( 29 | hypotheses, 30 | reference_list, 31 | severity_mismatch_penalty, 32 | )[self.key] 33 | ) 34 | result.print_result(self.metric_name()) 35 | 36 | return result.format_result(self.metric_name()) 37 | 38 | def evaluate( 39 | self, 40 | hypotheses: list, 41 | references: list, 42 | severity_mismatch_penalty: float = 0.5, 43 | ) -> dict: 44 | """ 45 | Computes the Error Span Detection metric. 46 | """ 47 | system_preds = self.load_annotations_from_list(hypotheses) 48 | gold_labels = self.load_annotations_from_list(references) 49 | 50 | tp = 0 51 | tn = 0 52 | fp = 0 53 | total_sys = 0 54 | total_gold = 0 55 | for segid in gold_labels: 56 | for ( 57 | character_gold_major, 58 | character_sys_major, 59 | character_gold_minor, 60 | character_sys_minor, 61 | ) in zip( 62 | gold_labels[segid]["major"], 63 | system_preds[segid]["major"], 64 | gold_labels[segid]["minor"], 65 | system_preds[segid]["minor"], 66 | ): 67 | if character_gold_major != 0 or character_gold_minor != 0: 68 | total_gold += 1 69 | if character_sys_major != 0 or character_sys_minor != 0: 70 | total_sys += 1 71 | if character_gold_major == 0 and character_gold_minor == 0: 72 | if character_sys_major == 0 and character_sys_minor == 0: 73 | tn += 1 74 | else: 75 | # fp+=(character_sys_major + character_sys_minor) 76 | fp += 1 77 | else: 78 | if character_gold_major > 0 and character_gold_minor == 0: 79 | if character_sys_major > 0: 80 | tp += 1 81 | elif character_sys_minor > 0: 82 | tp += severity_mismatch_penalty 83 | elif character_gold_minor > 0 and character_gold_major == 0: 84 | if character_sys_minor > 0: 85 | tp += 1 86 | elif character_sys_major > 0: 87 | tp += severity_mismatch_penalty 88 | elif character_gold_minor > 0 and character_gold_major > 0: 89 | if character_sys_minor > 0 or character_sys_major > 0: 90 | tp += 1 91 | 92 | precision = tp / (total_sys) 93 | recall = tp / (total_gold) 94 | f1 = 2 * precision * recall / (precision + recall) 95 | 96 | return { 97 | "f1": f1, 98 | "precision": precision, 99 | "recall": recall, 100 | } 101 | 102 | def load_annotations_from_list( 103 | self, 104 | segments: list, 105 | ) -> dict: 106 | """Converts list of annotation dictionaries into format that is amenable for evaluation fnctn. 107 | 108 | The file should contain lines with the following format: 109 | {"mt": , "annotations": [{"start": , "end": , "severity": 110 | 111 | Args: 112 | annotations: list of dictionaries with mt and respective annotations. 113 | Returns: 114 | a dictionary mapping document id's to a list of annotations. 115 | """ 116 | out_dict = {} 117 | for i, line in enumerate(segments): 118 | mt = line["mt"] 119 | 120 | seg_id = i 121 | out_dict[seg_id] = {} 122 | out_dict[seg_id]["major"] = [0] * (len(mt) + 1) 123 | out_dict[seg_id]["minor"] = [0] * (len(mt) + 1) 124 | for annotation in line["annotations"]: 125 | s = int(annotation["start"]) 126 | e = int(annotation["end"]) 127 | t = annotation["severity"] 128 | if e > len(mt): 129 | e = len(mt) 130 | 131 | if s != -1 and e != -1: 132 | if s == e: 133 | out_dict[seg_id][t][s] += 1 134 | else: 135 | i = s 136 | while i < e: 137 | out_dict[seg_id][t][i] += 1 138 | i += 1 139 | return out_dict 140 | 141 | def _handle_inputs( 142 | self, 143 | hypotheses: Path, 144 | references: Path, 145 | hyp_type: str, 146 | ) -> tuple: 147 | """ 148 | Function to handle input files. 149 | All inputs will be returned as list of strings. 150 | 151 | :param hypotheses: either the handler to the file storing the hypotheses. 152 | :param references: either the handler to the file storing the refereneces. 153 | :param sources: either the handler to the file storing the source sentences. 154 | 155 | :return: 156 | - Tuple with hypotheses, references and kwargs 157 | """ 158 | hypotheses_list = [] 159 | references_list = load_jsonl_file(references) 160 | mts = [ref["mt"] for ref in references_list] 161 | # if handling existing jsonl file 162 | if hyp_type == "jsonl": 163 | hypotheses_list_no_mt = load_jsonl_file(hypotheses) 164 | for i, mt in enumerate(mts): 165 | hypotheses_list_no_mt[i].update({"mt": mt}) 166 | hypotheses_list.append(hypotheses_list_no_mt[i]) 167 | # if handling model generations 168 | elif hyp_type in ["tag", "det"]: 169 | generations = read_lines(hypotheses, unescape_newline=True) 170 | hypotheses_list = self.parse_generations(generations, mts, hyp_type) 171 | 172 | assert len(hypotheses) == len( 173 | references_list 174 | ), f"The number of hypotheses {len(hypotheses)} and rows in the gold data {len(references_list)} should be the same." 175 | 176 | return hypotheses_list, references_list 177 | 178 | def process_result(self): 179 | pass 180 | 181 | @classmethod 182 | def parse_generations( 183 | self, 184 | generations: list[str], 185 | mts: list[str], 186 | hyp_type: str, 187 | ) -> list[dict[str, list[dict[str, str | int]]]]: 188 | """Takes a list of model generations (strings) and converts them into a list of records, each containing 189 | information like an error span raw test set. 190 | """ 191 | if hyp_type == "tag": 192 | _parsing_func = tag_to_annotation 193 | elif hyp_type == "det": 194 | _parsing_func = det_to_annotation 195 | 196 | return [ 197 | {"mt": mt, "annotations": _parsing_func(g, mt)} 198 | for g, mt in zip(generations, mts) 199 | ] 200 | 201 | 202 | class ErrorSpanDetectionResult(MetricResult): 203 | """ 204 | Error Span Detection Recall Result Handler. 205 | """ 206 | 207 | def __init__( 208 | self, 209 | result: list, 210 | ) -> None: 211 | super().__init__(result) 212 | -------------------------------------------------------------------------------- /tower_eval/metrics/base/metrics_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC, abstractmethod 3 | from pathlib import Path 4 | 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | from tower_eval.utils import list_to_dict, load_jsonl_file, read_lines 7 | 8 | 9 | class Metric(ABC): 10 | """Abstract class defining a shared interface for all the Metrics""" 11 | 12 | def __init__(self, **kwargs) -> None: 13 | self.kwargs = kwargs 14 | 15 | def run(self, hypothesis_path: Path = None, references_path: Path = None, **kwargs): 16 | """ 17 | The runner function that performs all the necessary checks on the files, 18 | applies the needed preprocessing (like lowercasing, tokenization, etc), 19 | and calls the evaluate method of the class to get the scores. 20 | """ 21 | pass 22 | 23 | @abstractmethod 24 | def evaluate(self, hypotheses: list, references: list, **kwargs) -> MetricResult: 25 | """ 26 | Evaluate function receives the hypotheses and the reference files and returns a MetricResult object. 27 | 28 | :param hypotheses: the path to the hypothese file. 29 | :param references: the path to the reference file. 30 | :return: MetricResult object. 31 | """ 32 | pass 33 | 34 | @abstractmethod 35 | def process_result(self, result) -> MetricResult: 36 | """ 37 | Process the result to be ready and complient with the MetricResult format. 38 | 39 | :param result: the raw result produced by the metric 40 | :return: MetricResult object. 41 | """ 42 | pass 43 | 44 | @staticmethod 45 | @abstractmethod 46 | def metric_name() -> None: 47 | """Metric name used to address the metric via cli.""" 48 | pass 49 | 50 | @staticmethod 51 | def _handle_inputs( 52 | hypotheses: Path, 53 | data_path: Path, 54 | ) -> tuple: 55 | """ 56 | Function to handle input files. 57 | All inputs will be returned as list of strings. 58 | 59 | :param hypotheses: either the handler to the file storing the hypotheses. 60 | :param references: either the handler to the file storing the refereneces. 61 | :param sources: either the handler to the file storing the source sentences. 62 | 63 | :return: 64 | - Tuple with hypotheses, references and kwargs 65 | - If sources not None, Tuple with hypotheses, references, sources and kwargs 66 | """ 67 | hypotheses = read_lines(hypotheses, unescape_newline=True) 68 | # gold data keys depend on the task; e.g., for MT, it will include "ref", for APE "pe" 69 | gold_data = load_jsonl_file(data_path) 70 | assert len(hypotheses) == len( 71 | gold_data 72 | ), f"The number of hypotheses ({len(hypotheses)}) and rows in the gold data ({len(gold_data)}) should be the same." 73 | gold_data = list_to_dict(gold_data) 74 | 75 | return hypotheses, gold_data 76 | -------------------------------------------------------------------------------- /tower_eval/metrics/base/metricx.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from datasets import Dataset 4 | from metricx23 import models 5 | from transformers import AutoTokenizer 6 | 7 | from tower_eval.metrics.base.metrics_handler import Metric 8 | from tower_eval.metrics.base.result_handler import MetricResult 9 | 10 | 11 | class BaseMetricXResult(MetricResult): 12 | """ 13 | MetricX Result Handler. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | result: float, 19 | ) -> None: 20 | super().__init__(result) 21 | 22 | 23 | class BaseMetricX(Metric): 24 | def __init__( 25 | self, tokenizer: str, model: str, max_input_length: int, **kwargs 26 | ) -> None: 27 | if torch.cuda.is_available(): 28 | # This refers to the first visible GPU 29 | self.device = torch.device("cuda") 30 | else: 31 | self.device = torch.device("cpu") 32 | super().__init__(**kwargs) 33 | self.max_input_length = max_input_length 34 | self.model = models.MT5ForRegression.from_pretrained(model) 35 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) 36 | self.model.to(self.device) 37 | self.model.eval() 38 | 39 | @staticmethod 40 | def load_gold_data(gold_data): 41 | pass 42 | 43 | @staticmethod 44 | def make_samples( 45 | sources: list[str], hypotheses: list[str], references: list[str] = None 46 | ): 47 | pass 48 | 49 | @staticmethod 50 | def _make_input(example): 51 | pass 52 | 53 | def run(self, hypothesis_path, gold_data_path, **kwargs) -> dict: 54 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 55 | references, sources = self.load_gold_data(gold_data) 56 | result = self.evaluate(sources, hypotheses, references) 57 | result.print_result(self.metric_name()) 58 | return result.format_result(self.metric_name()) 59 | 60 | def process_result(self, result) -> MetricResult: 61 | pass 62 | 63 | def evaluate( 64 | self, sources: list, hypotheses: list, references: list 65 | ) -> BaseMetricXResult: 66 | """ 67 | Evaluate function receives the hypotheses and the references and returns a COMETResult object. 68 | 69 | :param hypotheses: List of the MT outputs (sentences). 70 | :param references: List of the reference sentences. 71 | """ 72 | 73 | def _tokenize(example): 74 | return self.tokenizer( 75 | example["input"], 76 | max_length=self.max_input_length, 77 | truncation=True, 78 | padding=False, 79 | ) 80 | 81 | def _remove_eos(example): 82 | example["input_ids"] = example["input_ids"][:-1] 83 | example["attention_mask"] = example["attention_mask"][:-1] 84 | return example 85 | 86 | samples = self.make_samples( 87 | sources=sources, hypotheses=hypotheses, references=references 88 | ) 89 | ds = Dataset.from_list(samples) 90 | ds = ds.map(self._make_input) 91 | ds = ds.map(_tokenize) 92 | ds = ds.map(_remove_eos) 93 | ds.set_format( 94 | type="torch", 95 | columns=["input_ids", "attention_mask"], 96 | device=self.device, 97 | output_all_columns=True, 98 | ) 99 | with torch.no_grad(): 100 | predictions = [ 101 | self.model( 102 | sample["input_ids"], sample["attention_mask"] 103 | ).predictions.item() 104 | for sample in ds.iter(batch_size=1) 105 | ] 106 | metricx_result = BaseMetricXResult( 107 | { 108 | "system_score": sum(predictions) / len(predictions), 109 | "segments_scores": predictions, 110 | } 111 | ) 112 | return metricx_result 113 | 114 | @staticmethod 115 | def metric_name(): 116 | pass 117 | 118 | 119 | class RefMetricX(BaseMetricX): 120 | def __init__(self, tokenizer: str, model: str, **kwargs) -> None: 121 | super().__init__( 122 | model=model, tokenizer=tokenizer, max_input_length=1024, **kwargs 123 | ) 124 | 125 | @staticmethod 126 | def load_gold_data(gold_data): 127 | references, sources = gold_data["ref"], None 128 | 129 | return references, sources 130 | 131 | @staticmethod 132 | def make_samples( 133 | hypotheses: list[str], references: list[str], sources: list[str] = None 134 | ): 135 | return [ 136 | {"hypothesis": h, "reference": r} for h, r in zip(hypotheses, references) 137 | ] 138 | 139 | @staticmethod 140 | def _make_input(example): 141 | example["input"] = ( 142 | "candidate: " 143 | + example["hypothesis"] 144 | + " reference: " 145 | + example["reference"] 146 | ) 147 | return example 148 | 149 | 150 | class RefMetricX_24(BaseMetricX): 151 | def __init__(self, tokenizer: str, model: str, **kwargs) -> None: 152 | super().__init__( 153 | model=model, tokenizer=tokenizer, max_input_length=1536, **kwargs 154 | ) 155 | 156 | @staticmethod 157 | def load_gold_data(gold_data): 158 | references, sources = gold_data["ref"], gold_data["src"] 159 | 160 | return references, sources 161 | 162 | @staticmethod 163 | def make_samples( 164 | hypotheses: list[str], references: list[str], sources: list[str] = None 165 | ): 166 | return [ 167 | {"hypothesis": h, "reference": r, "source": s} 168 | for h, r, s in zip(hypotheses, references, sources) 169 | ] 170 | 171 | @staticmethod 172 | def _make_input(example): 173 | example["input"] = ( 174 | "source: " 175 | + example["source"] 176 | + " candidate: " 177 | + example["hypothesis"] 178 | + " reference: " 179 | + example["reference"] 180 | ) 181 | return example 182 | 183 | 184 | class QEMetricX(BaseMetricX): 185 | def __init__(self, tokenizer: str, model: str, **kwargs) -> None: 186 | super().__init__( 187 | model=model, tokenizer=tokenizer, max_input_length=1024, **kwargs 188 | ) 189 | 190 | @staticmethod 191 | def load_gold_data(gold_data): 192 | references, sources = None, gold_data["src"] 193 | 194 | return references, sources 195 | 196 | @staticmethod 197 | def make_samples( 198 | sources: list[str], hypotheses: list[str], references: list[str] = None 199 | ): 200 | return [{"hypothesis": h, "source": s} for h, s in zip(hypotheses, sources)] 201 | 202 | @staticmethod 203 | def _make_input(example): 204 | example["input"] = ( 205 | "candidate: " + example["hypothesis"] + " source: " + example["source"] 206 | ) 207 | return example 208 | 209 | 210 | class QEMetricX_24(BaseMetricX): 211 | def __init__(self, tokenizer: str, model: str, **kwargs) -> None: 212 | super().__init__( 213 | model=model, tokenizer=tokenizer, max_input_length=1536, **kwargs 214 | ) 215 | 216 | @staticmethod 217 | def load_gold_data(gold_data): 218 | references, sources = None, gold_data["src"] 219 | 220 | return references, sources 221 | 222 | @staticmethod 223 | def make_samples( 224 | sources: list[str], hypotheses: list[str], references: list[str] = None 225 | ): 226 | return [{"hypothesis": h, "source": s} for h, s in zip(hypotheses, sources)] 227 | 228 | @staticmethod 229 | def _make_input(example): 230 | example["input"] = ( 231 | "source: " + example["source"] + " candidate: " + example["hypothesis"] 232 | ) 233 | return example 234 | -------------------------------------------------------------------------------- /tower_eval/metrics/base/result_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from abc import ABC 3 | 4 | 5 | class MetricResult(ABC): 6 | """ 7 | Abstract class defining a shared interface for all Metric Result Handlers. 8 | 9 | :param result: float value to be displayed. 10 | """ 11 | 12 | def __init__(self, result: float) -> None: 13 | self.result = result 14 | 15 | def print_result(self, metric_name: str, round_to_decimals: int = 4) -> None: 16 | """Function used to display a particular Metric result. 17 | :param metric_name: Metric name. 18 | :param round_to_decimals: decimals that we want to present. 19 | """ 20 | if type(self.result) == dict: 21 | print( 22 | f'{metric_name}: {round(self.result["system_score"], round_to_decimals)}' 23 | ) 24 | else: 25 | print(f"{metric_name}: {round(self.result, round_to_decimals)}") 26 | 27 | def format_result(self, metric_name: str, round_to_decimals: int = 4) -> dict: 28 | """Function used to format a particular Metric result. 29 | :param metric_name: Metric name. 30 | :param round_to_decimals: decimals that we want to present. 31 | """ 32 | if type(self.result) == dict: 33 | out = { 34 | f"{metric_name}": round(self.result["system_score"], round_to_decimals), 35 | f"{metric_name}_segments": self.result["segments_scores"], 36 | } 37 | else: 38 | out = {f"{metric_name}": round(self.result, round_to_decimals)} 39 | return out 40 | -------------------------------------------------------------------------------- /tower_eval/metrics/base/xml_metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from pathlib import Path 3 | 4 | from loguru import logger 5 | 6 | from tower_eval.metrics.base.metrics_handler import Metric 7 | from tower_eval.metrics.base.result_handler import MetricResult 8 | from tower_eval.utils import prepare_xml_markup_pairs 9 | 10 | 11 | class XMLMetric(Metric): 12 | def __init__(self, metric, **kwargs) -> None: 13 | super().__init__(**kwargs) 14 | self.metric = metric() 15 | 16 | def run(self, hypothesis_path, gold_data_path, **kwargs) -> dict: 17 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 18 | references = gold_data["ref"] 19 | result = self.evaluate(hypotheses, references) 20 | result.print_result(self.metric_name()) 21 | return result.format_result(self.metric_name()) 22 | 23 | def process_result(self, result) -> MetricResult: 24 | pass 25 | 26 | def evaluate( 27 | self, 28 | hypotheses: list, 29 | references: list, 30 | ) -> dict: 31 | """ 32 | Evaluate function receives the hypotheses and the references and returns a MetricResult object. 33 | 34 | :param hypotheses: path to the hypotheses file. 35 | :param references: path to the references file. 36 | """ 37 | 38 | assert type(references[0]) == str, logger.error( 39 | "Mutli-reference is not supported for XMLMetrics" 40 | ) 41 | 42 | """ 43 | Based on the information provided in the original papers, xml-chrf (and xml-bleu) are calcualted as follows: 44 | We first use etree to extract the XML structure of the output and reference. 45 | The XML-Match is the percentage of outputs that have exactly the same XML structures as their references. 46 | If the XML structures of an output and its reference match, then the translation and reference are split by the XML tags 47 | and we evaluate the chrF score by comparing each split segment. 48 | If the structures do not match, the chrF score is counted as zero to penalize the irrelevant outputs. 49 | """ 50 | 51 | hypothesis_segmented, references_segmented, non_matching_indices = ( 52 | prepare_xml_markup_pairs(hypotheses, references) 53 | ) 54 | results = self.metric.evaluate(hypothesis_segmented, references_segmented) 55 | segment_scores = results.result["segments_scores"] 56 | # Now, add 0 scores for the instance that their markup structure doesn't match the one of their corresponding references. 57 | # This is to penalise the irrelevant outputs and is based on the information provided in the original paper: 58 | # How Effective is Synthetic Data and Instruction Fine-tuning for Translation with Markup using LLMs? 59 | # https://aclanthology.org/2024.amta-research.8 60 | # Appendix B: Details of Evaluation Metrics 61 | for index in reversed(non_matching_indices): 62 | segment_scores.insert(index, 0.0) 63 | 64 | score = sum(segment_scores) / len(segment_scores) 65 | 66 | result = MetricResult( 67 | { 68 | "system_score": score, 69 | "segments_scores": segment_scores, 70 | } 71 | ) 72 | return result 73 | -------------------------------------------------------------------------------- /tower_eval/metrics/bleu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/bleu/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/bleu/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sacrebleu.metrics import BLEU as SacreBLEU 3 | 4 | from tower_eval.metrics.base.metrics_handler import Metric 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | from tower_eval.metrics.bleu.result import BLEUResult 7 | from tower_eval.utils import get_sacrebleu_segment_scores 8 | 9 | 10 | class BLEU(Metric): 11 | def __init__(self, **kwargs) -> None: 12 | super().__init__(**kwargs) 13 | 14 | def run( 15 | self, 16 | hypothesis_path, 17 | gold_data_path, 18 | lowercase: bool = False, 19 | tokenizer: str = None, 20 | **kwargs 21 | ) -> dict: 22 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 23 | references = gold_data["ref"] 24 | result = self.evaluate( 25 | hypotheses, references, lowercase=lowercase, tokenize=tokenizer 26 | ) 27 | result.print_result(self.metric_name()) 28 | return result.format_result(self.metric_name()) 29 | 30 | def evaluate( 31 | self, 32 | hypothesis: list, 33 | references: list, 34 | lowercase: bool = False, 35 | tokenize: str = None, 36 | ) -> BLEUResult: 37 | """ 38 | Evaluate function receives the hypotheses and the references and returns a BLEUResult object. 39 | The BLEU score is calculate by calling sacreBLEU 40 | 41 | :param hypotheses: path to the hypotheses file. 42 | :param references: path to the references file. 43 | """ 44 | sacrebleu = SacreBLEU(lowercase=lowercase, tokenize=tokenize) 45 | if type(references[0]) == str: 46 | segment_references = [[r] for r in references] 47 | references = [references] 48 | score = sacrebleu.corpus_score(hypothesis, references) 49 | sacrebleu = SacreBLEU( 50 | lowercase=lowercase, tokenize=tokenize, effective_order=True 51 | ) 52 | segment_scores = get_sacrebleu_segment_scores( 53 | hypothesis, segment_references, method=sacrebleu 54 | ) 55 | result = BLEUResult(score.score) 56 | result = BLEUResult( 57 | { 58 | "system_score": score.score, 59 | "segments_scores": segment_scores, 60 | } 61 | ) 62 | return result 63 | 64 | def process_result(self, result) -> MetricResult: 65 | pass 66 | 67 | @staticmethod 68 | def metric_name(): 69 | return "bleu" 70 | -------------------------------------------------------------------------------- /tower_eval/metrics/bleu/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class BLEUResult(MetricResult): 6 | """ 7 | BLEU Result Handler. 8 | TODO: Add the extra information (such as the brevity penalty, scores of different n-grams) to the output. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | result: float, 14 | ) -> None: 15 | super().__init__(result) 16 | -------------------------------------------------------------------------------- /tower_eval/metrics/bleurt/__init__.py: -------------------------------------------------------------------------------- 1 | DEFAULT_COMET_MODEL = "Unbabel/wmt22-comet-da" 2 | -------------------------------------------------------------------------------- /tower_eval/metrics/bleurt/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from bleurt_pytorch import ( 4 | BleurtConfig, 5 | BleurtForSequenceClassification, 6 | BleurtTokenizer, 7 | ) 8 | from tqdm import tqdm 9 | 10 | from tower_eval.metrics.base.metrics_handler import Metric 11 | from tower_eval.metrics.base.result_handler import MetricResult 12 | from tower_eval.metrics.bleurt.result import BLEURTResult 13 | 14 | 15 | class BLEURT(Metric): 16 | def __init__(self, **kwargs) -> None: 17 | super().__init__(**kwargs) 18 | self.config = BleurtConfig.from_pretrained("lucadiliello/BLEURT-20") 19 | self.model = BleurtForSequenceClassification.from_pretrained( 20 | "lucadiliello/BLEURT-20" 21 | ) 22 | self.tokenizer = BleurtTokenizer.from_pretrained("lucadiliello/BLEURT-20") 23 | self.model.eval() 24 | self.model = self.model.to("cuda") 25 | 26 | def run( 27 | self, hypothesis_path, gold_data_path, batch_size: int = 16, **kwargs 28 | ) -> dict: 29 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 30 | references = gold_data["ref"] 31 | result = self.evaluate(hypotheses, references, batch_size) 32 | result.print_result(self.metric_name()) 33 | return result.format_result(self.metric_name()) 34 | 35 | def evaluate( 36 | self, hypotheses: list, references: list, batch_size: int 37 | ) -> BLEURTResult: 38 | """ 39 | Evaluate function receives the hypotheses and the references and returns a COMETResult object. 40 | 41 | :param hypotheses: List of the MT outputs (sentences). 42 | :param references: List of the reference sentences. 43 | :param sources: List of source sentences 44 | """ 45 | segments_scores = [] 46 | for i in tqdm(range(0, len(references), batch_size)): 47 | with torch.no_grad(): 48 | batch_references = references[i : i + batch_size] 49 | batch_hypotheses = hypotheses[i : i + batch_size] 50 | inputs = self.tokenizer( 51 | batch_references, 52 | batch_hypotheses, 53 | padding="longest", 54 | return_tensors="pt", 55 | truncation=True, 56 | ).to("cuda") 57 | segments_scores.extend(self.model(**inputs).logits.flatten().tolist()) 58 | system_score = sum(segments_scores) / len(segments_scores) 59 | 60 | result = BLEURTResult( 61 | { 62 | "system_score": system_score, 63 | "segments_scores": segments_scores, 64 | } 65 | ) 66 | return result 67 | 68 | def process_result(self, result) -> MetricResult: 69 | pass 70 | 71 | @staticmethod 72 | def metric_name(): 73 | return "bleurt" 74 | -------------------------------------------------------------------------------- /tower_eval/metrics/bleurt/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class BLEURTResult(MetricResult): 6 | """ 7 | BLEURT Result Handler. 8 | TODO: Add the segment-level to the output as additional params. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | result: float, 14 | ) -> None: 15 | super().__init__(result) 16 | -------------------------------------------------------------------------------- /tower_eval/metrics/chrf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/chrf/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/chrf/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sacrebleu.metrics import CHRF as SacreCHRF 3 | 4 | from tower_eval.metrics.base.metrics_handler import Metric 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | from tower_eval.metrics.chrf.result import CHRFResult 7 | from tower_eval.utils import get_sacrebleu_segment_scores 8 | 9 | 10 | class CHRF(Metric): 11 | def __init__(self, **kwargs) -> None: 12 | super().__init__(**kwargs) 13 | 14 | def run( 15 | self, hypothesis_path, gold_data_path, lowercase: bool = False, **kwargs 16 | ) -> dict: 17 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 18 | references = gold_data["ref"] 19 | result = self.evaluate(hypotheses, references, lowercase=lowercase) 20 | result.print_result(self.metric_name()) 21 | return result.format_result(self.metric_name()) 22 | 23 | def evaluate( 24 | self, 25 | hypotheses: list, 26 | references: list, 27 | lowercase: bool = False, 28 | ) -> CHRFResult: 29 | """ 30 | Evaluate function receives the hypotheses and the references and returns a CHRFResult object. 31 | The chrF score is calculate by calling sacreBLEU 32 | 33 | :param hypotheses: path to the hypotheses file. 34 | :param references: path to the references file. 35 | """ 36 | chrf = SacreCHRF(lowercase=lowercase) 37 | if type(references[0]) == str: 38 | segment_references = [[r] for r in references] 39 | references = [references] 40 | score = chrf.corpus_score(hypotheses, references) 41 | segment_scores = get_sacrebleu_segment_scores( 42 | hypotheses, segment_references, method=chrf 43 | ) 44 | result = CHRFResult(score.score) 45 | result = CHRFResult( 46 | { 47 | "system_score": score.score, 48 | "segments_scores": segment_scores, 49 | } 50 | ) 51 | return result 52 | 53 | def process_result(self, result) -> MetricResult: 54 | pass 55 | 56 | @staticmethod 57 | def metric_name(): 58 | return "chrf" 59 | -------------------------------------------------------------------------------- /tower_eval/metrics/chrf/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class CHRFResult(MetricResult): 6 | """ 7 | chrF Result Handler. 8 | TODO: Add the extra information (such as the casing, version of the metric, etc) to the output. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | result: float, 14 | ) -> None: 15 | super().__init__(result) 16 | -------------------------------------------------------------------------------- /tower_eval/metrics/comet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/comet/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/comet/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import RefCOMET 3 | 4 | 5 | class COMET(RefCOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/wmt22-comet-da", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "comet" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/comet_kiwi/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/comet_kiwi/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/comet_kiwi/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import QECOMET 3 | 4 | 5 | class COMETKiwi(QECOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/wmt22-cometkiwi-da", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "comet_kiwi" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/comet_kiwi_23_xxl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/comet_kiwi_23_xxl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/comet_kiwi_23_xxl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import QECOMET 3 | 4 | 5 | class COMETKiwi23XXL(QECOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/wmt23-cometkiwi-da-xxl", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "comet_kiwi_23_xxl" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/errant/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/errant/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/errant/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import re 3 | import subprocess 4 | import tempfile 5 | 6 | from tower_eval.metrics.base.metrics_handler import Metric 7 | from tower_eval.metrics.base.result_handler import MetricResult 8 | from tower_eval.metrics.errant.result import ERRANTResult 9 | from tower_eval.utils import tokenize_spacy 10 | 11 | 12 | class ERRANT(Metric): 13 | def __init__(self, **kwargs) -> None: 14 | super().__init__(**kwargs) 15 | 16 | def run( 17 | self, 18 | hypothesis_path, 19 | gold_data_path, 20 | tokenize_source: bool = False, 21 | tokenize_hypothesis: bool = False, 22 | **kwargs 23 | ) -> dict: 24 | language = kwargs["lp"]["src_lang"] 25 | references = kwargs["references_m2"] 26 | hypothesis_m2 = self.preprocess(hypothesis_path, gold_data_path, language, tokenize_source, tokenize_hypothesis) 27 | result = self.evaluate( 28 | hypothesis_m2, references 29 | ) 30 | result.print_result(self.metric_name()) 31 | return result.format_result(self.metric_name()) 32 | 33 | def evaluate( 34 | self, 35 | hypothesis_m2: str, 36 | references: str, 37 | ) -> ERRANTResult: 38 | """ 39 | Evaluate function receives the source, hypothesis as well as the reference and returns an ERRANTResult object. 40 | """ 41 | errant_score = subprocess.run( 42 | ["errant_compare", "-hyp", hypothesis_m2, "-ref", references], 43 | stderr=subprocess.PIPE, 44 | stdout=subprocess.PIPE, 45 | check=True, 46 | ) 47 | pattern = r"\d+\.\d+|\d+" 48 | _, _, score_values, _ = errant_score.stdout.decode("utf-8").strip().split("\n") 49 | matches = re.findall(pattern, score_values) 50 | # Assign the extracted values to the respective fields 51 | fields = ["TP", "FP", "FN", "Prec", "Rec", "F0.5"] 52 | values = [int(matches[i]) if i < 3 else float(matches[i]) for i in range(6)] 53 | # Create the output dictionary 54 | output_dict = {fields[i]: values[i] for i in range(6)} 55 | result = ERRANTResult(output_dict["F0.5"]) 56 | return result 57 | 58 | def preprocess( 59 | self, 60 | hypothesis_path, 61 | gold_data_path, 62 | language, 63 | tokenize_source, 64 | tokenize_hypothesis, 65 | ): 66 | hyp_lines, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 67 | src_lines = gold_data["src"] 68 | 69 | if tokenize_source: 70 | src_tokenized = tokenize_spacy(src_lines, language) 71 | else: 72 | # assumes gold data already has tokenized sources 73 | src_tokenized = gold_data["tok_src"] 74 | if tokenize_hypothesis: 75 | hyp_tokenized = tokenize_spacy(hyp_lines, language) 76 | else: 77 | hyp_tokenized = hyp_lines 78 | 79 | with tempfile.NamedTemporaryFile(mode="w", delete=False) as sfh_out: 80 | for line in src_tokenized: 81 | sfh_out.write(line + "\n") 82 | self.source = sfh_out.name 83 | with tempfile.NamedTemporaryFile(mode="w", delete=False) as hfh_out: 84 | for line in hyp_tokenized: 85 | hfh_out.write(line + "\n") 86 | self.hypothesis = hfh_out.name 87 | 88 | # Create the m2 version of the hypothesis file to be used by the evaluator. 89 | # TODO: replace this part by a script that calls native python code of errant lib 90 | with tempfile.NamedTemporaryFile(mode="w", delete=False) as hyp_m2: 91 | subprocess.run( 92 | [ 93 | "errant_parallel", 94 | "-orig", 95 | self.source, 96 | "-cor", 97 | self.hypothesis, 98 | "-out", 99 | hyp_m2.name, 100 | ], 101 | check=True, 102 | ) 103 | hypothesis_m2 = hyp_m2.name 104 | return hypothesis_m2 105 | 106 | def process_result(self, result) -> MetricResult: 107 | pass 108 | 109 | @staticmethod 110 | def metric_name(): 111 | return "errant" 112 | -------------------------------------------------------------------------------- /tower_eval/metrics/errant/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class ERRANTResult(MetricResult): 6 | """ 7 | ERRANT Result Handler. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | result: float, 13 | ) -> None: 14 | super().__init__(result) 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/error_span_detection_f1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/error_span_detection_f1/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/error_span_detection_f1/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.error_span_detection import ErrorSpanDetectionMetric 3 | 4 | 5 | class ErrorSpanDetectionF1(ErrorSpanDetectionMetric): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(key="f1", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "error-span-detection-f1" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/error_span_detection_precision/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/error_span_detection_precision/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/error_span_detection_precision/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.error_span_detection import ErrorSpanDetectionMetric 3 | 4 | 5 | class ErrorSpanDetectionPrecision(ErrorSpanDetectionMetric): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(key="precision", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "error-span-detection-precision" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/error_span_detection_recall/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/error_span_detection_recall/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/error_span_detection_recall/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.error_span_detection import ErrorSpanDetectionMetric 3 | 4 | 5 | class ErrorSpanDetectionRecall(ErrorSpanDetectionMetric): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(key="recall", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "error-span-detection-recall" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/f1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/f1/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/f1/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List 3 | 4 | from loguru import logger 5 | from sklearn.metrics import f1_score 6 | 7 | from tower_eval.metrics.base.metrics_handler import Metric 8 | from tower_eval.metrics.base.result_handler import MetricResult 9 | from tower_eval.metrics.f1.result import F1Result 10 | from tower_eval.utils import text_to_label 11 | 12 | 13 | class F1(Metric): 14 | def __init__( 15 | self, 16 | **kwargs, 17 | ) -> None: 18 | """Initializes an instance of the f1 metric. 19 | 20 | Args: 21 | source_type (str): The type of source data. Either "categorical" or "text". 22 | source_labels (List[str]): A list of labels for the source data. Required if source_type is "text". 23 | **kwargs: Additional keyword arguments. Must include "hypothesis" and "references". 24 | 25 | Raises: 26 | AssertionError: If multiple references are provided. 27 | 28 | Returns: 29 | None 30 | """ 31 | super().__init__(**kwargs) 32 | 33 | def run( 34 | self, 35 | hypothesis_path, 36 | gold_data_path, 37 | source_type: str = "categorical", 38 | source_labels: List[str] = None, 39 | **kwargs, 40 | ) -> dict: 41 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 42 | reference_lines = gold_data["ref"] 43 | gold_labels = [] 44 | predicted_labels = [] 45 | is_random_count = 0 46 | # if hypothesis is already numbered 47 | for ref_line, hyp_line in zip(reference_lines, hypotheses): 48 | # reference is always assumed to be in categorical format; i.e., [0,1,2,3,...] 49 | gold_labels.append(text_to_label(ref_line, "categorical")) 50 | label, is_random = text_to_label( 51 | hyp_line, 52 | source_type, 53 | source_labels, 54 | return_is_random=True, 55 | ) 56 | is_random_count += 1 if is_random else 0 57 | predicted_labels.append(label) 58 | # warn user that some labels were randomly assigned 59 | if is_random_count > 0: 60 | pct_random = (is_random_count / len(gold_labels)) * 100 61 | logger.opt(colors=True).warning( 62 | f"{is_random_count} ({pct_random:.2f}% of total) labels did not correspond to any label in source_labels, so a random label was a assigned." 63 | ) 64 | 65 | result = self.evaluate( 66 | gold_labels=gold_labels, predicted_labels=predicted_labels 67 | ) 68 | result.print_result(self.metric_name()) 69 | return result.format_result(self.metric_name()) 70 | 71 | def evaluate(self, gold_labels, predicted_labels) -> F1Result: 72 | """ 73 | Evaluate function receives the gold labels as well as the predicted ones and returns the F1 score of the predictions. 74 | The F1 score is calculate by calling the corresponding function in Scikit Learn library 75 | """ 76 | score = f1_score(y_true=gold_labels, y_pred=predicted_labels) 77 | result = F1Result(score) 78 | return result 79 | 80 | def process_result(self, result) -> MetricResult: 81 | pass 82 | 83 | @staticmethod 84 | def metric_name(): 85 | return "f1" 86 | -------------------------------------------------------------------------------- /tower_eval/metrics/f1/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class F1Result(MetricResult): 6 | """ 7 | F1 Result Handler. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | result: float, 13 | ) -> None: 14 | super().__init__(result) 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/f1_sequence/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/f1_sequence/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/f1_sequence/conlleval.py: -------------------------------------------------------------------------------- 1 | # 2 | # The code is directly copied from: 3 | # https://github.com/sighsmile/conlleval/blob/master/conlleval.py 4 | # 5 | 6 | ############### 7 | """ 8 | This script applies to IOB2 or IOBES tagging scheme. 9 | If you are using a different scheme, please convert to IOB2 or IOBES. 10 | 11 | IOB2: 12 | - B = begin, 13 | - I = inside but not the first, 14 | - O = outside 15 | 16 | e.g. 17 | John lives in New York City . 18 | B-PER O O B-LOC I-LOC I-LOC O 19 | 20 | IOBES: 21 | - B = begin, 22 | - E = end, 23 | - S = singleton, 24 | - I = inside but not the first or the last, 25 | - O = outside 26 | 27 | e.g. 28 | John lives in New York City . 29 | S-PER O O B-LOC I-LOC E-LOC O 30 | 31 | prefix: IOBES 32 | chunk_type: PER, LOC, etc. 33 | """ 34 | from __future__ import division, print_function, unicode_literals 35 | 36 | import sys 37 | from collections import defaultdict 38 | 39 | 40 | def split_tag(chunk_tag): 41 | """ 42 | split chunk tag into IOBES prefix and chunk_type 43 | e.g. 44 | B-PER -> (B, PER) 45 | O -> (O, None) 46 | """ 47 | if chunk_tag == "O": 48 | return ("O", None) 49 | return chunk_tag.split("-", maxsplit=1) 50 | 51 | 52 | def is_chunk_end(prev_tag, tag): 53 | """ 54 | check if the previous chunk ended between the previous and current word 55 | e.g. 56 | (B-PER, I-PER) -> False 57 | (B-LOC, O) -> True 58 | 59 | Note: in case of contradicting tags, e.g. (B-PER, I-LOC) 60 | this is considered as (B-PER, B-LOC) 61 | """ 62 | prefix1, chunk_type1 = split_tag(prev_tag) 63 | prefix2, chunk_type2 = split_tag(tag) 64 | 65 | if prefix1 == "O": 66 | return False 67 | if prefix2 == "O": 68 | return prefix1 != "O" 69 | 70 | if chunk_type1 != chunk_type2: 71 | return True 72 | 73 | return prefix2 in ["B", "S"] or prefix1 in ["E", "S"] 74 | 75 | 76 | def is_chunk_start(prev_tag, tag): 77 | """ 78 | check if a new chunk started between the previous and current word 79 | """ 80 | prefix1, chunk_type1 = split_tag(prev_tag) 81 | prefix2, chunk_type2 = split_tag(tag) 82 | 83 | if prefix2 == "O": 84 | return False 85 | if prefix1 == "O": 86 | return prefix2 != "O" 87 | 88 | if chunk_type1 != chunk_type2: 89 | return True 90 | 91 | return prefix2 in ["B", "S"] or prefix1 in ["E", "S"] 92 | 93 | 94 | def calc_metrics(tp, p, t, percent=True): 95 | """ 96 | compute overall precision, recall and FB1 (default values are 0.0) 97 | if percent is True, return 100 * original decimal value 98 | """ 99 | precision = tp / p if p else 0 100 | recall = tp / t if t else 0 101 | fb1 = 2 * precision * recall / (precision + recall) if precision + recall else 0 102 | if percent: 103 | return 100 * precision, 100 * recall, 100 * fb1 104 | else: 105 | return precision, recall, fb1 106 | 107 | 108 | def count_chunks(true_seqs, pred_seqs): 109 | """ 110 | true_seqs: a list of true tags 111 | pred_seqs: a list of predicted tags 112 | 113 | return: 114 | correct_chunks: a dict (counter), 115 | key = chunk types, 116 | value = number of correctly identified chunks per type 117 | true_chunks: a dict, number of true chunks per type 118 | pred_chunks: a dict, number of identified chunks per type 119 | 120 | correct_counts, true_counts, pred_counts: similar to above, but for tags 121 | """ 122 | correct_chunks = defaultdict(int) 123 | true_chunks = defaultdict(int) 124 | pred_chunks = defaultdict(int) 125 | 126 | correct_counts = defaultdict(int) 127 | true_counts = defaultdict(int) 128 | pred_counts = defaultdict(int) 129 | 130 | prev_true_tag, prev_pred_tag = "O", "O" 131 | correct_chunk = None 132 | 133 | for true_tag, pred_tag in zip(true_seqs, pred_seqs): 134 | if true_tag == pred_tag: 135 | correct_counts[true_tag] += 1 136 | true_counts[true_tag] += 1 137 | pred_counts[pred_tag] += 1 138 | 139 | _, true_type = split_tag(true_tag) 140 | _, pred_type = split_tag(pred_tag) 141 | 142 | if correct_chunk is not None: 143 | true_end = is_chunk_end(prev_true_tag, true_tag) 144 | pred_end = is_chunk_end(prev_pred_tag, pred_tag) 145 | 146 | if pred_end and true_end: 147 | correct_chunks[correct_chunk] += 1 148 | correct_chunk = None 149 | elif pred_end != true_end or true_type != pred_type: 150 | correct_chunk = None 151 | 152 | true_start = is_chunk_start(prev_true_tag, true_tag) 153 | pred_start = is_chunk_start(prev_pred_tag, pred_tag) 154 | 155 | if true_start and pred_start and true_type == pred_type: 156 | correct_chunk = true_type 157 | if true_start: 158 | true_chunks[true_type] += 1 159 | if pred_start: 160 | pred_chunks[pred_type] += 1 161 | 162 | prev_true_tag, prev_pred_tag = true_tag, pred_tag 163 | if correct_chunk is not None: 164 | correct_chunks[correct_chunk] += 1 165 | 166 | return ( 167 | correct_chunks, 168 | true_chunks, 169 | pred_chunks, 170 | correct_counts, 171 | true_counts, 172 | pred_counts, 173 | ) 174 | 175 | 176 | def get_result( 177 | correct_chunks, 178 | true_chunks, 179 | pred_chunks, 180 | correct_counts, 181 | true_counts, 182 | pred_counts, 183 | verbose=False, 184 | ): 185 | """ 186 | if verbose, print overall performance, as well as preformance per chunk type; 187 | otherwise, simply return overall prec, rec, f1 scores 188 | """ 189 | # sum counts 190 | sum_correct_chunks = sum(correct_chunks.values()) 191 | sum_true_chunks = sum(true_chunks.values()) 192 | sum_pred_chunks = sum(pred_chunks.values()) 193 | 194 | sum_correct_counts = sum(correct_counts.values()) 195 | sum_true_counts = sum(true_counts.values()) 196 | 197 | nonO_correct_counts = sum(v for k, v in correct_counts.items() if k != "O") 198 | nonO_true_counts = sum(v for k, v in true_counts.items() if k != "O") 199 | 200 | chunk_types = sorted(list(set(list(true_chunks) + list(pred_chunks)))) 201 | 202 | # compute overall precision, recall and FB1 (default values are 0.0) 203 | prec, rec, f1 = calc_metrics(sum_correct_chunks, sum_pred_chunks, sum_true_chunks) 204 | res = (prec, rec, f1) 205 | if not verbose: 206 | return res 207 | 208 | # print overall performance, and performance per chunk type 209 | 210 | print( 211 | "processed %i tokens with %i phrases; " % (sum_true_counts, sum_true_chunks), 212 | end="", 213 | ) 214 | print( 215 | "found: %i phrases; correct: %i.\n" % (sum_pred_chunks, sum_correct_chunks), 216 | end="", 217 | ) 218 | 219 | print("accuracy: %6.2f%%; (non-O)" % (100 * nonO_correct_counts / nonO_true_counts)) 220 | print("accuracy: %6.2f%%; " % (100 * sum_correct_counts / sum_true_counts), end="") 221 | print("precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1)) 222 | 223 | # for each chunk type, compute precision, recall and FB1 (default values are 0.0) 224 | for t in chunk_types: 225 | prec, rec, f1 = calc_metrics(correct_chunks[t], pred_chunks[t], true_chunks[t]) 226 | print("%17s: " % t, end="") 227 | print( 228 | "precision: %6.2f%%; recall: %6.2f%%; FB1: %6.2f" % (prec, rec, f1), end="" 229 | ) 230 | print(" %d" % pred_chunks[t]) 231 | 232 | return res 233 | # you can generate LaTeX output for tables like in 234 | # http://cnts.uia.ac.be/conll2003/ner/example.tex 235 | # but I'm not implementing this 236 | 237 | 238 | def evaluate(true_seqs, pred_seqs, verbose=False): 239 | ( 240 | correct_chunks, 241 | true_chunks, 242 | pred_chunks, 243 | correct_counts, 244 | true_counts, 245 | pred_counts, 246 | ) = count_chunks(true_seqs, pred_seqs) 247 | result = get_result( 248 | correct_chunks, 249 | true_chunks, 250 | pred_chunks, 251 | correct_counts, 252 | true_counts, 253 | pred_counts, 254 | verbose=verbose, 255 | ) 256 | return result 257 | 258 | 259 | def evaluate_conll_file(fileIterator): 260 | true_seqs, pred_seqs = [], [] 261 | 262 | for line in fileIterator: 263 | cols = line.strip().split() 264 | # each non-empty line must contain >= 3 columns 265 | if not cols: 266 | true_seqs.append("O") 267 | pred_seqs.append("O") 268 | elif len(cols) < 3: 269 | raise IOError("conlleval: too few columns in line %s\n" % line) 270 | else: 271 | # extract tags from last 2 columns 272 | true_seqs.append(cols[-2]) 273 | pred_seqs.append(cols[-1]) 274 | return evaluate(true_seqs, pred_seqs) 275 | 276 | 277 | if __name__ == "__main__": 278 | """ 279 | usage: conlleval < file 280 | """ 281 | evaluate_conll_file(sys.stdin) 282 | -------------------------------------------------------------------------------- /tower_eval/metrics/f1_sequence/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Dict, List, Union 3 | 4 | from tower_eval.metrics.base.result_handler import MetricResult 5 | 6 | 7 | class F1SequenceResult(MetricResult): 8 | """ 9 | F1Sequence Result Handler. 10 | """ 11 | 12 | def __init__( 13 | self, 14 | result: float, 15 | tags_f1: Dict[str, float], 16 | valid_tags: List[str], 17 | ) -> None: 18 | super().__init__(result) 19 | self.tags_f1 = tags_f1 20 | self.valid_tags = valid_tags 21 | 22 | def print_result(self, metric_name: str, round_to_decimals: int = 4) -> None: 23 | """Function used to display a particular Metric result. 24 | :param metric_name: Metric name. 25 | :param round_to_decimals: decimals that we want to present. 26 | """ 27 | print(f"{metric_name}: {round(self.result, round_to_decimals)}") 28 | for tag in self.valid_tags: 29 | print(f"{tag}: {round(self.tags_f1[tag], round_to_decimals)}") 30 | 31 | def format_result( 32 | self, metric_name: str, round_to_decimals: int = 4 33 | ) -> Dict[str, Union[float, Dict[str, float]]]: 34 | """Function used to format a particular Metric result. 35 | :param metric_name: Metric name. 36 | :param round_to_decimals: decimals that we want to present. 37 | """ 38 | out = {} 39 | out[f"{metric_name}"] = round(self.result, round_to_decimals) 40 | out["valid_tags"] = self.valid_tags 41 | out[f"{metric_name}_by_tag"] = { 42 | tag: round(self.tags_f1[tag], round_to_decimals) for tag in self.valid_tags 43 | } 44 | return out 45 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import RefMetricX 3 | 4 | 5 | class MetricX(RefMetricX): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-23-xl-v2p0", tokenizer="google/mt5-xl", **kwargs 9 | ) 10 | 11 | @staticmethod 12 | def metric_name(): 13 | return "metricx" 14 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_24/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx_24/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_24/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import QEMetricX_24, RefMetricX_24 3 | 4 | 5 | class MetricX_24_Large(RefMetricX_24): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-24-hybrid-large-v2p6", 9 | tokenizer="google/mt5-xl", 10 | **kwargs 11 | ) 12 | 13 | @staticmethod 14 | def metric_name(): 15 | return "metricx_24_large" 16 | 17 | 18 | class MetricX_24_XL(RefMetricX_24): 19 | def __init__(self, **kwargs) -> None: 20 | super().__init__( 21 | model="google/metricx-24-hybrid-xl-v2p6", 22 | tokenizer="google/mt5-xl", 23 | **kwargs 24 | ) 25 | 26 | @staticmethod 27 | def metric_name(): 28 | return "metricx_24_xl" 29 | 30 | 31 | class MetricX_24_XXL(RefMetricX_24): 32 | def __init__(self, **kwargs) -> None: 33 | super().__init__( 34 | model="google/metricx-24-hybrid-xxl-v2p6", 35 | tokenizer="google/mt5-xl", 36 | **kwargs 37 | ) 38 | 39 | @staticmethod 40 | def metric_name(): 41 | return "metricx_24_xxl" 42 | 43 | 44 | ### QE ### 45 | 46 | 47 | class MetricX_24_QE_Large(QEMetricX_24): 48 | def __init__(self, **kwargs) -> None: 49 | super().__init__( 50 | model="google/metricx-24-hybrid-large-v2p6", 51 | tokenizer="google/mt5-xl", 52 | **kwargs 53 | ) 54 | 55 | @staticmethod 56 | def metric_name(): 57 | return "metricx_24_qe_large" 58 | 59 | 60 | class MetricX_24_QE_XL(QEMetricX_24): 61 | def __init__(self, **kwargs) -> None: 62 | super().__init__( 63 | model="google/metricx-24-hybrid-xl-v2p6", 64 | tokenizer="google/mt5-xl", 65 | **kwargs 66 | ) 67 | 68 | @staticmethod 69 | def metric_name(): 70 | return "metricx_24_qe_xl" 71 | 72 | 73 | class MetricX_24_QE_XXL(QEMetricX_24): 74 | def __init__(self, **kwargs) -> None: 75 | super().__init__( 76 | model="google/metricx-24-hybrid-xxl-v2p6", 77 | tokenizer="google/mt5-xl", 78 | **kwargs 79 | ) 80 | 81 | @staticmethod 82 | def metric_name(): 83 | return "metricx_24_qe_xxl" 84 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_large/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx_large/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_large/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import RefMetricX 3 | 4 | 5 | class MetricXLarge(RefMetricX): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-23-large-v2p0", tokenizer="google/mt5-large", **kwargs 9 | ) 10 | 11 | @staticmethod 12 | def metric_name(): 13 | return "metricx_large" 14 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_qe/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx_qe/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_qe/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import QEMetricX 3 | 4 | 5 | class MetricXQE(QEMetricX): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-23-qe-xl-v2p0", tokenizer="google/mt5-xl", **kwargs 9 | ) 10 | 11 | @staticmethod 12 | def metric_name(): 13 | return "metricx_qe" 14 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_qe_large/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx_qe_large/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_qe_large/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import QEMetricX 3 | 4 | 5 | class MetricXQELarge(QEMetricX): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-23-qe-large-v2p0", 9 | tokenizer="google/mt5-large", 10 | **kwargs 11 | ) 12 | 13 | @staticmethod 14 | def metric_name(): 15 | return "metricx_qe_large" 16 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_qe_xxl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx_qe_xxl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_qe_xxl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import QEMetricX 3 | 4 | 5 | class MetricXQEXXL(QEMetricX): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-23-qe-xxl-v2p0", tokenizer="google/mt5-xxl", **kwargs 9 | ) 10 | 11 | @staticmethod 12 | def metric_name(): 13 | return "metricx_qe_xxl" 14 | -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_xxl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/metricx_xxl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/metricx_xxl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.metricx import RefMetricX 3 | 4 | 5 | class MetricXXXL(RefMetricX): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__( 8 | model="google/metricx-23-xxl-v2p0", tokenizer="google/mt5-xxl", **kwargs 9 | ) 10 | 11 | @staticmethod 12 | def metric_name(): 13 | return "metricx_xxl" 14 | -------------------------------------------------------------------------------- /tower_eval/metrics/pearson/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/pearson/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/pearson/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from scipy.stats import pearsonr 3 | 4 | from tower_eval.metrics.base.metrics_handler import Metric 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | from tower_eval.metrics.pearson.result import PearsonResult 7 | 8 | 9 | class PEARSON(Metric): 10 | def __init__(self, **kwargs) -> None: 11 | super().__init__(**kwargs) 12 | 13 | def run(self, hypothesis_path, gold_data_path, **kwargs) -> dict: 14 | predicted_scores, gold_data = self._handle_inputs( 15 | hypothesis_path, gold_data_path 16 | ) 17 | gold_scores = gold_data["score"] 18 | 19 | result = self.evaluate( 20 | gold_scores=gold_scores, predicted_scores=predicted_scores 21 | ) 22 | result.print_result(self.metric_name()) 23 | return result.format_result(self.metric_name()) 24 | 25 | def evaluate(self, gold_scores, predicted_scores) -> PearsonResult: 26 | """ 27 | Evaluate function receives the gold scores as well as the predicted ones and returns the Pearson Correlation Coefficient score of the predictions and the gold scores. 28 | The Pearson Correlation Coefficient is calculate by calling the corresponding function in Scikit Learn library 29 | """ 30 | pearson_corr_coef = pearsonr(gold_scores, predicted_scores) 31 | statistic, pvalue = pearson_corr_coef.statistic, pearson_corr_coef.pvalue 32 | result = PearsonResult(statistic) 33 | return result 34 | 35 | def process_result(self, result) -> MetricResult: 36 | pass 37 | 38 | @staticmethod 39 | def metric_name(): 40 | return "pearson" 41 | -------------------------------------------------------------------------------- /tower_eval/metrics/pearson/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class PearsonResult(MetricResult): 6 | """ 7 | Pearson Correlation Result Handler. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | result: float, 13 | ) -> None: 14 | super().__init__(result) 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/perplexity/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/perplexity/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/perplexity/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import numpy as np 8 | 9 | from tower_eval.metrics.base.metrics_handler import Metric 10 | from tower_eval.metrics.base.result_handler import MetricResult 11 | from tower_eval.metrics.perplexity.result import PerplexityResult 12 | from tower_eval.utils import handle_subprocess, list_to_dict, load_jsonl_file 13 | 14 | 15 | class Perplexity(Metric): 16 | def __init__(self, model: str, max_model_context: int, **kwargs) -> None: 17 | """ 18 | Calculates perplexity over some corpus. Truncates the ending of each instance to fit the max_model_context. 19 | """ 20 | super().__init__(**kwargs) 21 | self.model_id = model 22 | self.max_model_context = max_model_context 23 | self.vllm_args = kwargs.get( 24 | "vllm_args", 25 | { 26 | "gpu_memory_utilization": 0.9, 27 | "tensor_parallel_size": 1, 28 | "trust_remote_code": True, 29 | }, 30 | ) 31 | 32 | @staticmethod 33 | def _handle_inputs( 34 | data_path: Path, 35 | ) -> tuple: 36 | """ """ 37 | gold_data = load_jsonl_file(data_path) 38 | gold_data = list_to_dict(gold_data) 39 | 40 | return gold_data 41 | 42 | def run( 43 | self, 44 | gold_data_path, 45 | model_id: str, 46 | max_model_context: int, 47 | hypothesis_path=None, 48 | vllm_args: dict = { 49 | "gpu_memory_utilization": 0.9, 50 | "tensor_parallel_size": 1, 51 | "trust_remote_code": True, 52 | }, 53 | **kwargs, 54 | ) -> dict: 55 | result = self.evaluate(gold_data_path, model_id, max_model_context, vllm_args) 56 | result.print_result(self.metric_name()) 57 | return result.format_result(self.metric_name()) 58 | 59 | def evaluate( 60 | self, 61 | gold_data_path: Path, 62 | model_id: str, 63 | max_model_context: int, 64 | vllm_args: dict, 65 | ) -> PerplexityResult: 66 | """ 67 | Evaluate function receives the hypotheses and the references and returns a COMETResult object. 68 | 69 | :param hypotheses: List of the MT outputs (sentences). 70 | :param references: List of the reference sentences. 71 | :param sources: List of source sentences 72 | """ 73 | current_dir = os.getcwd() 74 | subprocess_args = [ 75 | f"python", 76 | f"{current_dir}/tower_eval/metrics/perplexity/vllm_subprocess.py", 77 | "--gold_data_path", 78 | f"{str(gold_data_path)}", 79 | "--model_id", 80 | f"{model_id}", 81 | "--max_model_context", 82 | f"{str(max_model_context)}", 83 | ] 84 | output = handle_subprocess(subprocess_args, check_output=True) 85 | perplexities, mean_perplexity = self.parse_subprocess_output(output) 86 | assert perplexities is not None and mean_perplexity is not None, ( 87 | f"Error when parsing the output of the perplexity subprocess, aborting." 88 | f"Output: {output}" 89 | ) 90 | result = PerplexityResult( 91 | { 92 | "system_score": mean_perplexity, 93 | "segments_scores": perplexities, 94 | } 95 | ) 96 | return result 97 | 98 | @staticmethod 99 | def parse_subprocess_output(output: str): 100 | # Correcting the regular expression to accurately capture the perplexities and mean perplexity 101 | regex = r"PERPLEXITIES: \[([0-9.,\s]+)\]\n PERPLEXITY: ([0-9.]+)" 102 | 103 | match = re.search(regex, output) 104 | 105 | if match: 106 | perplexities = list(map(float, match.group(1).split(", "))) 107 | mean_perplexity = float(match.group(2)) 108 | else: 109 | perplexities, mean_perplexity = None, None 110 | 111 | return perplexities, mean_perplexity 112 | 113 | @staticmethod 114 | def truncate_prompts(prompts: List[str], max_model_context: int): 115 | return [prompt[:max_model_context] for prompt in prompts] 116 | 117 | @staticmethod 118 | def get_perplexity_from_vllm_output(vllm_output): 119 | perplexities = [] 120 | for output in vllm_output: 121 | l_probs = [] 122 | # ignore very first token, for which there is no logprob 123 | for l_prob_dict in output.prompt_logprobs[1:]: 124 | l_probs.append(list(l_prob_dict.values())[0].logprob) 125 | perplexities.append(np.exp(-(sum(l_probs) / len(l_probs)))) 126 | mean_perplexity = np.mean(perplexities).astype(float) 127 | 128 | return perplexities, mean_perplexity 129 | 130 | def process_result(self, result) -> MetricResult: 131 | pass 132 | 133 | @staticmethod 134 | def metric_name(): 135 | return "perplexity" 136 | -------------------------------------------------------------------------------- /tower_eval/metrics/perplexity/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class PerplexityResult(MetricResult): 6 | """ 7 | Perplexity result handler 8 | """ 9 | 10 | def __init__( 11 | self, 12 | result: float, 13 | ) -> None: 14 | super().__init__(result) 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/perplexity/vllm_subprocess.py: -------------------------------------------------------------------------------- 1 | import vllm 2 | from jsonargparse import CLI 3 | from loguru import logger 4 | 5 | from tower_eval.metrics.perplexity.metric import Perplexity 6 | from tower_eval.utils import tokenize_text 7 | 8 | 9 | def main(gold_data_path: str, model_id: str, max_model_context: int): 10 | # load data 11 | gold_data = Perplexity._handle_inputs(gold_data_path) 12 | references = gold_data["text"] 13 | 14 | # vllm 15 | llm = vllm.LLM(model=model_id, enforce_eager=True, gpu_memory_utilization=0.9) 16 | tokenizer = llm.get_tokenizer() 17 | sampling_params = vllm.SamplingParams( 18 | max_tokens=1, temperature=0.0, prompt_logprobs=1 19 | ) 20 | tokenized_prompts = tokenize_text(references, tokenizer) 21 | truncated_prompts = Perplexity.truncate_prompts( 22 | tokenized_prompts, max_model_context 23 | ) 24 | logger.warning( 25 | f"Truncating prompts to fit model max context of {max_model_context}" 26 | ) 27 | model_output = llm.generate( 28 | prompt_token_ids=truncated_prompts, 29 | sampling_params=sampling_params, 30 | use_tqdm=True, 31 | ) 32 | perplexities, mean_perplexity = Perplexity.get_perplexity_from_vllm_output( 33 | model_output 34 | ) 35 | 36 | print(f"""PERPLEXITIES: {perplexities}\n PERPLEXITY: {mean_perplexity}""") 37 | 38 | 39 | if __name__ == "__main__": 40 | CLI([main], as_positional=False) 41 | -------------------------------------------------------------------------------- /tower_eval/metrics/spearman/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/spearman/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/spearman/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from scipy.stats import spearmanr 3 | 4 | from tower_eval.metrics.base.metrics_handler import Metric 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | from tower_eval.metrics.spearman.result import SpearmanResult 7 | 8 | 9 | class SPEARMAN(Metric): 10 | def __init__(self, **kwargs) -> None: 11 | super().__init__(**kwargs) 12 | 13 | def run(self, hypothesis_path, gold_data_path, **kwargs) -> dict: 14 | predicted_scores, gold_data = self._handle_inputs( 15 | hypothesis_path, gold_data_path 16 | ) 17 | gold_scores = gold_data["score"] 18 | 19 | result = self.evaluate( 20 | gold_scores=gold_scores, predicted_scores=predicted_scores 21 | ) 22 | result.print_result(self.metric_name()) 23 | return result.format_result(self.metric_name()) 24 | 25 | def evaluate(self, gold_scores, predicted_scores) -> SpearmanResult: 26 | """ 27 | Evaluate function receives the gold scores as well as the predicted ones and returns the Pearson Correlation Coefficient score of the predictions and the gold scores. 28 | The Pearson Correlation Coefficient is calculate by calling the corresponding function in Scikit Learn library 29 | """ 30 | spearman_corr_coef = spearmanr(gold_scores, predicted_scores) 31 | statistic, pvalue = spearman_corr_coef.statistic, spearman_corr_coef.pvalue 32 | result = SpearmanResult(statistic) 33 | return result 34 | 35 | def process_result(self, result) -> MetricResult: 36 | pass 37 | 38 | @staticmethod 39 | def metric_name(): 40 | return "spearman" 41 | -------------------------------------------------------------------------------- /tower_eval/metrics/spearman/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class SpearmanResult(MetricResult): 6 | """ 7 | Spearman Correlation Result Handler. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | result: float, 13 | ) -> None: 14 | super().__init__(result) 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/ter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/ter/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/ter/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from sacrebleu.metrics import TER as SacreTER 3 | 4 | from tower_eval.metrics.base.metrics_handler import Metric 5 | from tower_eval.metrics.base.result_handler import MetricResult 6 | from tower_eval.metrics.ter.result import TERResult 7 | 8 | 9 | class TER(Metric): 10 | def __init__(self, **kwargs) -> None: 11 | super().__init__(**kwargs) 12 | 13 | def run( 14 | self, 15 | hypothesis_path, 16 | gold_data_path, 17 | normalized: bool = False, 18 | no_punct: bool = False, 19 | asian_support: bool = False, 20 | case_sensitive: bool = False, 21 | **kwargs 22 | ) -> dict: 23 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 24 | references = gold_data["ref"] 25 | result = self.evaluate( 26 | hypotheses, references, normalized, no_punct, asian_support, case_sensitive 27 | ) 28 | result.print_result(self.metric_name()) 29 | return result.format_result(self.metric_name()) 30 | 31 | def evaluate( 32 | self, 33 | hypotheses: list, 34 | references: list, 35 | normalized: bool = False, 36 | no_punct: bool = False, 37 | asian_support: bool = False, 38 | case_sensitive: bool = False, 39 | ) -> TERResult: 40 | """ 41 | Evaluate function receives the hypotheses and the references and returns a TERResult object. 42 | The TER score is calculate by calling sacreBLEU 43 | 44 | :param hypotheses: path to the hypotheses file. 45 | :param references: path to the references file. 46 | """ 47 | ter = SacreTER( 48 | normalized=normalized, 49 | no_punct=no_punct, 50 | asian_support=asian_support, 51 | case_sensitive=case_sensitive, 52 | ) 53 | if type(references[0]) == str: 54 | references = [references] 55 | score = ter.corpus_score(hypotheses, references) 56 | result = TERResult(score.score) 57 | return result 58 | 59 | def process_result(self, result) -> MetricResult: 60 | pass 61 | 62 | @staticmethod 63 | def metric_name(): 64 | return "ter" 65 | -------------------------------------------------------------------------------- /tower_eval/metrics/ter/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class TERResult(MetricResult): 6 | """ 7 | TER Result Handler. 8 | TODO: Add the extra information (such as the casing, version of the metric, etc) to the output. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | result: float, 14 | ) -> None: 15 | super().__init__(result) 16 | -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_qe_xl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/xcomet_qe_xl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_qe_xl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import QECOMET 3 | 4 | 5 | class XCOMETQEXL(QECOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/XCOMET-XL", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "xcomet_qe_xl" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_qe_xxl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/xcomet_qe_xxl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_qe_xxl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import QECOMET 3 | 4 | 5 | class XCOMETQEXXL(QECOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/XCOMET-XXL", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "xcomet_qe_xxl" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_xl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/xcomet_xl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_xl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import RefCOMET 3 | 4 | 5 | class XCOMETXL(RefCOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/XCOMET-XL", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "xcomet_xl" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_xxl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/xcomet_xxl/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/xcomet_xxl/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.comet import RefCOMET 3 | 4 | 5 | class XCOMETXXL(RefCOMET): 6 | def __init__(self, **kwargs) -> None: 7 | super().__init__(model="Unbabel/XCOMET-XXL", **kwargs) 8 | 9 | @staticmethod 10 | def metric_name(): 11 | return "xcomet_xxl" 12 | -------------------------------------------------------------------------------- /tower_eval/metrics/xml_chrf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/xml_chrf/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/xml_chrf/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from loguru import logger 3 | 4 | from tower_eval.metrics.base.xml_metric import XMLMetric 5 | from tower_eval.metrics.chrf.metric import CHRF 6 | 7 | 8 | class XML_CHRF(XMLMetric): 9 | def __init__(self, **kwargs) -> None: 10 | super().__init__(CHRF, **kwargs) 11 | 12 | @staticmethod 13 | def metric_name(): 14 | return "xml_chrf" 15 | -------------------------------------------------------------------------------- /tower_eval/metrics/xml_match/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/metrics/xml_match/__init__.py -------------------------------------------------------------------------------- /tower_eval/metrics/xml_match/metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from loguru import logger 3 | from lxml import etree 4 | 5 | from tower_eval.metrics.base.metrics_handler import Metric 6 | from tower_eval.metrics.base.result_handler import MetricResult 7 | from tower_eval.metrics.xml_match.result import XML_MatchResult 8 | from tower_eval.utils import match_xml 9 | 10 | 11 | class XML_MATCH(Metric): 12 | def __init__(self, **kwargs) -> None: 13 | super().__init__(**kwargs) 14 | 15 | def run( 16 | self, hypothesis_path, gold_data_path, lowercase: bool = False, **kwargs 17 | ) -> dict: 18 | hypotheses, gold_data = self._handle_inputs(hypothesis_path, gold_data_path) 19 | references = gold_data["ref"] 20 | 21 | result = self.evaluate(hypotheses, references) 22 | result.print_result(self.metric_name()) 23 | return result.format_result(self.metric_name()) 24 | 25 | def evaluate( 26 | self, 27 | hypotheses: list, 28 | references: list, 29 | ) -> XML_MatchResult: 30 | """ 31 | Evaluate function receives the hypotheses and the references and returns a XML_MatchResult object. 32 | 33 | :param hypotheses: path to the hypotheses file. 34 | :param references: path to the references file. 35 | """ 36 | assert type(references[0]) == str, logger.error( 37 | "Mutli-reference is not supported for XML_CHRF" 38 | ) 39 | 40 | """ 41 | Based on the information provided in the original papers, 42 | the XML-Match is the percentage of outputs that have exactly the same XML structures as their references. 43 | """ 44 | segment_scores = [0] * len(hypotheses) 45 | 46 | for id, (hyp, ref) in enumerate(zip(hypotheses, references)): 47 | try: 48 | hyp = etree.fromstring(f"{hyp}") 49 | ref = etree.fromstring(f"{ref}") 50 | if match_xml(hyp, ref): 51 | segment_scores[id] = 1 52 | except: 53 | pass 54 | 55 | score = sum(segment_scores) / len(segment_scores) 56 | result = XML_MatchResult( 57 | { 58 | "system_score": score, 59 | "segments_scores": segment_scores, 60 | } 61 | ) 62 | return result 63 | 64 | def process_result(self, result) -> MetricResult: 65 | pass 66 | 67 | @staticmethod 68 | def metric_name(): 69 | return "xml_match" 70 | -------------------------------------------------------------------------------- /tower_eval/metrics/xml_match/result.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from tower_eval.metrics.base.result_handler import MetricResult 3 | 4 | 5 | class XML_MatchResult(MetricResult): 6 | """ 7 | xml_match Result Handler. 8 | TODO: Add the extra information (such as the casing, version of the metric, etc) to the output. 9 | """ 10 | 11 | def __init__( 12 | self, 13 | result: float, 14 | ) -> None: 15 | super().__init__(result) 16 | -------------------------------------------------------------------------------- /tower_eval/models/__init__.py: -------------------------------------------------------------------------------- 1 | from tower_eval.models.anthropic.generator import Anthropic 2 | from tower_eval.models.cohere.generator import Cohere 3 | from tower_eval.models.hf.generator import HF 4 | from tower_eval.models.openAI.generator import OpenAI 5 | from tower_eval.models.seq2seq.generator import Seq2Seq 6 | from tower_eval.models.vertexAI.generator import VertexAI 7 | from tower_eval.models.vllm.generator import VLLM 8 | from tower_eval.models.deepl.generator import DeepL 9 | 10 | available_models = { 11 | model.model_name(): model 12 | for model in [OpenAI, VLLM, VertexAI, Anthropic, Cohere, Seq2Seq, HF, DeepL] 13 | } 14 | -------------------------------------------------------------------------------- /tower_eval/models/anthropic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/models/anthropic/__init__.py -------------------------------------------------------------------------------- /tower_eval/models/anthropic/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import anthropic 5 | 6 | from tower_eval.models.inference_handler import Generator 7 | from tower_eval.utils import generate_with_retries 8 | 9 | 10 | class Anthropic(Generator): 11 | """Anthropic API wrapper.""" 12 | 13 | def __init__( 14 | self, 15 | api_key: str = None, 16 | model: str = "", 17 | temperature: float = 0.0, 18 | top_p: float = 1.0, 19 | max_tokens: int = 1024, 20 | retry_max_attempts: int = 1, 21 | retry_max_interval: int = 10, 22 | retry_min_interval: int = 4, 23 | retry_multiplier: int = 1, 24 | stop_sequences: list[str] = [], 25 | system_prompt: str = None, 26 | **kwargs, 27 | ) -> None: 28 | super().__init__(**kwargs) 29 | self.run_async = False # only sync calls are supported 30 | # Set anthropic settings 31 | model = kwargs.get("model", model) 32 | temperature = kwargs.get("temperature", temperature) 33 | top_p = kwargs.get("top_p", top_p) 34 | max_tokens = kwargs.get("max_tokens", max_tokens) 35 | stop_sequences = kwargs.get("stop_sequences", stop_sequences) 36 | system_prompt = system_prompt 37 | self.model_args = { 38 | "model": model, 39 | "temperature": temperature, 40 | "top_p": top_p, 41 | "max_tokens": max_tokens, 42 | "stop_sequences": stop_sequences, 43 | } 44 | if system_prompt is not None: 45 | self.model_args["system"] = system_prompt 46 | 47 | # Generations object / retry args 48 | self.client = anthropic.Anthropic( 49 | # defaults to os.environ.get("ANTHROPIC_API_KEY") 50 | api_key=os.environ.get("ANTHROPIC_API_KEY", api_key), 51 | ) 52 | self.retry_max_attempts = kwargs.get("retry_max_attempts", retry_max_attempts) 53 | self.retry_max_interval = kwargs.get("retry_max_interval", retry_max_interval) 54 | self.retry_min_interval = kwargs.get("retry_min_interval", retry_min_interval) 55 | self.retry_multiplier = kwargs.get("retry_multiplier", retry_multiplier) 56 | 57 | def _generate(self, input_line: str) -> str: 58 | """It calls the Chat completion function of anthropic. 59 | 60 | Args: 61 | prompt (str): Prompt for the anthropic model 62 | 63 | 64 | Returns: 65 | str: Returns the response used. 66 | """ 67 | prompt = {"messages": [{"role": "user", "content": input_line}]} 68 | response = generate_with_retries( 69 | retry_function=self.client.messages.create, 70 | model_args=self.model_args | prompt, 71 | retry_max_attempts=self.retry_max_attempts, 72 | retry_multiplier=self.retry_multiplier, 73 | retry_min_interval=self.retry_min_interval, 74 | retry_max_interval=self.retry_max_interval, 75 | ) 76 | 77 | response = response.content[0].text 78 | return response 79 | 80 | @staticmethod 81 | def model_name(): 82 | return "anthropic" 83 | -------------------------------------------------------------------------------- /tower_eval/models/cohere/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/models/cohere/__init__.py -------------------------------------------------------------------------------- /tower_eval/models/cohere/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | import cohere 5 | from tower_eval.models.inference_handler import Generator 6 | from tower_eval.utils import generate_with_retries 7 | 8 | 9 | class Cohere(Generator): 10 | """Cohere API wrapper.""" 11 | 12 | def __init__( 13 | self, 14 | api_key: str = None, 15 | model: str = "command-r-plus", 16 | temperature: float = 0.0, 17 | top_p: float = 1.0, 18 | max_tokens: int = 1024, 19 | retry_max_attempts: int = 1, 20 | retry_max_interval: int = 10, 21 | retry_min_interval: int = 4, 22 | retry_multiplier: int = 1, 23 | stop_sequences: list[str] = [], 24 | system_prompt: str = None, 25 | **kwargs, 26 | ) -> None: 27 | super().__init__(**kwargs) 28 | self.run_async = False # only sync calls are supported 29 | # Set cohere settings 30 | model = kwargs.get("model", model) 31 | temperature = kwargs.get("temperature", temperature) 32 | top_p = kwargs.get("top_p", top_p) 33 | max_tokens = kwargs.get("max_tokens", max_tokens) 34 | stop_sequences = kwargs.get("stop_sequences", stop_sequences) 35 | system_prompt = system_prompt 36 | self.model_args = { 37 | "model": model, 38 | "temperature": temperature, 39 | "p": top_p, 40 | "max_tokens": max_tokens, 41 | "stop_sequences": stop_sequences, 42 | } 43 | if system_prompt is not None: 44 | self.model_args["preamble"] = system_prompt 45 | 46 | # Generations object / retry args 47 | self.client = cohere.Client( 48 | # defaults to os.environ.get("COHERE_API_KEY") 49 | os.environ.get("COHERE_API_KEY", api_key), 50 | ) 51 | self.retry_max_attempts = kwargs.get("retry_max_attempts", retry_max_attempts) 52 | self.retry_max_interval = kwargs.get("retry_max_interval", retry_max_interval) 53 | self.retry_min_interval = kwargs.get("retry_min_interval", retry_min_interval) 54 | self.retry_multiplier = kwargs.get("retry_multiplier", retry_multiplier) 55 | 56 | def _generate(self, input_line: str) -> str: 57 | """It calls the Chat completion function of cohere. 58 | 59 | Args: 60 | prompt (str): Prompt for the anthropic model 61 | 62 | 63 | Returns: 64 | str: Returns the response used. 65 | """ 66 | prompt = {"message": input_line} 67 | response = generate_with_retries( 68 | retry_function=self.client.chat, 69 | model_args=self.model_args | prompt, 70 | retry_max_attempts=self.retry_max_attempts, 71 | retry_multiplier=self.retry_multiplier, 72 | retry_min_interval=self.retry_min_interval, 73 | retry_max_interval=self.retry_max_interval, 74 | ) 75 | 76 | response = response.text 77 | return response 78 | 79 | @staticmethod 80 | def model_name(): 81 | return "cohere" 82 | -------------------------------------------------------------------------------- /tower_eval/models/deepl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/models/deepl/__init__.py -------------------------------------------------------------------------------- /tower_eval/models/deepl/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | 4 | from loguru import logger 5 | import deepl 6 | from deepl import DocumentTranslationException, DeepLException 7 | 8 | import logging 9 | logging.getLogger('deepl').setLevel(logging.WARNING) 10 | from tower_eval.models.inference_handler import Generator 11 | from tower_eval.utils import generate_with_retries 12 | 13 | 14 | class DeepL(Generator): 15 | """DeepL Wrapper. 16 | 17 | Args: 18 | api_key (str): DeepL API Key 19 | model (str): Specifies which DeepL model should be used for translation. 20 | - "latency_optimized": the classic translation model of DeepL with lower latency that support all language pairs; default value 21 | - "quality_optimized": uses higher latency, improved quality “next-gen” translation models, which support only a subset of language pairs; 22 | if a language pair that is not supported by next-gen models is included in the request, it will fail. 23 | - "prefer_quality_optimized": prioritizes use of higher latency, improved quality “next-gen” translation models, which support only a subset of DeepL languages; 24 | if a request includes a language pair not supported by next-gen models, the request will fall back to latency_optimized classic models) 25 | Check this link for more information: https://developers.deepl.com/docs/api-reference/translate#request-body-descriptions 26 | retry_max_attempts (int, optional): Maximum number of retries. Defaults to 1. 27 | retry_max_interval (int, optional): Maximum interval between retries. Defaults to 10. 28 | retry_min_interval (int, optional): Minimum interval between retries. Defaults to 4. 29 | retry_multiplier (int, optional): Multiplier for the retry interval. Defaults to 1. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | api_key: str = os.environ["DEEPL_API_KEY"], 35 | model: str = "latency_optimized", 36 | retry_max_attempts: int = 1, 37 | retry_max_interval: int = 10, 38 | retry_min_interval: int = 4, 39 | retry_multiplier: int = 1, 40 | **kwargs, 41 | ) -> None: 42 | super().__init__(**kwargs) 43 | 44 | self.run_async = False # only sync calls are supported 45 | # Set openai settings 46 | self.model_type = kwargs.get("model", model) 47 | self.retry_max_attempts = kwargs.get("retry_max_attempts", retry_max_attempts) 48 | self.retry_max_interval = kwargs.get("retry_max_interval", retry_max_interval) 49 | self.retry_min_interval = kwargs.get("retry_min_interval", retry_min_interval) 50 | self.retry_multiplier = kwargs.get("retry_multiplier", retry_multiplier) 51 | self.client = deepl.Translator(api_key) 52 | 53 | self.src_lang_map = { 54 | "es-latam": "es", 55 | "en-gb": "en", 56 | "en-us": "en", 57 | "en-uk": "en", 58 | "pt-br": "pt", 59 | "zh-tw": "zh", 60 | "zh-cn": "zh", 61 | } 62 | 63 | self.trg_lang_map = { 64 | "pt": "pt-pt", 65 | "en": "en-us", 66 | "es-latam": "es", 67 | "fr-fr": "fr", 68 | "it-it": "it", 69 | "ko-ko": "ko", 70 | "zh": "zh-hans", 71 | "zh-cn": "zh-hans", 72 | "zh-tw": "zh-hant", 73 | } 74 | 75 | def normalize_languages(self): 76 | # Valid languages per vendor 77 | valid_src_lang = self.src_lang_map.get(self.source_language.lower(), self.source_language) 78 | valid_trg_lang = self.trg_lang_map.get(self.target_language.lower(), self.target_language) 79 | 80 | if valid_src_lang != self.source_language: 81 | logger.warning( 82 | f"Source language ({self.source_language}) not supported by DeepL. " f"Using {valid_src_lang} instead" 83 | ) 84 | if valid_trg_lang != self.target_language: 85 | logger.warning( 86 | f"Target language ({self.target_language}) not supported by DeepL. " f"Using {valid_trg_lang} instead" 87 | ) 88 | 89 | self.source_language = valid_src_lang 90 | self.target_language = valid_trg_lang 91 | 92 | def _generate(self, input_line: str, context: str = None, formality: str = "default") -> str: 93 | """It calls the translate_text function of DeepL. 94 | 95 | Args: 96 | input_line (str): The text to be translated. 97 | context (str, optional): The context for the translation. Defaults to None. 98 | 99 | 100 | Returns: 101 | str: Returns the translated text. 102 | """ 103 | self.normalize_languages() 104 | 105 | try: 106 | response = generate_with_retries( 107 | retry_function=self.client.translate_text, 108 | model_args={"text": input_line, 109 | "context": context, 110 | "source_lang": self.source_language, 111 | "target_lang": self.target_language, 112 | "formality": formality, 113 | "model_type": self.model_type}, 114 | retry_max_attempts=self.retry_max_attempts, 115 | retry_multiplier=self.retry_multiplier, 116 | retry_min_interval=self.retry_min_interval, 117 | retry_max_interval=self.retry_max_interval, 118 | ) 119 | except DocumentTranslationException as e: 120 | # If an error occurs during document translation after the document was 121 | # already uploaded, a DocumentTranslationException is raised. The 122 | # document_handle property contains the document handle that may be used to 123 | # later retrieve the document from the server, or contact DeepL support. 124 | doc_id = e.document_handle.id 125 | doc_key = e.document_handle.key 126 | logger.opt(colors=True).error(f"Error after uploading ${e}, id: ${doc_id} key: ${doc_key}") 127 | except DeepLException as e: 128 | # Errors during upload raise a DeepLException 129 | logger.opt(colors=True).error(f"{e}") 130 | 131 | response = response.text 132 | return response 133 | 134 | @staticmethod 135 | def model_name(): 136 | return "deepl" 137 | -------------------------------------------------------------------------------- /tower_eval/models/exceptions.py: -------------------------------------------------------------------------------- 1 | class GenerationException(Exception): 2 | pass 3 | -------------------------------------------------------------------------------- /tower_eval/models/hf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/models/hf/__init__.py -------------------------------------------------------------------------------- /tower_eval/models/hf/generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | from loguru import logger 4 | 5 | from transformers import ( 6 | AutoModelForCausalLM, 7 | AutoTokenizer, 8 | StoppingCriteriaList, 9 | StopStringCriteria, 10 | ) 11 | from typing import List 12 | from tower_eval.models.inference_handler import Generator 13 | logger.add( 14 | sys.stderr, 15 | colorize=True, 16 | format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", 17 | ) 18 | 19 | class HF(Generator): 20 | """HF Generate Wrapper.""" 21 | 22 | def __init__( 23 | self, 24 | max_tokens: int = 1024, 25 | stop_sequences: list = ["\n", "\\n", ""], 26 | seed: int = 42, 27 | run_async: bool = True, 28 | batch_size: int = 16, 29 | trust_remote_code: bool = True, 30 | temperature: float = 0.0, 31 | strip_output: bool = False, 32 | **kwargs 33 | ) -> None: 34 | super().__init__(**kwargs) 35 | self.max_tokens = max_tokens 36 | self.temperature = temperature 37 | self.seed = seed 38 | self.model_dir = kwargs.get("model_dir") 39 | self.run_async = run_async 40 | self.batch_size = batch_size 41 | self.trust_remote_code = trust_remote_code 42 | self.strip_output = strip_output 43 | self.do_sample = kwargs.get("do_sample", False) 44 | self.top_p = kwargs.get("top_p") 45 | self.top_k = kwargs.get("top_k") 46 | if not self.do_sample and self.temperature == 0.0: 47 | logger.opt(colors=True).warning("For greedy decoding you should only set do_sample to False " 48 | "and temperature to 0.0." 49 | " I am going to set do_sample=False and temperature=None which will result in greedy generation.") 50 | self.do_sample = False 51 | self.temperature = None 52 | # load tokenizer 53 | self.tokenizer = AutoTokenizer.from_pretrained( 54 | self.model_dir, 55 | trust_remote_code=self.trust_remote_code, 56 | padding_side='left' 57 | ) 58 | if stop_sequences: 59 | self.stopping_criteria = StoppingCriteriaList( 60 | [StopStringCriteria(self.tokenizer, stop_sequences)] 61 | ) 62 | else: 63 | self.stopping_criteria = None 64 | 65 | # Set up device configuration 66 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 67 | 68 | # Load model 69 | self.model = AutoModelForCausalLM.from_pretrained( 70 | self.model_dir, 71 | trust_remote_code=self.trust_remote_code, 72 | device_map="auto", 73 | ) 74 | 75 | torch.manual_seed(self.seed) 76 | 77 | def _generate(self, input_line: str) -> str: 78 | """Generate text for a single input.""" 79 | 80 | # Tokenize input 81 | inputs = self.tokenizer( 82 | input_line, return_token_type_ids=False, return_tensors="pt" 83 | ).to(self.device) 84 | outputs = self.model.generate( 85 | **inputs, 86 | max_new_tokens=self.max_tokens, 87 | stopping_criteria=self.stopping_criteria, 88 | temperature=self.temperature, 89 | do_sample=self.do_sample, 90 | top_p=self.top_p, 91 | top_k=self.top_k, 92 | ) 93 | # decode only the generated part of the text 94 | generated_text = self.tokenizer.decode( 95 | outputs[0][inputs["input_ids"].shape[-1] : -1] # exclude stopping tokens 96 | ) 97 | 98 | if self.strip_output: 99 | generated_text = generated_text.strip() 100 | 101 | return generated_text 102 | 103 | def _batch_generate(self, input_lines: List[str]) -> List[str]: 104 | if self.tokenizer.pad_token is None: 105 | self.tokenizer.pad_token = self.tokenizer.eos_token 106 | inputs = self.tokenizer( 107 | input_lines, 108 | return_token_type_ids=False, 109 | padding=True, 110 | truncation=True, 111 | return_tensors="pt", 112 | return_attention_mask=True 113 | ).to(self.device) 114 | model_output = self.model.generate( 115 | **inputs, 116 | max_new_tokens=self.max_tokens, 117 | stopping_criteria=self.stopping_criteria, 118 | pad_token_id=self.tokenizer.pad_token_id, 119 | temperature=self.temperature, 120 | do_sample=self.do_sample, 121 | top_p=self.top_p, 122 | top_k=self.top_k, 123 | ) 124 | generations = [] 125 | for i, output in enumerate(model_output): 126 | # Get the length of the original input for this example 127 | input_length = len(inputs["input_ids"][i].nonzero()) 128 | # Decode only the newly generated tokens 129 | generated_text = self.tokenizer.decode( 130 | output[input_length:], 131 | skip_special_tokens=True, 132 | clean_up_tokenization_spaces=True 133 | ) 134 | if self.strip_output: 135 | generated_text = generated_text.strip() 136 | generations.append(generated_text) 137 | return generations 138 | 139 | def apply_chat_template(self, input_line: str) -> str: 140 | if self.system_prompt is not None: 141 | messages = [{"role": "system", "content": self.system_prompt}] 142 | else: 143 | messages = [] 144 | messages.append({"role": "user", "content": input_line}) 145 | input_line = self.tokenizer.apply_chat_template( 146 | messages, 147 | add_generation_prompt=True, 148 | tokenize=False, 149 | chat_template=( 150 | None 151 | if self.model_dir 152 | not in [ 153 | "openGPT-X/Teuken-7B-instruct-research-v0.4", 154 | "openGPT-X/Teuken-7B-instruct-commercial-v0.4", 155 | ] 156 | else "EN" 157 | ), 158 | ) 159 | return input_line 160 | 161 | @staticmethod 162 | def model_name(): 163 | return "hf" 164 | -------------------------------------------------------------------------------- /tower_eval/models/inference_handler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from abc import ABC, abstractmethod 4 | from typing import List, Tuple 5 | 6 | from loguru import logger 7 | from tqdm import tqdm 8 | 9 | from tower_eval.utils import ( 10 | get_num_processed_lines, 11 | load_json_file, 12 | log_response, 13 | read_lines, 14 | save_to_json, 15 | write_lines, 16 | ) 17 | 18 | 19 | class Generator(ABC): 20 | """Abstract class defining a shared interface for all the generators (OpenAI models as well as our internal LLMs)""" 21 | 22 | def __init__(self, **kwargs) -> None: 23 | self.batch_size = kwargs.get("batch_size", 1) 24 | self.strip = kwargs.get("strip", True) 25 | self.use_chat_template = kwargs.get("use_chat_template", False) 26 | self.system_prompt = kwargs.get("system_prompt", None) 27 | 28 | def generate(self, prompt: str, **kwargs): 29 | """ 30 | The function that given the prompt sends the inference requtest to the model and gets the output. 31 | """ 32 | pass 33 | 34 | @abstractmethod 35 | def _generate(self, prompt: str): 36 | """ """ 37 | pass 38 | 39 | def _batch_generate(self, batch: List[str]): 40 | """ """ 41 | return [self._generate(b) for b in batch] 42 | 43 | @staticmethod 44 | @abstractmethod 45 | def model_name() -> None: 46 | """Model name to be called for inference.""" 47 | pass 48 | 49 | def assess_progress( 50 | self, 51 | input_lines: List[str], 52 | output_file: str, 53 | metadata: dict, 54 | metadata_file: str, 55 | overwrite_generations: bool = False, 56 | ) -> Tuple[List[str], List[str], dict, int, int]: 57 | """ """ 58 | total_lines = len(input_lines) 59 | if os.path.exists(output_file) and not overwrite_generations: 60 | processed_lines = read_lines(output_file, unescape_newline=True) 61 | num_processed_lines = get_num_processed_lines(output_file) 62 | else: 63 | processed_lines = [] 64 | num_processed_lines = 0 65 | # We assume that if the metadata file exists it already contains the information of the generation times. 66 | # If it doesn't exist, then the metadata will be the config of the task and we will add the generation_time field to it 67 | if os.path.exists(metadata_file) and not overwrite_generations: 68 | metadata = load_json_file(metadata_file) 69 | else: 70 | metadata["generation_time"] = [] 71 | 72 | assert ( 73 | num_processed_lines <= total_lines 74 | ), f"MORE PROCESSED LINES ({num_processed_lines}) THAN INPUT LINES ({total_lines})!" 75 | # Skip the lines already processed 76 | input_lines = input_lines[num_processed_lines:] 77 | 78 | return input_lines, processed_lines, metadata, num_processed_lines, total_lines 79 | 80 | def apply_chat_template(self, input_line: str) -> str: 81 | return input_line 82 | 83 | def preprocess_lines(self, input_lines: List[str]) -> List[str]: 84 | """ """ 85 | if self.strip: 86 | input_lines = [input_line.strip() for input_line in input_lines] 87 | else: 88 | input_lines = [input_line for input_line in input_lines] 89 | if self.use_chat_template: 90 | if self.model_name() in ["hf", "vllm"]: 91 | logger.warning("Applying chat template to loaded instructions.") 92 | else: 93 | raise NotImplementedError( 94 | "Applying chat template on the fly is only supported by hf or vllm models; please set the use_chat_template flag to False." 95 | ) 96 | input_lines = [ 97 | self.apply_chat_template(input_line) for input_line in input_lines 98 | ] 99 | if len(input_lines) > 0: 100 | logger.info(f"Example processed line: {input_lines[-1]}") 101 | return input_lines 102 | 103 | def generate_to_file( 104 | self, 105 | input_lines: List[str], 106 | processed_lines: List[str], 107 | num_processed_lines: int, 108 | total_lines: int, 109 | output_file: str, 110 | metadata: dict, 111 | metadata_file: str, 112 | ): 113 | input_lines = self.preprocess_lines(input_lines) 114 | inference_batch_size = self.batch_size 115 | # for vllm, handle the case where input lines is finished 116 | if self.batch_size == -1: 117 | inference_batch_size = max(len(input_lines), 1) 118 | with tqdm(initial=num_processed_lines, total=total_lines) as pbar: 119 | for batch_id in range(0, len(input_lines), inference_batch_size): 120 | batch = input_lines[batch_id : batch_id + inference_batch_size] 121 | start_time = time.time() 122 | responses = self._batch_generate(batch) 123 | end_time = time.time() 124 | metadata["generation_time"].append(end_time - start_time) 125 | 126 | for response_id, response in enumerate(responses): 127 | processed_lines.append(response.strip()) 128 | # Calculate the number of responses processed so far 129 | step = batch_id * inference_batch_size + response_id 130 | log_response(response, step=step, lim=10) 131 | write_lines(output_file, processed_lines, escape_newline=True) 132 | save_to_json(metadata_file, metadata) 133 | pbar.update(len(batch)) 134 | 135 | def generation_with_resume( 136 | self, output_file: str, input_file: str, metadata: dict, metadata_file: str, overwrite_generations: bool = False 137 | ): 138 | """ 139 | Writes generated output to file, resuming from last line generated. 140 | """ 141 | # Read all the input lines and store them in a list 142 | input_lines = read_lines(input_file, unescape_newline=True) 143 | # update input lines, given the already processed lines; store this information 144 | input_lines, processed_lines, metadata, num_processed_lines, total_lines = ( 145 | self.assess_progress(input_lines, output_file, metadata, metadata_file, overwrite_generations) 146 | ) 147 | # perform the generation to a file 148 | self.generate_to_file( 149 | input_lines, 150 | processed_lines, 151 | num_processed_lines, 152 | total_lines, 153 | output_file, 154 | metadata, 155 | metadata_file, 156 | ) 157 | -------------------------------------------------------------------------------- /tower_eval/models/openAI/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/models/openAI/__init__.py -------------------------------------------------------------------------------- /tower_eval/models/openAI/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import re 4 | 5 | import openai 6 | from loguru import logger 7 | from openai import BadRequestError 8 | 9 | from tower_eval.models.exceptions import GenerationException 10 | from tower_eval.models.inference_handler import Generator 11 | from tower_eval.utils import generate_with_retries 12 | 13 | 14 | class OpenAI(Generator): 15 | """OpenAI GPT completion Wrapper. 16 | 17 | Args: 18 | api_org (str): The Org ID for OpenAI org 19 | api_key (str): OpenAI API Key 20 | api_base (str, optional): OpenAI API Base URL. Defaults to "https://api.openai.com/v1". 21 | api_version (str, optional): OpenAI API Version. Defaults to None. 22 | api_type (str, optional): OpenAI API Type. Defaults to OpenAI. 23 | retry_max_attempts (int, optional): Maximum number of retries. Defaults to 1. 24 | retry_max_interval (int, optional): Maximum interval between retries. Defaults to 10. 25 | retry_min_interval (int, optional): Minimum interval between retries. Defaults to 4. 26 | retry_multiplier (int, optional): Multiplier for the retry interval. Defaults to 1. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | api_key: str = os.environ["OPENAI_API_KEY"], 32 | api_base: str = "https://api.openai.com/v1", 33 | model: str = "", 34 | temperature: float = 0.0, 35 | top_p: float = 1.0, 36 | max_tokens: int = 1024, 37 | frequency_penalty: float = 0.0, 38 | presence_penalty: float = 0.0, 39 | retry_max_attempts: int = 1, 40 | retry_max_interval: int = 10, 41 | retry_min_interval: int = 4, 42 | retry_multiplier: int = 1, 43 | stop_sequences: list[str] = [], 44 | run_async: bool = False, 45 | system_prompt: str = None, 46 | **kwargs, 47 | ) -> None: 48 | super().__init__(**kwargs) 49 | self.run_async = False # only sync calls are supported 50 | # Set openai settings 51 | model = kwargs.get("model", model) 52 | temperature = kwargs.get("temperature", temperature) 53 | top_p = kwargs.get("top_p", top_p) 54 | max_tokens = kwargs.get("max_tokens", max_tokens) 55 | frequency_penalty = kwargs.get("frequency_penalty", frequency_penalty) 56 | presence_penalty = kwargs.get("presence_penalty", presence_penalty) 57 | stop_sequences = kwargs.get("stop_sequences", stop_sequences) 58 | self.system_prompt = system_prompt 59 | self.run_async = run_async 60 | self.openai_args = { 61 | "model": model, 62 | "temperature": temperature, 63 | "top_p": top_p, 64 | "max_completion_tokens": max_tokens, 65 | "frequency_penalty": frequency_penalty, 66 | "presence_penalty": presence_penalty, 67 | "stop": stop_sequences, 68 | } 69 | 70 | self.retry_max_attempts = kwargs.get("retry_max_attempts", retry_max_attempts) 71 | self.retry_max_interval = kwargs.get("retry_max_interval", retry_max_interval) 72 | self.retry_min_interval = kwargs.get("retry_min_interval", retry_min_interval) 73 | self.retry_multiplier = kwargs.get("retry_multiplier", retry_multiplier) 74 | self.model_max_tokens = { 75 | "gpt-3.5-turbo": 4096, 76 | "gpt-4": 8192, 77 | "gpt-4o": 128000, 78 | "gpt-4o-mini": 128000, 79 | }.get(model, 32000) 80 | self.client = openai.Client(api_key=api_key, base_url=api_base) 81 | 82 | def _generate(self, input_line: str) -> str: 83 | """It calls the Chat completion function of OpenAI. 84 | 85 | Args: 86 | prompt (str): Prompt for the OpenAI model 87 | 88 | 89 | Returns: 90 | str: Returns the response used. 91 | """ 92 | try: 93 | if self.system_prompt is not None: 94 | prompt = { 95 | "messages": [ 96 | {"role": "system", "content": self.system_prompt}, 97 | {"role": "user", "content": input_line}, 98 | ] 99 | } 100 | else: 101 | prompt = {"messages": [{"role": "user", "content": input_line}]} 102 | response = generate_with_retries( 103 | retry_function=self.client.chat.completions.create, 104 | model_args=self.openai_args | prompt, 105 | retry_max_attempts=self.retry_max_attempts, 106 | retry_multiplier=self.retry_multiplier, 107 | retry_min_interval=self.retry_min_interval, 108 | retry_max_interval=self.retry_max_interval, 109 | ) 110 | except Exception as e: 111 | if type(e) == BadRequestError: 112 | response = self._handle_excessive_tokens_error(e, prompt) 113 | else: 114 | raise GenerationException(str(e)) 115 | 116 | response = response.choices[0].message.content 117 | return response 118 | 119 | # def _handle_excessive_tokens_error(self, e: InvalidRequestError, prompt: dict): 120 | def _handle_excessive_tokens_error(self, e: BadRequestError, prompt: dict): 121 | logger.error( 122 | f'Handling Open AI excessive tokens requested error by decreasing max tokens for this request. ("{str(e)}")' 123 | ) 124 | requested_tokens = int(re.findall(r"you requested (\d+) tokens", str(e))[0]) 125 | excessive_tokens = requested_tokens - self.model_max_tokens 126 | old_max_tokens = self.openai_args["max_completion_tokens"] 127 | new_max_tokens = old_max_tokens - excessive_tokens 128 | self.openai_args["max_completion_tokens"] = new_max_tokens 129 | logger.warning( 130 | f"Decreased max tokens from {old_max_tokens} to {new_max_tokens}." 131 | ) 132 | response = generate_with_retries( 133 | retry_function=self.client.chat.completions.create, 134 | model_args=self.openai_args | prompt, 135 | retry_max_attempts=self.retry_max_attempts, 136 | retry_multiplier=self.retry_multiplier, 137 | retry_min_interval=self.retry_min_interval, 138 | retry_max_interval=self.retry_max_interval, 139 | ) 140 | logger.warning(f"Restoring max tokens to {old_max_tokens}.") 141 | self.openai_args["max_completion_tokens"] = old_max_tokens 142 | 143 | return response 144 | 145 | @staticmethod 146 | def model_name(): 147 | return "open-ai" 148 | -------------------------------------------------------------------------------- /tower_eval/models/seq2seq/__init__.py: -------------------------------------------------------------------------------- 1 | NLLB_LANGUAGE_CODES = { 2 | "ach": "ace_Arab", 3 | "af": "afr_Latn", 4 | "ak": "aka_Latn", 5 | "am": "amh_Ethi", 6 | "ar": "arb_Arab", 7 | "as": "asm_Beng", 8 | "ast": "ast_Latn", 9 | "az": "azj_Latn", 10 | "ban": "ban_Latn", 11 | "bal": "bak_Cyrl", 12 | "be": "bel_Cyrl", 13 | "bg": "bul_Cyrl", 14 | "bn": "ben_Beng", 15 | "br": "bre_Latn", 16 | "bs": "bos_Latn", 17 | "ca": "cat_Latn", 18 | "ceb": "ceb_Latn", 19 | "cs": "ces_Latn", 20 | "cy": "cym_Latn", 21 | "da": "dan_Latn", 22 | "da-da": "dan_Latn", 23 | "de": "deu_Latn", 24 | "dz": "dzo_Tibt", 25 | "el": "ell_Grek", 26 | "en": "eng_Latn", 27 | "eo": "epo_Latn", 28 | "es": "spa_Latn", 29 | "es-latam": "spa_Latn", 30 | "et": "est_Latn", 31 | "eu": "eus_Latn", 32 | "fa": "pes_Arab", 33 | "fi": "fin_Latn", 34 | "fi-fi": "fin_Latn", 35 | "tgl": "tgl_Latn", 36 | "fo": "fao_Latn", 37 | "fr": "fra_Latn", 38 | "fr-fr": "fra_Latn", 39 | "fr-ca": "fra_Latn", 40 | "ga": "gle_Latn", 41 | "gl": "glg_Latn", 42 | "gu": "guj_Gujr", 43 | "he": "heb_Hebr", 44 | "hi": "hin_Deva", 45 | "hr": "hrv_Latn", 46 | "ht": "hat_Latn", 47 | "hu": "hun_Latn", 48 | "hu-hu": "hun_Latn", 49 | "hy": "hye_Armn", 50 | "id": "ind_Latn", 51 | "is": "isl_Latn", 52 | "it": "ita_Latn", 53 | "it-it": "ita_Latn", 54 | "ja": "jpn_Jpan", 55 | "jv": "jav_Latn", 56 | "ka": "kat_Geor", 57 | "kk": "kaz_Cyrl", 58 | "km": "khm_Khmr", 59 | "kn": "kan_Knda", 60 | "ko": "kor_Hang", 61 | "ko-ko": "kor_Hang", 62 | "ku": "kur_Latn", 63 | "ky": "kir_Cyrl", 64 | "la": "lat_Latn", 65 | "lb": "ltz_Latn", 66 | "lo": "lao_Laoo", 67 | "lt": "lit_Latn", 68 | "lv": "lvs_Latn", 69 | "mk": "mkd_Cyrl", 70 | "ml": "mal_Mlym", 71 | "mn": "mon_Cyrl", 72 | "mr": "mar_Deva", 73 | "ms": "msa_Latn", 74 | "mt": "mlt_Latn", 75 | "ne": "npi_Deva", 76 | "nl": "nld_Latn", 77 | "nl-nl": "nld_Latn", 78 | "nn": "nno_Latn", 79 | "no": "nob_Latn", 80 | "no-no": "nob_Latn", 81 | "nso": "nso_Latn", 82 | "ny": "nya_Latn", 83 | "pa": "pan_Guru", 84 | "pl": "pol_Latn", 85 | "pl-pl": "pol_Latn", 86 | "pt": "por_Latn", 87 | "pt-pt": "por_Latn", 88 | "pt-br": "por_Latn", 89 | "ro": "ron_Latn", 90 | "ro-ro": "ron_Latn", 91 | "ru": "rus_Cyrl", 92 | "si": "sin_Sinh", 93 | "sk": "slk_Latn", 94 | "sl": "slv_Latn", 95 | "sr": "srp_Cyrl", 96 | "sv": "swe_Latn", 97 | "sv-se": "swe_Latn", 98 | "sw": "swh_Latn", 99 | "ta": "tam_Taml", 100 | "te": "tel_Telu", 101 | "th": "tha_Thai", 102 | "tl": "tgl_Latn", 103 | "tr": "tur_Latn", 104 | "ug": "uig_Arab", 105 | "uk": "ukr_Cyrl", 106 | "ur": "urd_Arab", 107 | "vi": "vie_Latn", 108 | "xh": "xho_Latn", 109 | "yi": "ydd_Hebr", 110 | "zh": "zho_Hans", 111 | "zh-cn": "zho_Hans", 112 | "zh-tw": "zho_Hans", 113 | "zu": "zul_Latn", 114 | } 115 | -------------------------------------------------------------------------------- /tower_eval/models/seq2seq/generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 4 | 5 | from tower_eval.models.inference_handler import Generator 6 | from tower_eval.models.seq2seq import NLLB_LANGUAGE_CODES 7 | 8 | 9 | class Seq2Seq(Generator): 10 | """Seq2Seq Models Generation code. 11 | 12 | Args: 13 | model (str, required): The model name or path to the model. 14 | batch_size (int, optional): The batch size for the model. Defaults to 16. 15 | model_family (str, optional): The model family. Defaults to "nllb". 16 | max_tokens (int, optional): The maximum number of tokens to generate. Defaults to 1024. 17 | do_sample (bool, optional): Whether to use sampling or not. Defaults to False. 18 | gpu (int, optional): The GPU device to use. Defaults to 0. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | model_dir: str = None, 24 | batch_size: int = 16, 25 | model_family: str = None, 26 | max_tokens: int = 1024, 27 | do_sample: bool = False, 28 | hf_generate_kwargs: dict = {}, 29 | **kwargs, 30 | ) -> None: 31 | super().__init__(**kwargs) 32 | self.batch_size = kwargs.get("batch_size", batch_size) 33 | self.max_tokens = kwargs.get("max_tokens", max_tokens) 34 | self.do_sample = kwargs.get("do_sample", do_sample) 35 | self.model_dir = kwargs.get("model_dir", model_dir) 36 | self.model_family = kwargs.get("model_family", model_family) 37 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) 38 | self.model = AutoModelForSeq2SeqLM.from_pretrained( 39 | self.model_dir, device_map="auto" 40 | ) 41 | self.hf_generate_kwargs = hf_generate_kwargs 42 | 43 | def _generate(self) -> str: 44 | pass 45 | 46 | def _batch_generate(self, input_lines: List[str]) -> List[str]: 47 | """It calls the model to generate the sequences. 48 | 49 | Args: 50 | input_lines (List[str]): The input lines for the model 51 | 52 | 53 | Returns: 54 | str: Returns the generated sequences. 55 | """ 56 | # NLLB requires that source language be passed to the tokenizer and the target language be passed to the model 57 | if self.model_family == "nllb": 58 | self.tokenizer.src_lang = NLLB_LANGUAGE_CODES[self.source_language] 59 | inputs = self.tokenizer(input_lines, return_tensors="pt", padding=True).to( 60 | "cuda" 61 | ) 62 | generated_tokens = self.model.generate( 63 | **inputs, 64 | forced_bos_token_id=self.tokenizer.convert_tokens_to_ids( 65 | NLLB_LANGUAGE_CODES[self.target_language] 66 | ), 67 | do_sample=self.do_sample, 68 | max_new_tokens=self.max_tokens, 69 | **self.hf_generate_kwargs, 70 | ) 71 | else: 72 | generated_tokens = self.model.generate( 73 | **inputs, 74 | do_sample=self.do_sample, 75 | max_new_tokens=self.max_tokens, 76 | **self.hf_generate_kwargs, 77 | ) 78 | sequences = self.tokenizer.batch_decode( 79 | generated_tokens, skip_special_tokens=True 80 | ) 81 | sequences = [seq.strip() for seq in sequences] 82 | 83 | return sequences 84 | 85 | @staticmethod 86 | def model_name(): 87 | return "seq2seq" 88 | -------------------------------------------------------------------------------- /tower_eval/models/vertexAI/__init__.py: -------------------------------------------------------------------------------- 1 | API_TYPE = { 2 | "text-bison": "palm", 3 | "text-unicorn": "palm", 4 | "text-bison-32k": "palm", 5 | "chat-bison": "palm", 6 | "chat-bison-32k": "palm", 7 | "gemini-pro": "gemini", 8 | "gemini-1.0-pro-002": "gemini", # Information available in: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemini-pro?project=unbabel-translate-neural 9 | "gemini-1.5-flash-001": "gemini-1.5", # Information available in: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemini-1.5-flash-001?project=unbabel-translate-neural 10 | "gemini-1.5-pro-001": "gemini-1.5", # Information available in: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/gemini-1.5-pro-001?project=unbabel-translate-neural 11 | } -------------------------------------------------------------------------------- /tower_eval/models/vertexAI/generator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import vertexai 3 | from vertexai.language_models import TextGenerationModel 4 | from vertexai.preview.generative_models import GenerativeModel 5 | 6 | import base64 7 | 8 | import vertexai.preview.generative_models as generative_models 9 | 10 | from tower_eval.models.exceptions import GenerationException 11 | from tower_eval.models.inference_handler import Generator 12 | from tower_eval.utils import generate_with_retries 13 | from tower_eval.models.vertexAI import API_TYPE 14 | from loguru import logger 15 | 16 | 17 | class VertexAI(Generator): 18 | """Google's Vertex AI Wrapper. 19 | 20 | Args: 21 | model: the name of the model to use for the inference (default: gemini-pro) 22 | temprature: the temprature 23 | top_p: Defines the cumulative probability cutoff for token selection. 24 | max_tokens: determines the maximum number of tokens the model is supposed to generate. 25 | api_type (str, optional): OpenAI API Type. Defaults to OpenAI. 26 | retry_max_attempts (int, optional): Maximum number of retries. Defaults to 1. 27 | retry_max_interval (int, optional): Maximum interval between retries. Defaults to 10. 28 | retry_min_interval (int, optional): Minimum interval between retries. Defaults to 4. 29 | retry_multiplier (int, optional): Multiplier for the retry interval. Defaults to 1. 30 | """ 31 | 32 | def __init__( 33 | self, 34 | model: str = "gemini-pro", 35 | temperature: float = 0.0, 36 | top_p: float = 1.0, 37 | max_tokens: int = 1024, 38 | candidate_count: int = 1, 39 | retry_max_attempts: int = 1, 40 | retry_max_interval: int = 10, 41 | retry_min_interval: int = 4, 42 | retry_multiplier: int = 1, 43 | run_async: bool = False, 44 | system_prompt: str = None, 45 | **kwargs, 46 | ) -> None: 47 | super().__init__(**kwargs) 48 | self.run_async = run_async 49 | self.retry_max_attempts = retry_max_attempts 50 | self.retry_max_interval = retry_max_interval 51 | self.retry_min_interval = retry_min_interval 52 | self.retry_multiplier = retry_multiplier 53 | # Gimini and PaLM models need to be called through different model APIs with some small differences. 54 | # But, to avoid having two different inference endpoint in Tower-Eval we decided to handle both of them here. 55 | self.model_type = API_TYPE.get(model) 56 | if self.model_type == "gemini-1.5": 57 | try: 58 | project=kwargs["project"] 59 | location=kwargs["location"] 60 | except: 61 | logger.opt(colors=True).error(f"For Gemeni-1.5 models you need to provide \"project\" and \"location\" in your config file.") 62 | vertexai.init(project=project, location=location) 63 | from vertexai.generative_models import GenerativeModel 64 | logger.info(f"Using the following system prompt: {system_prompt}") 65 | self.model = GenerativeModel(model, system_instruction=system_prompt) 66 | self.inference_function = self.model.generate_content 67 | self.model_args = { 68 | "generation_config": { 69 | "max_output_tokens": max_tokens, 70 | "temperature": temperature, 71 | "top_p": top_p, 72 | }, 73 | } 74 | elif self.model_type == "gemini": 75 | from vertexai.preview.generative_models import GenerativeModel 76 | self.model = GenerativeModel(model, system_instruction=system_prompt) 77 | self.inference_function = self.model.generate_content 78 | self.model_args = { 79 | "generation_config": { 80 | "max_output_tokens": max_tokens, 81 | "temperature": temperature, 82 | "top_p": top_p, 83 | }, 84 | } 85 | elif self.model_type == "palm": 86 | logger.info(f"Model \"{model}\" doesn't support system prompt. So, running the inference with user prompt only.") 87 | self.model = TextGenerationModel.from_pretrained(model) 88 | self.inference_function = self.model.predict 89 | self.model_args = { 90 | "candidate_count": candidate_count, 91 | "max_output_tokens": max_tokens, 92 | "temperature": temperature, 93 | "top_p": top_p, 94 | } 95 | else: 96 | logger.opt(colors=True).info( 97 | f"Model {model} is not supported by Vertex AI." 98 | ) 99 | exit(1) 100 | 101 | def _generate(self, input_line: str) -> str: 102 | """It calls the generate_content() function of VertexAI. 103 | 104 | Args: 105 | input_line (str): Prompt for the model 106 | 107 | Returns: 108 | str: Returns the response used. 109 | """ 110 | try: 111 | if self.model_type in ["gemini", "gemini-1.5"]: 112 | prompt = {"contents": input_line} 113 | elif self.model_type == "palm": 114 | prompt = {"prompt": input_line} 115 | 116 | responses = generate_with_retries( 117 | retry_function=self.inference_function, 118 | model_args=self.model_args | prompt, 119 | retry_max_attempts=self.retry_max_attempts, 120 | retry_multiplier=self.retry_multiplier, 121 | retry_min_interval=self.retry_min_interval, 122 | retry_max_interval=self.retry_max_interval, 123 | ) 124 | except Exception as e: 125 | raise GenerationException(str(e)) 126 | 127 | return responses.text 128 | 129 | @staticmethod 130 | def model_name(): 131 | return "vertex-ai" 132 | -------------------------------------------------------------------------------- /tower_eval/models/vllm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/models/vllm/__init__.py -------------------------------------------------------------------------------- /tower_eval/models/vllm/generator.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from vllm import LLM, SamplingParams 4 | 5 | from tower_eval.models.inference_handler import Generator 6 | 7 | 8 | class VLLM(Generator): 9 | """VLLM Generate Wrapper.""" 10 | 11 | def __init__( 12 | self, 13 | max_tokens: int = 1024, 14 | stop_sequences: list = ["\n", "\\n", ""], 15 | seed: int = 42, 16 | n_gpus: int = 1, 17 | run_async: bool = True, 18 | batch_size: int = 16, 19 | quantization: str = None, # "awq", "gptq" or "squeezellm" 20 | trust_remote_code: bool = True, 21 | gpu_memory_utilization: float = 0.9, 22 | temperature: float = 0.0, # greedy by default 23 | vllm_sampling_params: dict = {}, # see vllm SamplingParams and pass desired kwargs 24 | vllm_engine_args: dict = {}, # see vllm LLM and pass desired kwargs 25 | **kwargs 26 | ) -> None: 27 | super().__init__(**kwargs) 28 | self.max_tokens = max_tokens # actually max new tokens 29 | if len(stop_sequences) == 0: 30 | self.stop_sequences = None 31 | else: 32 | self.stop_sequences = stop_sequences 33 | self.temperature = kwargs.get("temperature", temperature) 34 | self.seed = seed 35 | self.model_dir = kwargs.get("model_dir") 36 | self.run_async = run_async 37 | self.batch_size = batch_size 38 | self.quantization = quantization 39 | self.n_gpus = n_gpus 40 | self.trust_remote_code = trust_remote_code 41 | self.gpu_memory_utilization = gpu_memory_utilization 42 | self.sampling_params = SamplingParams( 43 | stop=self.stop_sequences, 44 | max_tokens=self.max_tokens, 45 | temperature=self.temperature, 46 | **vllm_sampling_params, 47 | ) 48 | self.model = LLM( 49 | model=self.model_dir, 50 | quantization=self.quantization, 51 | seed=self.seed, 52 | trust_remote_code=self.trust_remote_code, 53 | tensor_parallel_size=self.n_gpus, 54 | gpu_memory_utilization=gpu_memory_utilization, 55 | **vllm_engine_args, 56 | ) 57 | 58 | def _generate(self, input_line: str) -> str: 59 | pass 60 | 61 | def _batch_generate(self, input_lines: List[str]) -> List[str]: 62 | model_output = self.model.generate( 63 | input_lines, self.sampling_params, use_tqdm=True 64 | ) 65 | generations = [output.outputs[0].text for output in model_output] 66 | return generations 67 | 68 | def apply_chat_template(self, input_line: str) -> str: 69 | tokenizer = self.model.get_tokenizer() 70 | if self.system_prompt is not None: 71 | messages = [{"role": "system", "content": self.system_prompt}] 72 | else: 73 | messages = [] 74 | messages.append({"role": "user", "content": input_line}) 75 | input_line = tokenizer.apply_chat_template( 76 | messages, 77 | add_generation_prompt=True, 78 | tokenize=False, 79 | chat_template=( 80 | None 81 | if self.model_dir 82 | not in [ 83 | "openGPT-X/Teuken-7B-instruct-research-v0.4", 84 | "openGPT-X/Teuken-7B-instruct-commercial-v0.4", 85 | ] 86 | else "EN" 87 | ), 88 | ) 89 | return input_line 90 | 91 | @staticmethod 92 | def model_name(): 93 | return "vllm" 94 | -------------------------------------------------------------------------------- /tower_eval/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from tower_eval.metrics.bleu.metric import BLEU 2 | from tower_eval.metrics.comet.metric import COMET 3 | 4 | __all__ = [BLEU, COMET] 5 | 6 | 7 | available_metrics = {metric.metric_name(): metric for metric in __all__} 8 | -------------------------------------------------------------------------------- /tower_eval/tasks/evaluate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deep-spin/tower-eval/28265202cfe218d2b7a8ad3915c62cb70de2202f/tower_eval/tasks/evaluate.py -------------------------------------------------------------------------------- /tower_eval/tasks/generate.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from pathlib import Path 3 | 4 | from jsonargparse import CLI 5 | from loguru import logger 6 | from tower_eval.models import available_models 7 | from tower_eval.utils import ( 8 | add_average_generation_time, 9 | get_langs, 10 | make_dir_if_not_exists, 11 | parse_yaml_config, 12 | ) 13 | def generate(i: int, config_path: str, available_models: dict=available_models) -> None: 14 | configs = parse_yaml_config(config_path) 15 | logger.remove() 16 | logger.add( 17 | sys.stderr, 18 | colorize=True, 19 | format="{time:YYYY-MM-DD HH:mm:ss} | {level: <8} | {message}", 20 | ) 21 | gen_data_dir = Path(configs.get("gen_data_dir")) 22 | gen_output_dir = Path(configs.get("gen_output_dir")) 23 | overwrite_generations = configs.get("overwrite_generations", False) 24 | average_time_metric = configs.get("average_time_metric", "lps") 25 | tasks = configs.get("tasks") 26 | model = configs.get("models")[i] 27 | model_name = model.get("name") 28 | model_type = model.get("type") 29 | model_args = model.get("arguments") 30 | model_args = {} if not model_args else model_args 31 | model = available_models[model_type](**(model_args)) 32 | for task in tasks: 33 | task_name = task.get("name") 34 | subtasks = task.get("subtasks") 35 | for subtask, _ in subtasks.items(): 36 | input_file = gen_data_dir / task_name / subtask / "instructions.txt" 37 | output_path = gen_output_dir / task_name / subtask / model_type / model_name 38 | output_file = output_path / "generation.txt" 39 | make_dir_if_not_exists(output_file) 40 | metadata_file = output_path / "metadata.json" 41 | logger.opt(colors=True).info( 42 | f"Running inference for task: {task_name} , subtask: {subtask} with model: {model_type}/{model_name} saving to: {output_file} " 43 | ) 44 | 45 | lp = subtask.split(".")[-1] 46 | src_lang, tgt_lang = get_langs(lp) 47 | 48 | model.source_language = src_lang 49 | model.target_language = tgt_lang 50 | model.generation_with_resume( 51 | input_file=input_file, 52 | output_file=output_file, 53 | metadata=configs, 54 | metadata_file=metadata_file, 55 | overwrite_generations=overwrite_generations, 56 | ) 57 | add_average_generation_time( 58 | output_file, metadata_file, language=tgt_lang, mode=average_time_metric 59 | ) 60 | 61 | 62 | def simple_generate( 63 | input_paths: list[str], 64 | output_paths: list[str], 65 | model_path: str, 66 | model_type: str, 67 | model_args: dict, 68 | available_models: dict, 69 | overwrite_generations: bool = False, 70 | ): 71 | model_path_key = "model_dir" if model_type == "vllm" else "model" 72 | model_args[model_path_key] = model_path 73 | model = available_models[model_type](**(model_args)) 74 | metadata = { 75 | "model_type": model_type, 76 | "model_args": model_args, 77 | } 78 | for input_path, output_path in zip( 79 | input_paths, output_paths 80 | ): 81 | output_dir = Path(output_path).parent 82 | metadata_file_path = output_dir / "metadata.json" 83 | model.generation_with_resume( 84 | input_file=input_path, 85 | output_file=output_path, 86 | metadata=metadata, 87 | metadata_file=metadata_file_path, 88 | overwrite_generations=overwrite_generations 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | CLI([generate], as_positional=False) 94 | -------------------------------------------------------------------------------- /tower_eval/tasks/index.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import faiss 5 | import numpy as np 6 | import pandas as pd 7 | from loguru import logger 8 | from sentence_transformers import SentenceTransformer 9 | 10 | def index_data( 11 | datastore_filename: str, 12 | datastore_indexed_path: Path, 13 | task_name: str = "task", 14 | subtask_name: str = "subtask", 15 | jsonl: bool =False, 16 | batch_size: int = 1000, 17 | ) -> None: 18 | """ 19 | :param datastore_filename: The input csv of json file name to be encoded and used as the datastore 20 | :param datastore_indexed_path: The output index file that will be used to retrieve similar samples from. 21 | :param index_columns: the columns to be indexed. 22 | """ 23 | logger.opt(colors=True).info(f"========================================") 24 | logger.opt(colors=True).info( 25 | f"Indexing the data of task: {task_name} , subtask: {subtask_name} {datastore_filename} " 26 | ) 27 | 28 | if not os.path.exists(datastore_indexed_path): 29 | os.makedirs(datastore_indexed_path) 30 | 31 | if jsonl: 32 | df = pd.read_json(datastore_filename, lines=True) 33 | else: 34 | df = pd.read_csv(datastore_filename, encoding="utf8") 35 | source_sentences = df['src'].to_list() 36 | 37 | # initialize LaBSE encoder 38 | encoder = SentenceTransformer("sentence-transformers/LaBSE") 39 | embeddings = encoder.encode(source_sentences, batch_size=batch_size) 40 | embeddings = np.asarray(embeddings).astype('float32') 41 | dimension = embeddings.shape[1] 42 | index = faiss.IndexFlatL2(dimension) 43 | index.add(embeddings) 44 | faiss.write_index(index, os.path.join(datastore_indexed_path, "knn.index")) -------------------------------------------------------------------------------- /tower_eval/tasks/prepare.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import jinja2 5 | from loguru import logger 6 | 7 | from tower_eval.fewshot_retrieval_utils import load_few_shot_data 8 | from tower_eval.utils import load_data_to_records, sample_strings_from_list, write_lines 9 | 10 | 11 | def apply_prompt_templates( 12 | prompt_templates: list[str], 13 | prompt_args: dict = {}, 14 | data: dict[list[str]] = {}, 15 | fewshot_examples_list: list[dict[list[str]]] = {}, 16 | ) -> list[str]: 17 | """ """ 18 | # Sample prompt templates from list of templates 19 | prompt_templates = sample_strings_from_list(prompt_templates, len(data)) 20 | # prompt_args and data record should not have matching keys. 21 | if data: 22 | for record in data: 23 | record.update(prompt_args) 24 | else: 25 | data = [prompt_args] * len(prompt_templates) 26 | # add few shot examples data, if exists and create formatted data 27 | if fewshot_examples_list: 28 | for data_record, fewshot_examples in zip(data, fewshot_examples_list): 29 | data_record["examples"] = fewshot_examples 30 | # compile templates and render with arguments 31 | env = jinja2.Environment() 32 | compiled_templates = [env.from_string(prompt) for prompt in prompt_templates] 33 | formatted_data = [t.render(**record) for record, t in zip(data, compiled_templates)] 34 | 35 | return formatted_data 36 | 37 | 38 | def prepare_data( 39 | prompt_templates: list[str], 40 | prompt_args: dict = {}, 41 | test_data_path: str = "", 42 | datastore_data_path: str = None, 43 | n_fewshots: int = 0, 44 | fewshot_retrieval_method: str = None, 45 | fewshot_retrieval_args: dict = {}, 46 | task_name: str = "task", 47 | subtask_name: str = "subtask", 48 | datastore_index_path: str = None, 49 | output_dir: Path = "tests/data", 50 | ) -> None: 51 | """ """ 52 | logger.opt(ansi=True).info(f"========================================") 53 | logger.opt(ansi=True).info( 54 | f"Preparing data of task: {task_name} , subtask: {subtask_name} " 55 | ) 56 | test_data = load_data_to_records(test_data_path) 57 | # read fewshot data objects into lists of strings 58 | fewshot_examples_list: list[dict[list[str]]] = [] 59 | if datastore_data_path is not None: 60 | total_shots = n_fewshots * len(test_data) 61 | fewshot_examples_list = load_few_shot_data( 62 | test_set=test_data, 63 | datastore_data_path=datastore_data_path, 64 | n_fewshots=n_fewshots, 65 | total_shots=total_shots, 66 | fewshot_retrieval_method=fewshot_retrieval_method, 67 | task=task_name, 68 | datastore_index_path=datastore_index_path, 69 | fewshot_retrieval_args=fewshot_retrieval_args, 70 | ) 71 | prepared_data = apply_prompt_templates( 72 | prompt_templates=prompt_templates, 73 | prompt_args=prompt_args, 74 | data=test_data, 75 | fewshot_examples_list=fewshot_examples_list, 76 | ) 77 | 78 | output_path = output_dir / f"{task_name}/{subtask_name}/instructions.txt" 79 | write_lines(output_path, prepared_data, escape_newline=True) 80 | -------------------------------------------------------------------------------- /tower_eval/tools/logging/wandb_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from pathlib import Path 3 | from time import time 4 | from typing import Dict, Tuple 5 | 6 | import wandb 7 | from tower_eval.utils import PATTERN_SHOT_NAME, load_json_file 8 | 9 | METRICS_PER_TASK = { 10 | "mt": ["comet", "chrf", "bleu"], 11 | "ape": ["ter", "comet", "comet_kiwi", "chrf", "bleu"], 12 | "gec": ["errant", "ter"], 13 | "ner": ["f1sequence"], 14 | } 15 | 16 | PROJECT = "tower-eval" 17 | ROOT_DIR = "" 18 | 19 | 20 | def create_wandb_names_from_path(path: Path) -> Tuple[str]: 21 | model = path.parts[-2] 22 | model_type = path.parts[-3] 23 | task = path.parts[-5] 24 | shot_setting = path.parts[-6] 25 | subtask = path.parts[-4] 26 | dataset_name, lp_xor_language = subtask.split(".") 27 | shot_name_short = re.sub(PATTERN_SHOT_NAME, r"\1", shot_setting) 28 | model_name_and_setting = f"{model} ({shot_name_short} shot)" 29 | 30 | return ( 31 | model, 32 | model_type, 33 | task, 34 | shot_setting, 35 | subtask, 36 | dataset_name, 37 | lp_xor_language, 38 | model_name_and_setting, 39 | shot_name_short, 40 | ) 41 | 42 | 43 | def get_data_to_log( 44 | model_name_and_setting: str, 45 | shot_name_short: str, 46 | task: str, 47 | subtask: str, 48 | dataset_name: str, 49 | lp_xor_language: str, 50 | metric: str, 51 | score: str, 52 | model_type: str, 53 | model: str, 54 | shot_setting: str, 55 | ) -> Tuple[Dict, Dict, str]: 56 | data_to_log = { 57 | "model": model_name_and_setting, 58 | "shots": shot_name_short, 59 | "task": task, 60 | "subtask": subtask, 61 | "dataset": dataset_name, 62 | "lp/language": lp_xor_language, 63 | "metric": metric, 64 | "score": score, 65 | "model_type": model_type, 66 | "model_raw_name": model, 67 | "shot_setting": shot_setting, 68 | } 69 | config_to_log = {k: v for k, v in data_to_log.items() if k != "score"} 70 | table_name = f"table_{task}_{dataset_name}_{metric}" 71 | 72 | return data_to_log, config_to_log, table_name 73 | 74 | 75 | def log_one_entry( 76 | project: str, 77 | model_name_and_setting: str, 78 | shot_name_short: str, 79 | task: str, 80 | subtask: str, 81 | dataset_name: str, 82 | lp_xor_language: str, 83 | metric: str, 84 | score: str, 85 | model_type: str, 86 | model: str, 87 | shot_setting: str, 88 | ) -> None: 89 | data_to_log, config_to_log, table_name = get_data_to_log( 90 | model_name_and_setting, 91 | shot_name_short, 92 | task, 93 | subtask, 94 | dataset_name, 95 | lp_xor_language, 96 | metric, 97 | score, 98 | model_type, 99 | model, 100 | shot_setting, 101 | ) 102 | wandb.init( 103 | project=project, name=table_name.split("table_")[-1], config=config_to_log 104 | ) 105 | wandb.log( 106 | { 107 | table_name: wandb.Table( 108 | columns=list(data_to_log.keys()), data=[list(data_to_log.values())] 109 | ) 110 | } 111 | ) 112 | wandb.finish(quiet=True) 113 | 114 | 115 | def log_from_existing_repo(): 116 | wandb_tables = {} 117 | 118 | evaluation_paths = [p for p in Path(ROOT_DIR).rglob("*.json")] 119 | 120 | for p in evaluation_paths: 121 | ( 122 | model, 123 | model_type, 124 | task, 125 | shot_setting, 126 | subtask, 127 | dataset_name, 128 | lp_xor_language, 129 | model_name_and_setting, 130 | shot_name_short, 131 | ) = create_wandb_names_from_path(p) 132 | metrics_to_log = METRICS_PER_TASK[task] 133 | 134 | evaluations = load_json_file(p) 135 | for metric in metrics_to_log: 136 | try: 137 | score = evaluations[metric] 138 | except KeyError: 139 | score = None 140 | data_to_log, config_to_log, table_name = get_data_to_log( 141 | model_name_and_setting, 142 | shot_name_short, 143 | task, 144 | subtask, 145 | dataset_name, 146 | lp_xor_language, 147 | metric, 148 | score, 149 | model_type, 150 | model, 151 | shot_setting, 152 | ) 153 | if table_name not in wandb_tables: 154 | wandb_tables[table_name] = { 155 | "columns": list(data_to_log.keys()), 156 | "data": [], 157 | "config": config_to_log, 158 | } 159 | else: 160 | wandb_tables[table_name]["data"].append(list(data_to_log.values())) 161 | 162 | for table_name, table_dict in wandb_tables.items(): 163 | s = time() 164 | wandb.init( 165 | project=PROJECT, 166 | name=table_name.split("table_")[-1], 167 | config=table_dict["config"], 168 | ) 169 | wandb.log( 170 | { 171 | table_name: wandb.Table( 172 | columns=table_dict["columns"], data=table_dict["data"] 173 | ) 174 | } 175 | ) 176 | wandb.finish(quiet=True) 177 | e = time() 178 | print(f"Took {e-s:.2f}s") 179 | -------------------------------------------------------------------------------- /tower_eval/tools/run_calame.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from datasets import load_dataset 4 | from jsonargparse import CLI 5 | from vllm import LLM, SamplingParams 6 | 7 | 8 | def extract_first_word(input_string): 9 | # Define a regular expression pattern to match the first word 10 | pattern = r"\b\w+\b" 11 | 12 | # Use the findall function from the re module to find all matches 13 | matches = re.findall(pattern, input_string) 14 | 15 | if matches: 16 | return matches[0] 17 | else: 18 | return "" 19 | 20 | 21 | def main(model_dir: str): 22 | # Load the model 23 | model = LLM(model_dir) 24 | s = SamplingParams(temperature=0.0, max_tokens=10) 25 | for subset in ["generated", "handwritten"]: 26 | dataset = load_dataset("NOVA-vision-language/calame-pt", subset)["train"] 27 | input_lines = dataset["sentence"] 28 | gold_words = dataset["last_word"] 29 | # generate 30 | model_outputs = model.generate(input_lines, sampling_params=s, use_tqdm=True) 31 | # replicate calame script 32 | generations = [ 33 | output.outputs[0].text.replace("\n", "") for output in model_outputs 34 | ] 35 | # Extract first predicted word 36 | predicted_last_words = [extract_first_word(g).strip() for g in generations] 37 | correct_predictions = [ 38 | 1 if p.lower() == g.lower() else 0 39 | for p, g in zip(predicted_last_words, gold_words) 40 | ] 41 | accuracy = 100 * sum(correct_predictions) / len(correct_predictions) 42 | print(f"Accuracy for {subset} subset: {accuracy}") 43 | 44 | 45 | if __name__ == "__main__": 46 | CLI([main], as_positional=False) 47 | --------------------------------------------------------------------------------